Advertisement
lukethenerd

model.py

Feb 23rd, 2019
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.29 KB | None | 0 0
  1. import pronouncing
  2. import markovify
  3. import re
  4. import random
  5. import numpy as np
  6. import os
  7. import keras
  8. from keras.models import Sequential
  9. from keras.layers import LSTM
  10. from keras.layers.core import Dense
  11.  
  12. depth = 24  # depth of the network. changing will require a retrain
  13. maxsyllables = 16  # maximum syllables per line. Change this freely without retraining the network
  14. train_mode = True
  15. artist = "MUSICIAN-NAME-GOES-HERE"  # used when saving the trained model
  16. rap_file = "OUTPUT-LYRIC-FILE.txt"  # where the rap is written to
  17.  
  18.  
  19. def create_network(depth):
  20.     model = Sequential()
  21.     model.add(LSTM(4, input_shape=(2, 2), return_sequences=True))
  22.     for i in range(depth):
  23.         model.add(LSTM(8, return_sequences=True))
  24.     model.add(LSTM(2, return_sequences=True))
  25.     model.summary()
  26.     model.compile(optimizer='rmsprop',
  27.                   loss='mse')
  28.  
  29.     if artist + ".rap" in os.listdir(".") and train_mode == False:
  30.         model.load_weights(str(artist + ".rap"))
  31.         print("loading saved network: " + str(artist) + ".rap")
  32.     return model
  33.  
  34.  
  35. def markov(text_file):
  36.     read = open(text_file, "r").read()
  37.     text_model = markovify.NewlineText(read)
  38.     return text_model
  39.  
  40.  
  41. def syllables(line):
  42.     count = 0
  43.     for word in line.split(" "):
  44.         vowels = 'aeiouy'
  45.         word = word.lower().strip(".:;?!")
  46.         if word and word[0] in vowels:
  47.             count += 1
  48.         for index in range(1, len(word)):
  49.             if word[index] in vowels and word[index - 1] not in vowels:
  50.                 count += 1
  51.         if word.endswith('e'):
  52.             count -= 1
  53.         if word.endswith('le'):
  54.             count += 1
  55.         if count == 0:
  56.             count += 1
  57.     return count / maxsyllables
  58.  
  59.  
  60. def rhymeindex(lyrics):
  61.     if str(artist) + ".rhymes" in os.listdir(".") and train_mode == False:
  62.         print("loading saved rhymes from " + str(artist) + ".rhymes")
  63.         return open(str(artist) + ".rhymes", "r").read().split("\n")
  64.     else:
  65.         rhyme_master_list = []
  66.         print ("Alright, building the list of all the rhymes")
  67.         for i in lyrics:
  68.             word = re.sub(r"\W+", '', i.split(" ")[-1]).lower()
  69.             rhymeslist = pronouncing.rhymes(word)
  70.             rhymeslist = [x.encode('UTF8') for x in rhymeslist]
  71.             rhymeslistends = []
  72.             for i in rhymeslist:
  73.                 rhymeslistends.append(i[-2:])
  74.             try:
  75.                 rhymescheme = max(set(rhymeslistends), key=rhymeslistends.count)
  76.             except Exception:
  77.                 rhymescheme = word[-2:]
  78.             rhyme_master_list.append(rhymescheme)
  79.         rhyme_master_list = list(set(rhyme_master_list))
  80.  
  81.         reverselist = [x[::-1] for x in rhyme_master_list]
  82.         reverselist = sorted(reverselist)
  83.  
  84.         rhymelist = [x[::-1] for x in reverselist]
  85.  
  86.         f = open(str(artist) + ".rhymes", "w")
  87.         f.write("\n".join(rhymelist))
  88.         f.close()
  89.         print (rhymelist)
  90.         return rhymelist
  91.  
  92.  
  93. def rhyme(line, rhyme_list):
  94.     word = re.sub(r"\W+", '', line.split(" ")[-1]).lower()
  95.     rhymeslist = pronouncing.rhymes(word)
  96.     rhymeslist = [x.encode('UTF8') for x in rhymeslist]
  97.     rhymeslistends = []
  98.     for i in rhymeslist:
  99.         rhymeslistends.append(i[-2:])
  100.     try:
  101.         rhymescheme = max(set(rhymeslistends), key=rhymeslistends.count)
  102.     except Exception:
  103.         rhymescheme = word[-2:]
  104.     try:
  105.         float_rhyme = rhyme_list.index(rhymescheme)
  106.         float_rhyme = float_rhyme / float(len(rhyme_list))
  107.         return float_rhyme
  108.     except Exception:
  109.         return None
  110.  
  111.  
  112. def split_lyrics_file(text_file):
  113.     text = open(text_file).read()
  114.     text = text.split("\n")
  115.     while "" in text:
  116.         text.remove("")
  117.     return text
  118.  
  119.  
  120. def generate_lyrics(text_model, text_file):
  121.     bars = []
  122.     last_words = []
  123.     lyriclength = len(open(text_file).read().split("\n"))
  124.     count = 0
  125.     markov_model = markov(text_file)
  126.  
  127.     while len(bars) < lyriclength / 9 and count < lyriclength * 2:
  128.         bar = markov_model.make_sentence()
  129.  
  130.         if type(bar) != type(None) and syllables(bar) < 1:
  131.  
  132.             def get_last_word(bar):
  133.                 last_word = bar.split(" ")[-1]
  134.                 if last_word[-1] in "!.?,":
  135.                     last_word = last_word[:-1]
  136.                 return last_word
  137.  
  138.             last_word = get_last_word(bar)
  139.             if bar not in bars and last_words.count(last_word) < 3:
  140.                 bars.append(bar)
  141.                 last_words.append(last_word)
  142.                 count += 1
  143.     return bars
  144.  
  145.  
  146. def build_dataset(lines, rhyme_list):
  147.     dataset = []
  148.     line_list = []
  149.     for line in lines:
  150.         line_list = [line, syllables(line), rhyme(line, rhyme_list)]
  151.         dataset.append(line_list)
  152.  
  153.     x_data = []
  154.     y_data = []
  155.     for i in range(len(dataset) - 3):
  156.         line1 = dataset[i][1:]
  157.         line2 = dataset[i + 1][1:]
  158.         line3 = dataset[i + 2][1:]
  159.         line4 = dataset[i + 3][1:]
  160.  
  161.         x = [line1[0], line1[1], line2[0], line2[1]]
  162.         x = np.array(x)
  163.         x = x.reshape(2, 2)
  164.         x_data.append(x)
  165.  
  166.         y = [line3[0], line3[1], line4[0], line4[1]]
  167.         y = np.array(y)
  168.         y = y.reshape(2, 2)
  169.         y_data.append(y)
  170.  
  171.     x_data = np.array(x_data)
  172.     y_data = np.array(y_data)
  173.  
  174.     # print "x shape " + str(x_data.shape)
  175.     # print "y shape " + str(y_data.shape)
  176.     return x_data, y_data
  177.  
  178.  
  179. def compose_rap(lines, rhyme_list, lyrics_file, model):
  180.     rap_vectors = []
  181.     human_lyrics = split_lyrics_file(lyrics_file)
  182.  
  183.     initial_index = random.choice(range(len(human_lyrics) - 1))
  184.     initial_lines = human_lyrics[initial_index:initial_index + 2]
  185.  
  186.     starting_input = []
  187.     for line in initial_lines:
  188.         starting_input.append([syllables(line), rhyme(line, rhyme_list)])
  189.  
  190.     starting_vectors = model.predict(np.array([starting_input]).flatten().reshape(1, 2, 2))
  191.     rap_vectors.append(starting_vectors)
  192.  
  193.     for i in range(100):
  194.         rap_vectors.append(model.predict(np.array([rap_vectors[-1]]).flatten().reshape(1, 2, 2)))
  195.  
  196.     return rap_vectors
  197.  
  198.  
  199. def vectors_into_song(vectors, generated_lyrics, rhyme_list):
  200.     print ("\n\n")
  201.     print ("About to write rap (this could take a moment)...")
  202.     print ("\n\n")
  203.  
  204.     def last_word_compare(rap, line2):
  205.         penalty = 0
  206.         for line1 in rap:
  207.             word1 = line1.split(" ")[-1]
  208.             word2 = line2.split(" ")[-1]
  209.  
  210.             while word1[-1] in "?!,. ":
  211.                 word1 = word1[:-1]
  212.  
  213.             while word2[-1] in "?!,. ":
  214.                 word2 = word2[:-1]
  215.  
  216.             if word1 == word2:
  217.                 penalty += 0.2
  218.  
  219.         return penalty
  220.  
  221.     def calculate_score(vector_half, syllables, rhyme, penalty):
  222.         desired_syllables = vector_half[0]
  223.         desired_rhyme = vector_half[1]
  224.         desired_syllables = desired_syllables * maxsyllables
  225.         desired_rhyme = desired_rhyme * len(rhyme_list)
  226.  
  227.         score = 1.0 - (abs((float(desired_syllables) - float(syllables))) + abs(
  228.             (float(desired_rhyme) - float(rhyme)))) - penalty
  229.  
  230.         return score
  231.  
  232.     dataset = []
  233.     for line in generated_lyrics:
  234.         line_list = [line, syllables(line), rhyme(line, rhyme_list)]
  235.         dataset.append(line_list)
  236.  
  237.     rap = []
  238.  
  239.     vector_halves = []
  240.  
  241.     for vector in vectors:
  242.         vector_halves.append(list(vector[0][0]))
  243.         vector_halves.append(list(vector[0][1]))
  244.  
  245.     for vector in vector_halves:
  246.         scorelist = []
  247.         for item in dataset:
  248.             line = item[0]
  249.  
  250.             if len(rap) != 0:
  251.                 penalty = last_word_compare(rap, line)
  252.             else:
  253.                 penalty = 0
  254.             total_score = calculate_score(vector, item[1], item[2], penalty)
  255.             score_entry = [line, total_score]
  256.             scorelist.append(score_entry)
  257.  
  258.         fixed_score_list = []
  259.         for score in scorelist:
  260.             fixed_score_list.append(float(score[1]))
  261.         max_score = max(fixed_score_list)
  262.         for item in scorelist:
  263.             if item[1] == max_score:
  264.                 rap.append(item[0])
  265.                 print ((str)(item[0]))
  266.  
  267.                 for i in dataset:
  268.                     if item[0] == i[0]:
  269.                         dataset.remove(i)
  270.                         break
  271.                 break
  272.     return rap
  273.  
  274.  
  275. def train(x_data, y_data, model):
  276.     model.fit(np.array(x_data), np.array(y_data),
  277.               batch_size=2,
  278.               epochs=5,
  279.               verbose=1)
  280.     model.save_weights(artist + ".rap")
  281.  
  282.  
  283. def main(depth, train_mode):
  284.     model = create_network(depth)
  285.     text_file = "lyrics.txt"
  286.     text_model = markov(text_file)
  287.  
  288.     if train_mode == True:
  289.         bars = split_lyrics_file(text_file)
  290.  
  291.     if train_mode == False:
  292.         bars = generate_lyrics(text_model, text_file)
  293.  
  294.     rhyme_list = rhymeindex(bars)
  295.     if train_mode == True:
  296.         x_data, y_data = build_dataset(bars, rhyme_list)
  297.         train(x_data, y_data, model)
  298.  
  299.     if train_mode == False:
  300.         vectors = compose_rap(bars, rhyme_list, text_file, model)
  301.         rap = vectors_into_song(vectors, bars, rhyme_list)
  302.         f = open(rap_file, "w")
  303.         for bar in rap:
  304.             f.write(bar)
  305.             f.write("\n")
  306.  
  307. print(rap_file)
  308. main(depth, train_mode)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement