Guest User

Untitled

a guest
Jan 21st, 2018
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.18 KB | None | 0 0
  1. #!/usr/bin/python3
  2. import os
  3. import time
  4. import numpy as np
  5.  
  6.  
  7. def parse_lyrics_file(filename):
  8. lyrics_lines = open(filename, 'r', errors='ignore').readlines()
  9. tokens = []
  10. for line in lyrics_lines:
  11. tokens = tokens + line.split()
  12. return tokens
  13.  
  14.  
  15. class Episode(object):
  16. def __init__(self, root, data, batch_size, support_size, query_size,
  17. max_len, word_ids, parser):
  18. self.support = np.zeros((batch_size, support_size, max_len), dtype=np.int32)
  19. self.query = np.zeros((batch_size, query_size, max_len), dtype=np.int32)
  20. artists = np.random.choice(data, size=batch_size, replace=False)
  21. for batch, artist in enumerate(artists):
  22. directory = os.path.join(root, artist)
  23. if not os.path.exists(directory):
  24. raise RuntimeError('artist directory not found: %s' % directory)
  25. songs = os.listdir(directory)
  26. sample_size = support_size+query_size
  27. sample = np.random.choice(songs, size=sample_size, replace=False)
  28. support = sample[:support_size]
  29. query = sample[support_size:]
  30. for song_idx, song in enumerate(support):
  31. song_path = os.path.join(root, artist, song)
  32. for token_idx, token in enumerate(parser(song_path)[:max_len]):
  33. self.support[batch][song_idx][token_idx] = word_ids[token]
  34. for song_idx, song in enumerate(query):
  35. song_path = os.path.join(root, artist, song)
  36. for token_idx, token in enumerate(parser(song_path)[:max_len]):
  37. self.query[batch][song_idx][token_idx] = word_ids[token]
  38.  
  39. class EpisodeSampler(object):
  40. def __init__(self, root, split, batch_size, support_size, query_size,
  41. max_len, split_proportions=(8, 1, 1), persist_split=True,
  42. persist_ids=True, parser=parse_lyrics_file):
  43. self.root = root
  44. self.split = split
  45. self.batch_size = batch_size
  46. self.support_size = support_size
  47. self.query_size = query_size
  48. self.max_len = max_len
  49. self.parser = parser
  50. self.word_ids = {}
  51. if split not in ['train', 'val', 'test']:
  52. raise RuntimeError('unknown split: %s' % split)
  53. if not os.path.exists(root):
  54. raise RuntimeError('data directory not found')
  55. word_ids_path = os.path.join(root, 'word_ids.csv')
  56. if persist_ids and os.path.exists(word_ids_path):
  57. for line in open(word_ids_path, 'r'):
  58. row = line.rstrip('\n').split(',', 1)
  59. self.word_ids[row[1]] = int(row[0])
  60. else:
  61. print('Parsing lyrics...')
  62. curr_word_id = 0
  63. for directory, _, filenames in os.walk(root):
  64. for filename in filenames:
  65. filepath = os.path.join(directory, filename)
  66. if not os.path.isdir(filepath):
  67. for word in parser(filepath):
  68. if word not in self.word_ids:
  69. self.word_ids[word] = curr_word_id
  70. curr_word_id += 1
  71. print('done')
  72. if persist_ids:
  73. word_ids_csv = open(word_ids_path, 'w')
  74. for word in self.word_ids:
  75. word_ids_csv.write('%s,%s\n' % (self.word_ids[word], word))
  76.  
  77. split_csv_path = os.path.join(root, '%s.csv' % split)
  78. if persist_split and os.path.exists(split_csv_path):
  79. split_csv = open(split_csv_path, 'r')
  80. self.data = [line.strip() for line in split_csv.readlines()]
  81. split_csv.close()
  82. else:
  83. dirs = []
  84. for artist in os.listdir(root):
  85. if os.path.isdir(os.path.join(root, artist)):
  86. dirs.append(artist)
  87. artists = []
  88. skipped_count = 0
  89. for artist in dirs:
  90. song_count = len(os.listdir(os.path.join(root, artist)))
  91. if song_count >= support_size + query_size:
  92. artists.append(artist)
  93. else:
  94. skipped_count += 1
  95. if skipped_count > 0:
  96. print("%s artists don't have K+K'=%s songs. Using %s artists" % (
  97. skipped_count, support_size + query_size, len(artists)))
  98. train_count = int(float(split_proportions[0]) / sum(split_proportions) * len(artists))
  99. val_count = int(float(split_proportions[1]) / sum(split_proportions) * len(artists))
  100. np.random.shuffle(artists)
  101. if persist_split:
  102. train_csv = open(os.path.join(root, 'train.csv'), 'w')
  103. val_csv = open(os.path.join(root, 'val.csv'), 'w')
  104. test_csv = open(os.path.join(root, 'test.csv'), 'w')
  105. train_csv.write('\n'.join(artists[:train_count]))
  106. val_csv.write('\n'.join(artists[train_count:train_count+val_count]))
  107. test_csv.write('\n'.join(artists[train_count+val_count:]))
  108. train_csv.close()
  109. val_csv.close()
  110. test_csv.close()
  111. if split == 'train':
  112. self.data = artists[:train_count]
  113. elif split == 'val':
  114. self.data = artists[train_count:train_count+val_count]
  115. else:
  116. self.data = artists[train_count+val_count:]
  117.  
  118. def __len__(self):
  119. return len(self.data)
  120.  
  121. def __repr__(self):
  122. return 'EpisodeSampler("%s", "%s")' % (self.root, self.split)
  123.  
  124. def get_episode(self):
  125. return Episode(
  126. self.root,
  127. self.data,
  128. self.batch_size,
  129. self.support_size,
  130. self.query_size,
  131. self.max_len,
  132. self.word_ids,
  133. self.parser
  134. )
  135.  
  136.  
  137. if __name__ == '__main__':
  138. root = './lyrics_data'
  139. split = 'train'
  140. batch_size = 10
  141. support_size = 10
  142. query_size = 10
  143. max_len = 100
  144. sampler = EpisodeSampler(root, split, batch_size, support_size, query_size,
  145. max_len)
  146. start = time.time()
  147. episode = sampler.get_episode()
  148. end = time.time()
  149. print(episode.support.shape)
  150. print(episode.query.shape)
  151. print('Elapsed: %s' % (end - start))
Add Comment
Please, Sign In to add comment