Advertisement
Guest User

Untitled

a guest
Dec 9th, 2019
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.78 KB | None | 0 0
  1. import numpy as np
  2. import random
  3. from PIL import Image
  4. import os
  5. import re
  6.  
  7. #convert matrix to a vector
  8. def mat2vec(x):
  9.     m = x.shape[0]*x.shape[1]
  10.     tmp1 = np.zeros(m)
  11.  
  12.     c = 0
  13.     for i in range(x.shape[0]):
  14.         for j in range(x.shape[1]):
  15.             tmp1[c] = x[i,j]
  16.             c +=1
  17.     return tmp1
  18.  
  19.  
  20. #Create Weight matrix for a single image
  21. def create_W(x):
  22.     if len(x.shape) != 1:
  23.         print("The input is not vector")
  24.         return
  25.     else:
  26.         w = np.zeros([len(x),len(x)])
  27.         for i in range(len(x)):
  28.             for j in range(i,len(x)):
  29.                 if i == j:
  30.                     w[i,j] = 0
  31.                 else:
  32.                     w[i,j] = x[i]*x[j]
  33.                     w[j,i] = w[i,j]
  34.     return w
  35.  
  36.  
  37. #Read Image file and convert it to Numpy array
  38. def readImg2array(file,size, threshold= 145):
  39.     pilIN = Image.open(file).convert(mode="L")
  40.     pilIN= pilIN.resize(size)
  41.     #pilIN.thumbnail(size,Image.ANTIALIAS)
  42.     imgArray = np.asarray(pilIN,dtype=np.uint8)
  43.     x = np.zeros(imgArray.shape,dtype=np.float)
  44.     x[imgArray > threshold] = 1
  45.     x[x==0] = -1
  46.     return x
  47.  
  48. #Convert Numpy array to Image file like Jpeg
  49. def array2img(data, outFile = None):
  50.  
  51.     #data is 1 or -1 matrix
  52.     y = np.zeros(data.shape,dtype=np.uint8)
  53.     y[data==1] = 255
  54.     y[data==-1] = 0
  55.     img = Image.fromarray(y,mode="L")
  56.     if outFile is not None:
  57.         img.save(outFile)
  58.     return img
  59.  
  60.  
  61. #Update
  62. def update(w,y_vec,theta=0.5,time=100):
  63.     for s in range(time):
  64.         m = len(y_vec)
  65.         i = random.randint(0,m-1)
  66.         u = np.dot(w[i][:],y_vec) - theta
  67.  
  68.         if u > 0:
  69.             y_vec[i] = 1
  70.         elif u < 0:
  71.             y_vec[i] = -1
  72.  
  73.     return y_vec
  74.  
  75.  
  76. #The following is training pipeline
  77. #Initial setting
  78. def hopfield(train_files, test_files,theta=0.5, time=1000, size=(100,100),threshold=60, current_path=None):
  79.  
  80.     #read image and convert it to Numpy array
  81.     print ("Importing images and creating weight matrix....")
  82.  
  83.     #num_files is the number of files
  84.     num_files = 0
  85.     for path in train_files:
  86.         print (path)
  87.         x = readImg2array(file=path,size=size,threshold=threshold)
  88.         x_vec = mat2vec(x)
  89.         print (len(x_vec))
  90.         if num_files == 0:
  91.             w = create_W(x_vec)
  92.             num_files = 1
  93.         else:
  94.             tmp_w = create_W(x_vec)
  95.             w = w + tmp_w
  96.             num_files +=1
  97.  
  98.     print ("Weight matrix is done!!")
  99.  
  100.  
  101.     #Import test data
  102.     counter = 0
  103.     for path in test_files:
  104.         y = readImg2array(file=path,size=size,threshold=threshold)
  105.         oshape = y.shape
  106.         y_img = array2img(y)
  107.         y_img.show()
  108.         print ("Imported test data")
  109.  
  110.         y_vec = mat2vec(y)
  111.         print ("Updating...")
  112.         y_vec_after = update(w=w,y_vec=y_vec,theta=theta,time=time)
  113.         y_vec_after = y_vec_after.reshape(oshape)
  114.         if current_path is not None:
  115.             outfile = current_path+"/after_"+str(counter)+".jpeg"
  116.             array2img(y_vec_after,outFile=outfile)
  117.         else:
  118.             after_img = array2img(y_vec_after,outFile=None)
  119.             after_img.show()
  120.         counter +=1
  121.  
  122.  
  123. #Main
  124. #First, you can create a list of input file path
  125. current_path = "."
  126. train_paths = []
  127. path = current_path+"/train_pics/"
  128. for i in os.listdir(path):
  129.     if re.match(r'[0-9a-zA-Z-]*.jp[e]*g',i):
  130.         train_paths.append(path+i)
  131.  
  132. #Second, you can create a list of sungallses file path
  133. test_paths = []
  134. path = current_path+"/test_pics/"
  135. for i in os.listdir(path):
  136.     if re.match(r'[0-9a-zA-Z-_]*.jp[e]*g',i):
  137.         test_paths.append(path+i)
  138.  
  139. #Hopfield network starts!
  140. hopfield(train_files=train_paths, test_files=test_paths, theta=0.,time=2000,size=(28,28),threshold=0, current_path = current_path)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement