Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os, sys
- import random
- import csv
- def main(train_size=0.7):
- cdir = os.getcwd()
- train_set = []
- test_set = []
- genre_list = [g for g in os.listdir(cdir) if os.path.isdir(g)]
- for genre in genre_list:
- file_list = os.listdir(cdir + "/%s"%genre)
- file_list.sort()
- statements = []
- for f in file_list:
- statements.append(
- {"file_name": f,
- "genre": genre,
- "index": genre_index(genre)}
- )
- train_subset = random.sample(population=statements,
- k=int(len(statements) * train_size))
- test_subset = [s for s in statements if not s in train_subset]
- train_set.append(train_subset)
- test_set.append(test_subset)
- with open(cdir + "/train.csv", "w") as f:
- f.write("file_name,genre,index\n")
- for g in train_set:
- for row in g:
- f.write("%s,%s,%s\n"%(row["file_name"], row["genre"], row["index"]))
- with open(cdir + "/test.csv", "w") as f:
- f.write("file_name,genre,index\n")
- for g in test_set:
- for row in g:
- f.write("%s,%s,%s\n"%(row["file_name"], row["genre"], row["index"]))
- def genre_index(genre_name):
- return {
- "blues": 0,
- "classical": 1,
- "country": 2,
- "disco": 3,
- "hiphop": 4,
- "jazz": 5,
- "metal": 6,
- "pop": 7,
- "reggae": 8,
- "rock": 9
- }[genre_name]
- if __name__ == "__main__":
- main()
Add Comment
Please, Sign In to add comment