Advertisement
snakemelon

why i cant start training?

May 14th, 2022
795
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import itertools
  3. import logging
  4. import tensorflow as tf
  5. import keras
  6. from keras import backend as K
  7. from keras.models import Sequential
  8. from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D, BatchNormalization
  9. from tensorflow.keras.optimizers import Adam
  10. from keras.utils import np_utils
  11. from keras import regularizers
  12. import matplotlib
  13. from matplotlib import pyplot as plt
  14. from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
  15. import load_data
  16. import matplotlib
  17. import time
  18.  
  19. matplotlib.use("Agg")
  20.  
  21. # Not show the Warning information
  22. logging.getLogger("tensorflow").setLevel(logging.ERROR)
  23.  
  24. # use the GPU to train the network
  25. CUDA_VISIBEL_DEVICES = 0
  26.  
  27. # Models to be passed to Music_Genre_CNN
  28. song_labels = ["Blues", "Classical", "Country", "Disco", "Hip hop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]
  29. MODEL_PATH = "CNN/"
  30. MAX_ITERATION = 100
  31.  
  32.  
  33. # remote interpreter path:
  34. # /root/miniconda3/envs/myconda/bin/python
  35.  
  36. def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
  37.     """
  38.    This function prints and plots the confusion matrix.
  39.    Normalization can be applied by setting `normalize=True`.
  40.    """
  41.     if normalize:
  42.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  43.         print("Normalized confusion matrix")
  44.     else:
  45.         print('Confusion matrix, without normalization')
  46.  
  47.     print(cm)
  48.  
  49.     plt.imshow(cm, interpolation='nearest', cmap=cmap)
  50.     plt.title(title)
  51.     plt.colorbar()
  52.     tick_marks = np.arange(len(classes))
  53.     plt.xticks(tick_marks, classes, rotation=45)
  54.     plt.yticks(tick_marks, classes)
  55.  
  56.     fmt = '.1f' if normalize else 'd'
  57.     thresh = cm.max() / 2.
  58.     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  59.         plt.text(j, i, format(cm[i, j], fmt),
  60.                  horizontalalignment="center",
  61.                  color="white" if cm[i, j] > thresh else "black")
  62.  
  63.     plt.ylabel('True label')
  64.     plt.xlabel('Predicted label')
  65.     plt.tight_layout()
  66.  
  67.  
  68. def metric(y_true, y_pred):
  69.     return K.mean(K.equal(K.argmax(y_true, axis=1), K.argmax(y_pred, axis=1)))
  70.  
  71.  
  72. def cnn(kernel_size=(4,4), num_genres=10, input_shape=(64, 173, 1), learning_rate=0.01):
  73.     model = Sequential()
  74.  
  75.     model.add(Conv2D(64, kernel_size=kernel_size, activation='relu', input_shape=input_shape))
  76.     model.add(BatchNormalization())
  77.     model.add(MaxPooling2D(pool_size=(2, 4)))
  78.     model.add(Conv2D(64, (3, 5), activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  79.     model.add(MaxPooling2D(pool_size=(2, 2)))
  80.     model.add(Dropout(0.2))
  81.     model.add(Conv2D(64, (2, 2), activation='relu'))
  82.     # ,kernel_regularizer=regularizers.l2(0.04)
  83.     model.add(BatchNormalization())
  84.     model.add(MaxPooling2D(pool_size=(2, 2)))
  85.     model.add(Dropout(0.2))
  86.     model.add(Flatten())
  87.     model.add(Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  88.     model.add(Dropout(0.5))
  89.     model.add(Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  90.     model.add(Dense(num_genres, activation='softmax'))
  91.     model.compile(loss=keras.losses.categorical_crossentropy,
  92.                   optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0),
  93.                   metrics=[metric])
  94.     return model
  95.  
  96.  
  97. def model_predict(model_name, model, test_x, test_y):
  98.     pred = model.predict(test_x, batch_size=32, verbose=1).argmax(axis=1)
  99.     cnf_matrix = confusion_matrix(np.argmax(test_y, axis=1), pred)
  100.     np.set_printoptions(precision=2)
  101.  
  102.     # visualization
  103.     plt.figure()
  104.     plot_confusion_matrix(cnf_matrix,
  105.                           classes=song_labels,
  106.                           normalize=True,
  107.                           title='Normalized confusion matrix')
  108.     print(precision_recall_fscore_support(np.argmax(test_y, axis=1), pred, average='macro'))
  109.     plt.savefig(str(model_name) + ".png", dpi=600)
  110.  
  111.  
  112. class Model(object):
  113.     # Main network thingy to train
  114.  
  115.     @tf.autograph.experimental.do_not_convert
  116.     def __init__(self, ann_model):
  117.         self.model = ann_model()
  118.  
  119.     t = time.gmtime()
  120.  
  121.     @tf.autograph.experimental.do_not_convert
  122.     def train_model(self, train_x, train_y, val_x=None, val_y=None, small_batch_size=300, max_iteration=MAX_ITERATION,
  123.                     print_interval=1, test_x=None, test_y=None,
  124.                     confusion_matrix_img_name=time.strftime("%m-%d %H:%M:%S", t)):
  125.  
  126.         batch_size = str(small_batch_size)
  127.         train_accuracy_list = []
  128.         train_loss_list = []
  129.         validation_accuracy_list = []
  130.         validation_loss_list = []
  131.         test_accuracy_list = []
  132.         test_loss_list = []
  133.         epoch_list = []
  134.  
  135.         m = len(train_x)
  136.         for it in range(max_iteration):
  137.  
  138.             # split training data into even batches
  139.             batch_idx = np.random.permutation(m)
  140.             train_x = train_x[batch_idx]
  141.             train_y = train_y[batch_idx]
  142.  
  143.             num_batches = int(m / small_batch_size)
  144.             for batch in range(num_batches):
  145.                 x_batch = train_x[batch * small_batch_size: (batch + 1) * small_batch_size]
  146.                 y_batch = train_y[batch * small_batch_size: (batch + 1) * small_batch_size]
  147.                 print("starting batch\t", batch, "\t Epoch:\t", it)
  148.  
  149.                 # tf.autograph.experimental.do_not_convert(func=self.model.train_on_batch(x_batch, y_batch))
  150.                 self.model.train_on_batch(x_batch, y_batch)
  151.  
  152.             if it % print_interval == 0:
  153.                 validation_accuracy = self.model.evaluate(val_x, val_y)
  154.                 training_accuracy = self.model.evaluate(train_x, train_y)
  155.                 testing_accuracy = self.model.evaluate(test_x, test_y)
  156.  
  157.                 # print of test error used only after development of the model
  158.                 print("\nTraining accuracy: %f\t Validation accuracy: %f\t Testing Accuracy: %f" %
  159.                       (training_accuracy[1], validation_accuracy[1], testing_accuracy[1]))
  160.                 print("\nTraining loss: %f    \t Validation loss: %f    \t Testing Loss: %f \n" %
  161.                       (training_accuracy[0], validation_accuracy[0], testing_accuracy[0]))
  162.                 print()
  163.  
  164.                 # the data for draw
  165.                 epoch_list.append(it)
  166.  
  167.                 train_accuracy_list.append(training_accuracy[1])
  168.                 train_loss_list.append(training_accuracy[0])
  169.  
  170.                 validation_accuracy_list.append(validation_accuracy[1])
  171.                 validation_loss_list.append(validation_accuracy[0])
  172.  
  173.                 test_accuracy_list.append(testing_accuracy[1])
  174.                 test_loss_list.append(testing_accuracy[0])
  175.  
  176.             # 验证集准确度>.81则输出训练的效果图 matrix
  177.             if validation_accuracy[1] >= 0.8:
  178.                 print("Saving confusion data...")
  179.                 # 较早版本的
  180.                 model_name = "model_" + str(100 * validation_accuracy[1]) + "_" + str(100 * testing_accuracy[1]) + ".h5"
  181.                 # self.model.save(model_name)
  182.                 # pred = self.model.predict_classes(test_x, verbose=1)
  183.                 pred = self.model.predict(test_x, verbose=1).argmax(axis=1)
  184.                 cnf_matrix = confusion_matrix(np.argmax(test_y, axis=1), pred)
  185.                 np.set_printoptions(precision=2)
  186.  
  187.                 # visualization
  188.                 plt.figure()
  189.                 plot_confusion_matrix(cnf_matrix,
  190.                                       classes=song_labels,
  191.                                       normalize=True,
  192.                                       title='Normalized confusion matrix')
  193.                 print(precision_recall_fscore_support(np.argmax(test_y, axis=1), pred, average='macro'))
  194.                 plt.savefig(confusion_matrix_img_name + ".png", dpi=600)
  195.                 # break
  196.  
  197.         # save the png file and analyze the batch_size and the learning_rate
  198.         fig = plt.figure()
  199.         ax1 = fig.add_subplot(111)
  200.         ax1.plot(epoch_list, train_accuracy_list, label="train_accuracy", color='blue')
  201.         ax2 = ax1.twinx()
  202.         ax2.plot(epoch_list, train_loss_list, label="train_loss", color='red')
  203.  
  204.         learning_rate = 0.01
  205.         title_name = "batch_size=" + batch_size + " learning_rate=" + str(learning_rate)
  206.  
  207.         plt.legend()
  208.         plt.title(title_name)
  209.         plt.savefig(title_name + ".png", dpi=600)
  210.         print(title_name, "saved successfully!")
  211.  
  212.  
  213. def main():
  214.     # Data stuff
  215.     data = load_data.loadall('melspects.npz')
  216.  
  217.     # tmp = np.load('melspects.npz')
  218.     x_tr = data['x_tr']
  219.     y_tr = data['y_tr']
  220.     x_te = data['x_te']
  221.     y_te = data['y_te']
  222.     x_cv = data['x_cv']
  223.     y_cv = data['y_cv']
  224.  
  225.     tr_idx = np.random.permutation(len(x_tr))
  226.     te_idx = np.random.permutation(len(x_te))
  227.     cv_idx = np.random.permutation(len(x_cv))
  228.  
  229.     x_tr = x_tr[tr_idx]
  230.     y_tr = y_tr[tr_idx]
  231.     x_te = x_te[te_idx]
  232.     y_te = y_te[te_idx]
  233.     x_cv = x_cv[cv_idx]
  234.     y_cv = y_cv[cv_idx]
  235.  
  236.     x_tr = x_tr[:, :, :, np.newaxis]
  237.     x_te = x_te[:, :, :, np.newaxis]
  238.     x_cv = x_cv[:, :, :, np.newaxis]
  239.  
  240.     y_tr = np_utils.to_categorical(y_tr)
  241.     y_te = np_utils.to_categorical(y_te)
  242.     y_cv = np_utils.to_categorical(y_cv)
  243.  
  244.     # training = np.load('gtzan/gtzan_tr.npy')
  245.     # x_tr = np.delete(training, -1, 1)
  246.     # label_tr = training[:,-1]
  247.  
  248.     # test = np.load('gtzan/gtzan_te.npy')
  249.     # x_te = np.delete(test, -1, 1)
  250.     # label_te = test[:,-1]
  251.  
  252.     # cv = np.load('gtzan/gtzan_cv.npy')
  253.     # x_cv = np.delete(cv, -1, 1)
  254.     # label_cv = test[:,-1]
  255.  
  256.     # temp = np.zeros((len(label_tr),10))
  257.     # temp[np.arange(len(label_tr)),label_tr.astype(int)] = 1
  258.     # y_tr = temp
  259.     # temp = np.zeros((len(label_te),10))
  260.     # temp[np.arange(len(label_te)),label_te.astype(int)] = 1
  261.     # y_te = temp
  262.     # temp = np.zeros((len(label_cv),10))
  263.     # temp[np.arange(len(label_cv)),label_cv.astype(int)] = 1
  264.     # y_cv = temp
  265.     # del temp
  266.  
  267.     #################################################
  268.  
  269.     #    if True:
  270.     #   model = keras.models.load_model('model84.082.0.h5', custom_objects={'metric': metric})
  271.     #   print("Saving confusion data...")
  272.     #   pred = model.predict_classes(x_te, verbose=1)
  273.     #   cnf_matrix = confusion_matrix(np.argmax(y_te, axis=1), pred)
  274.     #   np.set_printoptions(precision=1)
  275.     #   plt.figure()
  276.     #   plot_confusion_matrix(cnf_matrix, classes=song_labels, normalize=True, title='Normalized confusion matrix')
  277.     #   print(precision_recall_fscore_support(np.argmax(y_te, axis=1),pred, average='macro'))
  278.     #   plt.savefig("matrix",format='png', dpi=1000)
  279.     #   raise SystemExit
  280.     al = {
  281.         "kernel_size": [],
  282.         "learning_rate": [],
  283.         "batch_size": [],
  284.     }
  285.     for a in [(4, 4), (3, 3), (5, 5)]:
  286.         for b in [0.01, 0.001, 0.1]:
  287.             for c in [300, 200, 400]:
  288.                 al["kernel_size"].append(a)
  289.                 al["learning_rate"].append(b)
  290.                 al["batch_size"].append(c)
  291.  
  292.     for i in range(0, len(al["kernel_size"])):
  293.         confusion_filename = "kernel_size:" + str(al["kernel_size"][i]) + \
  294.                              " lr:" + str(al["learning_rate"][i]) + \
  295.                              " bs:" + str(al["batch_size"][i])
  296.         print("train_network:",
  297.               str(al["kernel_size"][i]),
  298.               str(al["learning_rate"][i]),
  299.               str(al["batch_size"][i]),
  300.               confusion_filename
  301.               )
  302.         ml_train = cnn(kernel_size=al["kernel_size"][i], learning_rate=al["learning_rate"][i], input_shape=(64,173,1))
  303.         ann = Model(ml_train)
  304.         ann.train_model(confusion_matrix_img_name=confusion_filename,
  305.                         train_x=x_tr, train_y=y_tr, val_x=x_cv, val_y=y_cv, test_x=x_te, test_y=y_te)
  306.  
  307.     # ann = Model(cnn(kernel_size=(5, 5), learning_rate=0.01))
  308.     # dot_img_file = 'model_cnn_architecture.png'
  309.     # tf.keras.utils.plot_model(ann, to_file=dot_img_file, show_shapes=True)
  310.     # ann.train_model(x_tr, y_tr, val_x=x_cv, val_y=y_cv, test_x=x_te, test_y=y_te)
  311.  
  312.     # for item in os.listdir():
  313.     #     if item.rfind(".h5") != -1:
  314.     #         model_predict(item, keras.models.load_model(item, custom_objects={'metric': metric}), x_te, y_te)
  315.  
  316.  
  317. if __name__ == '__main__':
  318.     main()
  319.  
Advertisement
Advertisement
Advertisement
RAW Paste Data Copied
Advertisement