Guest User

Untitled

a guest
Jun 24th, 2018
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.39 KB | None | 0 0
  1. from os import listdir
  2. from os.path import isfile, join
  3. import os
  4. import json
  5. import progressbar
  6. from multiprocessing import Pool
  7. import time, tqdm, random
  8. from collections import deque
  9. import numpy as np
  10.  
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.init as init
  14. from torch.autograd import Variable
  15. import torch.optim as optim
  16. import logging
  17.  
  18. from collections import OrderedDict
  19.  
  20.  
  21. torch.set_default_tensor_type('torch.FloatTensor')
  22.  
  23. class RagaDataset(object):
  24. def song_queue_push(self, json_path):
  25. spectr = torch.from_numpy(np.load(json_path.replace('json', 'npy'). \
  26. replace('metadata', 'npy_spectr')))
  27. json_file = json.load(open(json_path))
  28. return ((spectr, json_file,))
  29.  
  30. def __init__(self, data_root, song_q_len):
  31. # Set json q length
  32. self.song_q_len = song_q_len
  33.  
  34. # Load up all valid jsons
  35. self.json_q = [join(data_root, 'metadata/', f) for f in listdir(join(data_root, \
  36. 'metadata/')) if isfile(join(data_root, 'metadata/', f)) and f.endswith('json')]
  37.  
  38. # Num songs
  39. self.num_songs = len(self.json_q)
  40.  
  41. # Shuffle JSON Queue
  42. random.shuffle(self.json_q)
  43.  
  44. # Initialize empty song q
  45. self.song_q = deque([])
  46.  
  47. def __getitem__(self, index):
  48. if len(self.song_q) == 0: # Refill the song queue
  49. num_grab = min(self.song_q_len, len(self.json_q)) # num elements to pop
  50.  
  51. print("Loading more songs!")
  52.  
  53. # Multithreaded loading of songs detailed in json queue
  54. pool = Pool(os.cpu_count())
  55. for song in tqdm.tqdm(pool.imap_unordered(
  56. self.song_queue_push, self.json_q[0:num_grab]), total=num_grab):
  57.  
  58. self.song_q.append(song) # Add each loaded song to the queue
  59.  
  60. if not len(self.song_q) == num_grab: # pop items out
  61. self.json_q = self.json_q[num_grab:]
  62. else: # list should be empty
  63. self.json_q = []
  64. print("Done!")
  65.  
  66. return self.song_q.popleft()
  67.  
  68. def __len__(self):
  69. return self.num_songs
  70.  
  71. class RagaDetector(nn.Module):
  72. def __init__(self, num_outputs, lstm_input_len, k):
  73. super(RagaDetector, self).__init__()
  74.  
  75. self.num_outputs = num_outputs
  76. self.lstm_input_len = lstm_input_len
  77. self.k = k
  78.  
  79. self.encoder = nn.Sequential(OrderedDict([
  80. ('norm0', nn.BatchNorm2d(1)),
  81.  
  82. ('conv1', nn.Conv2d(1, 64, 3, padding=1)),
  83. ('norm1', nn.BatchNorm2d(64)),
  84. ('elu1', nn.ELU()),
  85. ('pool1', nn.MaxPool2d(2, 2)),
  86. ('drop1', nn.Dropout(p=0.1)),
  87.  
  88. ('conv2', nn.Conv2d(64, 128, 3, padding=1)),
  89. ('norm2', nn.BatchNorm2d(128)),
  90. ('elu2', nn.ELU()),
  91. ('pool2', nn.MaxPool2d(3, 3)),
  92. ('drop2', nn.Dropout(p=0.1)),
  93.  
  94. ('conv3', nn.Conv2d(128, 128, 3, padding=1)),
  95. ('norm3', nn.BatchNorm2d(128)),
  96. ('elu3', nn.ELU()),
  97. ('pool3', nn.MaxPool2d(4, 4)),
  98. ('drop3', nn.Dropout(p=0.1)),
  99.  
  100. ('conv4', nn.Conv2d(128, 128, 3, padding=1)),
  101. ('norm4', nn.BatchNorm2d(128)),
  102. ('elu4', nn.ELU()),
  103. ('pool4', nn.MaxPool2d(4, 4)),
  104. ('drop4', nn.Dropout(p=0.1))
  105. ]))
  106.  
  107. self.lstm_1 = torch.nn.LSTMCell(640, 100).cuda()
  108. self.lstm_2 = torch.nn.LSTMCell(100, num_outputs).cuda()
  109.  
  110. for m in self.modules():
  111. if isinstance(m, nn.Conv2d):
  112. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  113. elif isinstance(m, nn.BatchNorm2d):
  114. nn.init.constant_(m.weight, 1)
  115. nn.init.constant_(m.bias, 0)
  116.  
  117. def train(self, x, y, criterion, optimizer):
  118. batch_size = x.shape[0]
  119.  
  120. x = self.encoder(x)
  121. x = torch.split(x, 5, dim=3)[:-1]
  122.  
  123. hx_1 = torch.randn(batch_size, 100).cuda()
  124. cx_1 = torch.randn(batch_size, 100).cuda()
  125.  
  126. hx_2 = torch.randn(batch_size, self.num_outputs).cuda()
  127. cx_2 = torch.randn(batch_size, self.num_outputs).cuda()
  128.  
  129. output = []
  130. for i in range(len(x)):
  131. input = x[i].reshape((batch_size, -1))
  132. hx_1, cx_1 = self.lstm_1(input, (hx_1, cx_1))
  133. hx_2, cx_2 = self.lstm_2(hx_1, (hx_2, cx_2))
  134. output.append(hx_2.unsqueeze(2))
  135.  
  136. if (i + 1) % self.k == 0:
  137. net_out = torch.cat(output, dim=2)
  138.  
  139. target = torch.LongTensor(self.k)
  140. target[:] = y
  141. target = target.unsqueeze(0)
  142.  
  143. loss = criterion(net_out.float(), Variable(target).cuda())
  144. # loss.backward(retain_graph=True)
  145. print(loss.data[0])
  146. optimizer.step()
  147. optimizer.zero_grad()
  148. output = []
  149.  
  150. if __name__ == '__main__':
  151. dataset = RagaDataset('/home/sauhaarda/Dataset', 40)
  152. net = RagaDetector(72, 100, 10).cuda()
  153. # data = dataset[0][0].unsqueeze(0).unsqueeze(0).float()
  154. # print(data.shape)
  155. # print(net(data)[0].shape)
  156.  
  157. criterion = nn.CrossEntropyLoss().cuda()
  158. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  159. for epoch in range(2):
  160. running_loss = 0.0
  161. for x, y in dataset:
  162. x = torch.autograd.Variable(x.unsqueeze(0).unsqueeze(0).float().cuda())
  163. label = y['myragaid']
  164. optimizer.zero_grad()
  165. net.train(x, label, criterion, optimizer)
  166. del x
Add Comment
Please, Sign In to add comment