Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import print_function
- import SharedArray as sa
- import numpy as np
- import os
- import json
- import random
- from sklearn.utils import shuffle
- def split_file(path):
- base = os.path.basename(path)
- filename = base.split('.')[0]
- extension = base.split('.')[1]
- return path[:-len(base)-1], filename, extension
- def save_on_sa(use_only_84_keys=True, rescale=True, postfix=''):
- print('Reading...')
- root_dir = '../gdrive/My Drive/Lead Sheet DataSet'
- root_pianoroll = '../gdrive/My Drive/Lead Sheet DataSet/pianoroll'
- with open(os.path.join(root_dir, 'xml_list.json'), "r") as f:
- xml_list = json.load(f)
- num_xml = len(xml_list)
- sa.delete("tra_X_phrs")
- sa.delete("val_X_phrs")
- data = []
- concathelper = []
- for idx in range(num_xml):
- path, fn, _ = split_file(xml_list[idx])
- path_pianoroll = os.path.join(root_pianoroll, path)
- if os.path.exists(os.path.join(path_pianoroll, fn + '_nokey_melody.npy')) and os.path.exists(os.path.join(path_pianoroll, fn + '_nokey_chords.npy')):
- melody_file = os.path.join(path_pianoroll, fn + '_nokey_melody.npy')
- chords_file = os.path.join(path_pianoroll, fn + '_nokey_chords.npy')
- melody = np.load (melody_file)
- chord = np.load(chords_file)
- if melody.shape[1] < 96 or chord.shape[1] < 96:
- continue
- minLength = 0;
- if(melody.shape[1] < chord.shape[1]) :
- minLength = melody.shape[1]
- else:
- minLength = chord.shape[1]
- print(" meldy and chord length :::" , melody.shape, " ", chord.shape)
- print('minLength' , minLength)
- iterations = int(minLength // 96)
- start = 0
- end = 96
- melody = melody.transpose()
- chord = chord.transpose()
- print('itertions = ' , iterations)
- while(iterations > 0):
- subMelody = melody[start : end]
- subChord = chord[start : end]
- print("sub meldy and subchord length :::" , len(subMelody), " ", len(subChord))
- subMelody_data = np.reshape(subMelody, (-1, 96, 128, 1)) # for 1bar
- subChord_data = np.reshape(subChord, (-1, 96, 128, 1))
- if (use_only_84_keys):
- subMelody_data = subMelody_data[:, :, 24:108, :]
- subChord_data = subChord_data[:, :, 24:108, :]
- data.append(subMelody_data)
- data.append(subChord_data)
- temp = np.concatenate(data, axis = 3)
- data = []
- #print("temp shape", temp.shape)
- concathelper.append(temp)
- temp = []
- #print("Helper shape ", concathelper[0])
- start = end
- end = end + 96
- iterations = iterations - 1
- random.shuffle(concathelper)
- train = np.concatenate(concathelper [0:int(len(concathelper)*0.7)] , axis= 0)
- val = np.concatenate(concathelper [int(len(concathelper)*0.7) + 1 : len(concathelper)] , axis= 0)
- tmp_arr = sa.create("tra_X_phrs", train.shape, dtype='float64')
- np.copyto(tmp_arr, train)
- tmp_arr = sa.create("val_X_phrs", val.shape, dtype='float64')
- np.copyto(tmp_arr, val)
- # print (len(train))
- # print (len(val))
- # x = np.delete(x, slice(0, x.shape[1] - 96), axis=1)
- # print(x.shape)
- # x = x.transpose()
- #print(os.path.join(data_dir, sd, x_name + str(i+1 ) + '.npy'))
- ##tmp_data = np.reshape(np.load(os.path.join(data_dir, sd , x_name+'.npy')),(-1,384,128, 1)) # for 4 dbar
- # midi setting
- ##data_X = np.load(os.path.join(data_dir, sd, 'phr_chord_clean.npy'))
- ##data_y = np.load(os.path.join(data_dir, sd , data_prefix[-1]+'.npy'))
- # data_X = np.concatenate(concathelper , axis= 0)
- # print(data_X.dtype)
- # ##print(data_y.dtype)
- # if sd is 'tra':
- # print(sd)
- # print('Shuffling...')
- # ##data_X, data_y = shuffle(data_X, data_y, random_state=0)
- # data_X = shuffle(data_X, random_state=0)
- # else:
- # print(sd)
- # pass
- # name = sd + '_X_' + postfix
- # print(name, data_X.shape)
- # # sa.delete(name)
- # tmp_arr = sa.create(name, data_X.shape, dtype='float64')
- # np.copyto(tmp_arr, data_X)
- ##name = sd + '_y_' + postfix
- ##print(name, data_y.shape)
- ##tmp_arr = sa.create(name, data_y.shape, dtype=bool)
- ##np.copyto(tmp_arr, data_y)
- if __name__ == '__main__':
- ##save_on_sa('../../wayne/v3.0/dataset/data_bar', postfix='bars')
- ##save_on_sa('../music/lpd_4dbar_12_C', postfix='phrs') # data_phr
- ##save_on_sa('./data_tab_4dbar_12', postfix='phrs')
- ##save_on_sa('./data_tab_1bar_12', postfix='phrs')
- ##save_on_sa('./data_tab_2bar_12', postfix='phrs')
- save_on_sa(postfix='phrs')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement