snakemelon

python model code

May 13th, 2022 (edited)
25
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import tensorflow as tf
  3. import keras
  4. from keras import backend as K
  5. from keras.models import Sequential, load_model
  6. from keras.layers import Dense, Dropout, Flatten
  7. from keras.layers import Conv2D, MaxPooling2D, BatchNormalization
  8. from tensorflow.keras.optimizers import Adam
  9. from keras.utils import np_utils
  10. from keras import regularizers
  11.  
  12. from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
  13. import load_data
  14. import matplotlib
  15. matplotlib.use("Agg")
  16. from matplotlib import pyplot as plt
  17. import itertools
  18. import os
  19.  
  20. # use the GPU to train!
  21. # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  22.  
  23.  
  24. # Models to be passed to Music_Genre_CNN
  25. song_labels = ["Blues", "Classical", "Country", "Disco", "Hip hop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]
  26.  
  27.  
  28. def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
  29.     """
  30.    This function prints and plots the confusion matrix.
  31.    Normalization can be applied by setting `normalize=True`.
  32.    """
  33.     if normalize:
  34.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  35.         print("Normalized confusion matrix")
  36.     else:
  37.         print('Confusion matrix, without normalization')
  38.  
  39.     print(cm)
  40.  
  41.     plt.imshow(cm, interpolation='nearest', cmap=cmap)
  42.     plt.title(title)
  43.     plt.colorbar()
  44.     tick_marks = np.arange(len(classes))
  45.     plt.xticks(tick_marks, classes, rotation=45)
  46.     plt.yticks(tick_marks, classes)
  47.  
  48.     fmt = '.1f' if normalize else 'd'
  49.     thresh = cm.max() / 2.
  50.     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  51.         plt.text(j, i, format(cm[i, j], fmt),
  52.                  horizontalalignment="center",
  53.                  color="white" if cm[i, j] > thresh else "black")
  54.  
  55.     plt.ylabel('True label')
  56.     plt.xlabel('Predicted label')
  57.     plt.tight_layout()
  58.  
  59.  
  60. def metric(y_true, y_pred):
  61.     return K.mean(K.equal(K.argmax(y_true, axis=1), K.argmax(y_pred, axis=1)))
  62.  
  63.  
  64. def cnn(num_genres=10, input_shape=(64, 173, 1)):
  65.     model = Sequential()
  66.     model.add(Conv2D(64, kernel_size=(4, 4),
  67.                      activation='relu', #kernel_regularizer=regularizers.l2(0.04),
  68.                      input_shape=input_shape))
  69.     model.add(BatchNormalization())
  70.     model.add(MaxPooling2D(pool_size=(2, 4)))
  71.     model.add(Conv2D(64, (3, 5), activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  72.     model.add(MaxPooling2D(pool_size=(2, 2)))
  73.     model.add(Dropout(0.2))
  74.     model.add(Conv2D(64, (2, 2), activation='relu'))
  75.     #,kernel_regularizer=regularizers.l2(0.04)
  76.     model.add(BatchNormalization())
  77.     model.add(MaxPooling2D(pool_size=(2, 2)))
  78.     model.add(Dropout(0.2))
  79.     model.add(Flatten())
  80.     model.add(Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  81.     model.add(Dropout(0.5))
  82.     model.add(Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  83.     model.add(Dense(num_genres, activation='softmax'))
  84.     model.compile(loss=keras.losses.categorical_crossentropy,
  85.                   optimizer=Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0),
  86.                   metrics=[metric])
  87.     return model
  88.  
  89.  
  90. class model(object):
  91.     # Main network thingy to train
  92.     def __init__(self, ann_model):
  93.         self.model = ann_model()
  94.  
  95.     def train_model(self, train_x, train_y, val_x=None, val_y=None,
  96.                     small_batch_size=200,
  97.                     max_iteration=300,
  98.                     print_interval=1,
  99.                     test_x=None, test_y=None):
  100.  
  101.         """
  102.        epoch数是一个超参数,它定义了学习算法在整个训练数据集中的工作次数
  103.        """
  104.  
  105.         m = len(train_x)
  106.  
  107.         #
  108.         for it in range(max_iteration):
  109.  
  110.             # split training data into even batches
  111.             batch_idx = np.random.permutation(m)
  112.             train_x = train_x[batch_idx]
  113.             train_y = train_y[batch_idx]
  114.  
  115.             num_batches = int(m / small_batch_size)
  116.             for batch in range(num_batches):
  117.  
  118.                 x_batch = train_x[batch*small_batch_size: (batch+1)*small_batch_size]
  119.                 y_batch = train_y[batch*small_batch_size: (batch+1)*small_batch_size]
  120.                 print("starting batch\t", batch, "\t Epoch:\t", it)
  121.                 self.model.train_on_batch(x_batch, y_batch)
  122.  
  123.             if it % print_interval == 0:
  124.                 validation_accuracy = self.model.evaluate(val_x, val_y)
  125.                 training_accuracy = self.model.evaluate(train_x, train_y)
  126.                 testing_accuracy = self.model.evaluate(test_x, test_y)
  127.  
  128.                 # print of test error used only after development of the model
  129.                 print("\nTraining accuracy: %f\t Validation accuracy: %f\t Testing Accuracy: %f" %
  130.                       (training_accuracy[1], validation_accuracy[1], testing_accuracy[1]))
  131.                 print("\nTraining loss: %f    \t Validation loss: %f    \t Testing Loss: %f \n" %
  132.                       (training_accuracy[0], validation_accuracy[0], testing_accuracy[0]))
  133.                 print()
  134.  
  135.             # 验证集准确度>.81则输出训练的效果图 matrix
  136.             if validation_accuracy[1] > .81:
  137.                 print("Saving confusion data...")
  138.                 model_name = "model" + str(100*validation_accuracy[1]) + str(100* testing_accuracy[1]) + ".h5"
  139.                 self.model.save(model_name)
  140.                 # pred = self.model.predict_classes(test_x, verbose=1)
  141.                 pred = self.model.predict(test_x, verbose=1)
  142.                 cnf_matrix = confusion_matrix(np.argmax(test_y, axis=1), pred)
  143.                 np.set_printoptions(precision=2)
  144.  
  145.                 # visualization
  146.                 plt.figure()
  147.                 plot_confusion_matrix(cnf_matrix,
  148.                                       classes=song_labels,
  149.                                       normalize=True,
  150.                                       title='Normalized confusion matrix')
  151.                 print(precision_recall_fscore_support(np.argmax(test_y, axis=1), pred, average='macro'))
  152.                 plt.savefig(str(batch)+".png", dpi=600)
  153.  
  154.     # def model_predict(self, test_x):
  155.     #     pred = self.model.predict(test_x, verbose=1)
  156.     #     cnf_matrix = confusion_matrix(np.argmax(test_y, axis=1), pred)
  157.     #     np.set_printoptions(precision=2)
  158.     #
  159.     #     # visualization
  160.     #     plt.figure()
  161.     #     plot_confusion_matrix(cnf_matrix,
  162.     #                           classes=song_labels,
  163.     #                           normalize=True,
  164.     #                           title='Normalized confusion matrix')
  165.     #     print(precision_recall_fscore_support(np.argmax(test_y, axis=1), pred, average='macro'))
  166.     #     plt.savefig(str(batch) + ".png", dpi=600)
  167.  
  168.  
  169. def main():
  170.     # Data stuff
  171.     # data = load_data.loadall('melspects.npz')
  172.  
  173.     tmp = np.load('melspects.npz')
  174.     x_tr = tmp['x_tr']
  175.     y_tr = tmp['y_tr']
  176.     x_te = tmp['x_te']
  177.     y_te = tmp['y_te']
  178.     x_cv = tmp['x_cv']
  179.     y_cv = tmp['y_cv']
  180.     # data = {'x_tr': x_tr, 'y_tr': y_tr,
  181.     #         'x_te': x_te, 'y_te': y_te,
  182.     #         'x_cv': x_cv, 'y_cv': y_cv, }
  183.     #
  184.     # x_tr = data['x_tr']
  185.     # y_tr = data['y_tr']
  186.     # x_te = data['x_te']
  187.     # y_te = data['y_te']
  188.     # x_cv = data['x_cv']
  189.     # y_cv = data['y_cv']
  190.  
  191.     tr_idx = np.random.permutation(len(x_tr))
  192.     te_idx = np.random.permutation(len(x_te))
  193.     cv_idx = np.random.permutation(len(x_cv))
  194.  
  195.     x_tr = x_tr[tr_idx]
  196.     y_tr = y_tr[tr_idx]
  197.     x_te = x_te[te_idx]
  198.     y_te = y_te[te_idx]
  199.     x_cv = x_cv[cv_idx]
  200.     y_cv = y_cv[cv_idx]
  201.  
  202.     x_tr = x_tr[:, :, :, np.newaxis]
  203.     x_te = x_te[:, :, :, np.newaxis]
  204.     x_cv = x_cv[:, :, :, np.newaxis]
  205.  
  206.     y_tr = np_utils.to_categorical(y_tr)
  207.     y_te = np_utils.to_categorical(y_te)
  208.     y_cv = np_utils.to_categorical(y_cv)
  209.  
  210.  
  211. # training = np.load('gtzan/gtzan_tr.npy')
  212. # x_tr = np.delete(training, -1, 1)
  213. # label_tr = training[:,-1]
  214.  
  215. # test = np.load('gtzan/gtzan_te.npy')
  216. # x_te = np.delete(test, -1, 1)
  217. # label_te = test[:,-1]
  218.  
  219. # cv = np.load('gtzan/gtzan_cv.npy')
  220. # x_cv = np.delete(cv, -1, 1)
  221. # label_cv = test[:,-1]
  222.  
  223. # temp = np.zeros((len(label_tr),10))
  224. # temp[np.arange(len(label_tr)),label_tr.astype(int)] = 1
  225. # y_tr = temp
  226. # temp = np.zeros((len(label_te),10))
  227. # temp[np.arange(len(label_te)),label_te.astype(int)] = 1
  228. # y_te = temp
  229. # temp = np.zeros((len(label_cv),10))
  230. # temp[np.arange(len(label_cv)),label_cv.astype(int)] = 1
  231. # y_cv = temp
  232. # del temp
  233.  
  234. #################################################
  235.  
  236. #    if True:
  237. #   model = keras.models.load_model('model84.082.0.h5', custom_objects={'metric': metric})
  238. #   print("Saving confusion data...")
  239. #   pred = model.predict_classes(x_te, verbose=1)
  240. #   cnf_matrix = confusion_matrix(np.argmax(y_te, axis=1), pred)
  241. #   np.set_printoptions(precision=1)
  242. #   plt.figure()
  243. #   plot_confusion_matrix(cnf_matrix, classes=song_labels, normalize=True, title='Normalized confusion matrix')
  244. #   print(precision_recall_fscore_support(np.argmax(y_te, axis=1),pred, average='macro'))
  245. #   plt.savefig("matrix",format='png', dpi=1000)
  246. #   raise SystemExit
  247.     ann = model(cnn)
  248.     ann.train_model(x_tr, y_tr, val_x=x_cv, val_y=y_cv, test_x=x_te, test_y=y_te)
  249.  
  250.  
  251. if __name__ == '__main__':
  252.     main()
  253.  
  254.  
RAW Paste Data Copied