Guest User

Untitled

a guest
May 20th, 2018
135
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.54 KB | None | 0 0
  1. import os, sys
  2. import random
  3. import csv
  4.  
  5. def main(train_size=0.7):
  6. cdir = os.getcwd()
  7. train_set = []
  8. test_set = []
  9. genre_list = [g for g in os.listdir(cdir) if os.path.isdir(g)]
  10. for genre in genre_list:
  11. file_list = os.listdir(cdir + "/%s"%genre)
  12. file_list.sort()
  13. statements = []
  14. for f in file_list:
  15. statements.append(
  16. {"file_name": f,
  17. "genre": genre,
  18. "index": genre_index(genre)}
  19. )
  20. train_subset = random.sample(population=statements,
  21. k=int(len(statements) * train_size))
  22. test_subset = [s for s in statements if not s in train_subset]
  23. train_set.append(train_subset)
  24. test_set.append(test_subset)
  25. with open(cdir + "/train.csv", "w") as f:
  26. f.write("file_name,genre,index\n")
  27. for g in train_set:
  28. for row in g:
  29. f.write("%s,%s,%s\n"%(row["file_name"], row["genre"], row["index"]))
  30. with open(cdir + "/test.csv", "w") as f:
  31. f.write("file_name,genre,index\n")
  32. for g in test_set:
  33. for row in g:
  34. f.write("%s,%s,%s\n"%(row["file_name"], row["genre"], row["index"]))
  35.  
  36. def genre_index(genre_name):
  37. return {
  38. "blues": 0,
  39. "classical": 1,
  40. "country": 2,
  41. "disco": 3,
  42. "hiphop": 4,
  43. "jazz": 5,
  44. "metal": 6,
  45. "pop": 7,
  46. "reggae": 8,
  47. "rock": 9
  48. }[genre_name]
  49.  
  50. if __name__ == "__main__":
  51. main()
Add Comment
Please, Sign In to add comment