Advertisement
Guest User

Untitled

a guest
May 24th, 2019
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.57 KB | None | 0 0
  1. import numpy as np
  2. import sys, pickle, os
  3. import tensorflow.keras as krs
  4.  
  5. lines_per_song = 128
  6. words_per_line = 32
  7. max_words = lines_per_song * words_per_line
  8. latent_dim = 64
  9.  
  10.  
  11. iterations = 10000
  12. batch_size = 20
  13. save_dir = 'your_dir'
  14.  
  15.  
  16.  
  17. def turn_word_to_float(data, dictionary):
  18.     assert(len(data) <= max_words)
  19.     result = np.zeros(shape=(lines_per_song,words_per_line), dtype=np.float64)
  20.     length = len(dictionary)
  21.     line = 0
  22.     word = 0
  23.     for elem in data:
  24.         #print("pos = %s, %s" % (line, word))
  25.         if elem == "\n":
  26.             #print("found n")
  27.             word = -1
  28.             line += 1
  29.         else:
  30.             if elem in dictionary:
  31.                 #print("found %s with val" % elem)
  32.                 #print("yep %s %s" % (word, dictionary[word]))
  33.                 result[line][word] = np.float64(dictionary[elem]/length)
  34.             else:
  35.                 print("not found %s" % elem)
  36.         word += 1
  37.         if word == words_per_line:
  38.             word = 0
  39.             line += 1
  40.         if line == lines_per_song:
  41.             break
  42.     return result
  43.            
  44. def turn_num_array_to_text_array(num_array, reversed_dictionary):
  45.     result = []
  46.     length = len(reversed_dictionary)
  47.     for line in num_array:
  48.         for elem in line:
  49.             inified = int(elem*length)
  50.             if inified in reversed_dictionary:
  51.                 if elem != 0.0:
  52.                     result.append(reversed_dictionary[inified])
  53.         result.append("\n")
  54.     return result
  55.  
  56. def clean_string(data):
  57.     data = data.lower()
  58.     data = data.replace("\n", " \n ")
  59.     data = data.replace("\"", " \" ")
  60.     data = data.replace("!", " ! ")
  61.     data = data.replace("?", " ? ")
  62.     data = data.replace(",", " , ")
  63.     data = data.replace("(", " ( ")
  64.     data = data.replace(")", " ) ")
  65.     data = data.replace("{", " { ")
  66.     data = data.replace("}", " } ")
  67.     #data = data.replace("'", " ' ")
  68.     #data = data.replace("#", " # ")
  69.     data = data.replace("  ", " ")
  70.     data = data.replace("  ", " ")
  71.     data = data.replace("  ", " ")
  72.     #data = data.split("\n")
  73.     #data = " ".join(data)
  74.     return data
  75.  
  76. def turn_file_to_num_array( filename, dictionary):
  77.     #print("opening file '%s'" % filename)
  78.     fo = open(filename, "r")
  79.     data = fo.read()
  80.     data = clean_string(data).split(" ")
  81.     data = turn_word_to_float(data, dictionary)
  82.     #print(data)
  83.     fo.close()
  84.     return data
  85.  
  86.  
  87. def grab_dataset_from_folder(folder_to_read,dict_name):
  88.     #try:
  89.     print("opening ditctionaries %s.dict.pickle and %s.rdict.pickle" % (dict_name,dict_name))
  90.    
  91.     with open('%s.dict.pickle' % dict_name, 'rb') as handle:
  92.         dictionary = pickle.load(handle)
  93.    
  94.     with open('%s.rdict.pickle' % dict_name, 'rb') as handle:
  95.         reversed_dictionary = pickle.load(handle)
  96.  
  97.     print("opening folder '%s'" % folder_to_read)
  98.     result = []
  99.     for filename in os.listdir(folder_to_read):
  100.         result.append(turn_file_to_num_array( "%s/%s" % (folder_to_read,filename), dictionary))
  101.         #break
  102.     #reversed = turn_num_array_to_text_array(result[0], reversed_dictionary)
  103.     #print(" ".join(reversed))
  104.     return np.array(result)
  105.     #except:
  106.     #    print("oups! cannot build dataset")
  107.     #    exit()
  108.  
  109. def generate_generator():
  110.     generator_input = krs.Input(shape=(latent_dim,))
  111.     x = krs.layers.Dense(128 * lines_per_song * words_per_line)(generator_input)
  112.     x = krs.layers.LeakyReLU()(x)
  113.     x = krs.layers.Reshape((32, 32, 128))(x)
  114.     x = krs.layers.Conv2D(256, 5, padding='same')(x)
  115.     x = krs.layers.LeakyReLU()(x)
  116.     x = krs.layers.Conv2DTranspose(256, 4, strides=2, padding='same')(x)
  117.     x = krs.layers.LeakyReLU()(x)
  118.     x = krs.layers.UpSampling2D()(x)
  119.     x = krs.layers.Conv2D(256, 5, padding='same')(x)
  120.     x = krs.layers.LeakyReLU()(x)
  121.     x = krs.layers.UpSampling2D()(x)
  122.     x = krs.layers.Conv2D(256, 5, padding='same')(x)
  123.     x = krs.layers.LeakyReLU()(x)
  124.     x = krs.layers.Conv2D(1, 4, activation='tanh', padding='same')(x)
  125.     x = krs.layers.Flatten()(x)
  126.     x = krs.layers.Reshape((lines_per_song, words_per_line, 1))(x)
  127.     generator = krs.models.Model(generator_input, x)
  128.     print("-- Generator -- ")
  129.     generator.summary()
  130.     return generator
  131.  
  132.  
  133.  
  134. def generate_discriminator():
  135.     discriminator_input = krs.layers.Input(shape=(lines_per_song, words_per_line, 1))
  136.     x = krs.layers.Conv2D(128, 3)(discriminator_input)
  137.     x = krs.layers.LeakyReLU()(x)
  138.     x = krs.layers.Conv2D(128, 4, strides=2)(x)
  139.     x = krs.layers.LeakyReLU()(x)
  140.     x = krs.layers.Conv2D(128, 4, strides=2)(x)
  141.     x = krs.layers.LeakyReLU()(x)
  142.     x = krs.layers.Conv2D(128, 4, strides=2)(x)
  143.     x = krs.layers.LeakyReLU()(x)
  144.     x = krs.layers.Flatten()(x)
  145.     x = krs.layers.Dropout(0.4)(x)
  146.     x = krs.layers.Dense(1, activation='sigmoid')(x)
  147.     discriminator = krs.models.Model(discriminator_input, x)
  148.     print("-- Discriminator -- ")
  149.     discriminator.summary()
  150.     discriminator_optimizer = krs.optimizers.RMSprop(
  151.         lr=0.0008,
  152.         clipvalue=1.0,
  153.         decay=1e-8
  154.     )
  155.     discriminator.compile(
  156.         optimizer=discriminator_optimizer,
  157.         loss='binary_crossentropy'
  158.     )
  159.     discriminator.trainable = False
  160.     return discriminator
  161.  
  162. if __name__ == '__main__':
  163.     if len(sys.argv) != 3:
  164.         print("script folder dictname")
  165.         exit()
  166.     x_train = grab_dataset_from_folder(sys.argv[1], sys.argv[2])
  167.     #print(dataset)
  168.     print(x_train[0].shape)
  169.     generator = generate_generator()
  170.     discriminator = generate_discriminator()
  171.  
  172.     gan_input = krs.Input(shape=(latent_dim,))
  173.     gan_output = discriminator(generator(gan_input))
  174.     gan = krs.models.Model(gan_input, gan_output)
  175.     gan_optimizer = krs.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
  176.     gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')
  177.     start = 0
  178.     for step in range(iterations):
  179.         random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
  180.         generated_images = generator.predict(random_latent_vectors)
  181.         stop = start + batch_size
  182.         real_images = x_train[start : stop]
  183.         combined_images = np.concatenate([generated_images, real_images])
  184.         labels = np.concatenate([np.ones((batch_size, 1)),
  185.         np.zeros((batch_size, 1))])
  186.         labels += 0.05 * np.random.random(labels.shape)
  187.         d_loss = discriminator.train_on_batch(combined_images, labels)
  188.         random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
  189.         misleading_targets = np.zeros((batch_size, 1))
  190.         a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
  191.         start += batch_size
  192.         if start > len(x_train) - batch_size:
  193.             start = 0
  194.         if step % 100 == 0:
  195.             gan.save_weights('gan.h5')
  196.             print('discriminator loss:', d_loss)
  197.             print('adversarial loss:', a_loss)
  198.             reversed = " ".join(turn_num_array_to_text_array(generated_images[0], reversed_dictionary))
  199.            
  200.             print("---------------------\n%s\n---------------------" % reversed)
  201.            
  202.            
  203.             file1 = open(os.path.join(save_dir,'generated_song' + str(step) + '.txt'),"w")
  204.             file1.write(reversed)
  205.             file1.close()
  206.             reversed = " ".join(turn_num_array_to_text_array(real_images[0], reversed_dictionary))
  207.             print("---------------------\n%s\n---------------------" % reversed)
  208.             file1 = open(os.path.join(save_dir,'real_song' + str(step) + '.txt'),"w")
  209.             file1.write(reversed)
  210.             file1.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement