Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/python3
- import os
- import time
- import numpy as np
- def parse_lyrics_file(filename):
- lyrics_lines = open(filename, 'r', errors='ignore').readlines()
- tokens = []
- for line in lyrics_lines:
- tokens = tokens + line.split()
- return tokens
- class Episode(object):
- def __init__(self, root, data, batch_size, support_size, query_size,
- max_len, word_ids, parser):
- self.support = np.zeros((batch_size, support_size, max_len), dtype=np.int32)
- self.query = np.zeros((batch_size, query_size, max_len), dtype=np.int32)
- artists = np.random.choice(data, size=batch_size, replace=False)
- for batch, artist in enumerate(artists):
- directory = os.path.join(root, artist)
- if not os.path.exists(directory):
- raise RuntimeError('artist directory not found: %s' % directory)
- songs = os.listdir(directory)
- sample_size = support_size+query_size
- sample = np.random.choice(songs, size=sample_size, replace=False)
- support = sample[:support_size]
- query = sample[support_size:]
- for song_idx, song in enumerate(support):
- song_path = os.path.join(root, artist, song)
- for token_idx, token in enumerate(parser(song_path)[:max_len]):
- self.support[batch][song_idx][token_idx] = word_ids[token]
- for song_idx, song in enumerate(query):
- song_path = os.path.join(root, artist, song)
- for token_idx, token in enumerate(parser(song_path)[:max_len]):
- self.query[batch][song_idx][token_idx] = word_ids[token]
- class EpisodeSampler(object):
- def __init__(self, root, split, batch_size, support_size, query_size,
- max_len, split_proportions=(8, 1, 1), persist_split=True,
- persist_ids=True, parser=parse_lyrics_file):
- self.root = root
- self.split = split
- self.batch_size = batch_size
- self.support_size = support_size
- self.query_size = query_size
- self.max_len = max_len
- self.parser = parser
- self.word_ids = {}
- if split not in ['train', 'val', 'test']:
- raise RuntimeError('unknown split: %s' % split)
- if not os.path.exists(root):
- raise RuntimeError('data directory not found')
- word_ids_path = os.path.join(root, 'word_ids.csv')
- if persist_ids and os.path.exists(word_ids_path):
- for line in open(word_ids_path, 'r'):
- row = line.rstrip('\n').split(',', 1)
- self.word_ids[row[1]] = int(row[0])
- else:
- print('Parsing lyrics...')
- curr_word_id = 0
- for directory, _, filenames in os.walk(root):
- for filename in filenames:
- filepath = os.path.join(directory, filename)
- if not os.path.isdir(filepath):
- for word in parser(filepath):
- if word not in self.word_ids:
- self.word_ids[word] = curr_word_id
- curr_word_id += 1
- print('done')
- if persist_ids:
- word_ids_csv = open(word_ids_path, 'w')
- for word in self.word_ids:
- word_ids_csv.write('%s,%s\n' % (self.word_ids[word], word))
- split_csv_path = os.path.join(root, '%s.csv' % split)
- if persist_split and os.path.exists(split_csv_path):
- split_csv = open(split_csv_path, 'r')
- self.data = [line.strip() for line in split_csv.readlines()]
- split_csv.close()
- else:
- dirs = []
- for artist in os.listdir(root):
- if os.path.isdir(os.path.join(root, artist)):
- dirs.append(artist)
- artists = []
- skipped_count = 0
- for artist in dirs:
- song_count = len(os.listdir(os.path.join(root, artist)))
- if song_count >= support_size + query_size:
- artists.append(artist)
- else:
- skipped_count += 1
- if skipped_count > 0:
- print("%s artists don't have K+K'=%s songs. Using %s artists" % (
- skipped_count, support_size + query_size, len(artists)))
- train_count = int(float(split_proportions[0]) / sum(split_proportions) * len(artists))
- val_count = int(float(split_proportions[1]) / sum(split_proportions) * len(artists))
- np.random.shuffle(artists)
- if persist_split:
- train_csv = open(os.path.join(root, 'train.csv'), 'w')
- val_csv = open(os.path.join(root, 'val.csv'), 'w')
- test_csv = open(os.path.join(root, 'test.csv'), 'w')
- train_csv.write('\n'.join(artists[:train_count]))
- val_csv.write('\n'.join(artists[train_count:train_count+val_count]))
- test_csv.write('\n'.join(artists[train_count+val_count:]))
- train_csv.close()
- val_csv.close()
- test_csv.close()
- if split == 'train':
- self.data = artists[:train_count]
- elif split == 'val':
- self.data = artists[train_count:train_count+val_count]
- else:
- self.data = artists[train_count+val_count:]
- def __len__(self):
- return len(self.data)
- def __repr__(self):
- return 'EpisodeSampler("%s", "%s")' % (self.root, self.split)
- def get_episode(self):
- return Episode(
- self.root,
- self.data,
- self.batch_size,
- self.support_size,
- self.query_size,
- self.max_len,
- self.word_ids,
- self.parser
- )
- if __name__ == '__main__':
- root = './lyrics_data'
- split = 'train'
- batch_size = 10
- support_size = 10
- query_size = 10
- max_len = 100
- sampler = EpisodeSampler(root, split, batch_size, support_size, query_size,
- max_len)
- start = time.time()
- episode = sampler.get_episode()
- end = time.time()
- print(episode.support.shape)
- print(episode.query.shape)
- print('Elapsed: %s' % (end - start))
Add Comment
Please, Sign In to add comment