Advertisement
snakemelon

why i cant start training?

May 14th, 2022
1,090
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.26 KB | None | 0 0
  1. import numpy as np
  2. import itertools
  3. import logging
  4. import tensorflow as tf
  5. import keras
  6. from keras import backend as K
  7. from keras.models import Sequential
  8. from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D, BatchNormalization
  9. from tensorflow.keras.optimizers import Adam
  10. from keras.utils import np_utils
  11. from keras import regularizers
  12. import matplotlib
  13. from matplotlib import pyplot as plt
  14. from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
  15. import load_data
  16. import matplotlib
  17. import time
  18.  
  19. matplotlib.use("Agg")
  20.  
  21. # Not show the Warning information
  22. logging.getLogger("tensorflow").setLevel(logging.ERROR)
  23.  
  24. # use the GPU to train the network
  25. CUDA_VISIBEL_DEVICES = 0
  26.  
  27. # Models to be passed to Music_Genre_CNN
  28. song_labels = ["Blues", "Classical", "Country", "Disco", "Hip hop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]
  29. MODEL_PATH = "CNN/"
  30. MAX_ITERATION = 100
  31.  
  32.  
  33. # remote interpreter path:
  34. # /root/miniconda3/envs/myconda/bin/python
  35.  
  36. def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
  37.     """
  38.    This function prints and plots the confusion matrix.
  39.    Normalization can be applied by setting `normalize=True`.
  40.    """
  41.     if normalize:
  42.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  43.         print("Normalized confusion matrix")
  44.     else:
  45.         print('Confusion matrix, without normalization')
  46.  
  47.     print(cm)
  48.  
  49.     plt.imshow(cm, interpolation='nearest', cmap=cmap)
  50.     plt.title(title)
  51.     plt.colorbar()
  52.     tick_marks = np.arange(len(classes))
  53.     plt.xticks(tick_marks, classes, rotation=45)
  54.     plt.yticks(tick_marks, classes)
  55.  
  56.     fmt = '.1f' if normalize else 'd'
  57.     thresh = cm.max() / 2.
  58.     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  59.         plt.text(j, i, format(cm[i, j], fmt),
  60.                  horizontalalignment="center",
  61.                  color="white" if cm[i, j] > thresh else "black")
  62.  
  63.     plt.ylabel('True label')
  64.     plt.xlabel('Predicted label')
  65.     plt.tight_layout()
  66.  
  67.  
  68. def metric(y_true, y_pred):
  69.     return K.mean(K.equal(K.argmax(y_true, axis=1), K.argmax(y_pred, axis=1)))
  70.  
  71.  
  72. def cnn(kernel_size=(4,4), num_genres=10, input_shape=(64, 173, 1), learning_rate=0.01):
  73.     model = Sequential()
  74.  
  75.     model.add(Conv2D(64, kernel_size=kernel_size, activation='relu', input_shape=input_shape))
  76.     model.add(BatchNormalization())
  77.     model.add(MaxPooling2D(pool_size=(2, 4)))
  78.     model.add(Conv2D(64, (3, 5), activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  79.     model.add(MaxPooling2D(pool_size=(2, 2)))
  80.     model.add(Dropout(0.2))
  81.     model.add(Conv2D(64, (2, 2), activation='relu'))
  82.     # ,kernel_regularizer=regularizers.l2(0.04)
  83.     model.add(BatchNormalization())
  84.     model.add(MaxPooling2D(pool_size=(2, 2)))
  85.     model.add(Dropout(0.2))
  86.     model.add(Flatten())
  87.     model.add(Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  88.     model.add(Dropout(0.5))
  89.     model.add(Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.04)))
  90.     model.add(Dense(num_genres, activation='softmax'))
  91.     model.compile(loss=keras.losses.categorical_crossentropy,
  92.                   optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0),
  93.                   metrics=[metric])
  94.     return model
  95.  
  96.  
  97. def model_predict(model_name, model, test_x, test_y):
  98.     pred = model.predict(test_x, batch_size=32, verbose=1).argmax(axis=1)
  99.     cnf_matrix = confusion_matrix(np.argmax(test_y, axis=1), pred)
  100.     np.set_printoptions(precision=2)
  101.  
  102.     # visualization
  103.     plt.figure()
  104.     plot_confusion_matrix(cnf_matrix,
  105.                           classes=song_labels,
  106.                           normalize=True,
  107.                           title='Normalized confusion matrix')
  108.     print(precision_recall_fscore_support(np.argmax(test_y, axis=1), pred, average='macro'))
  109.     plt.savefig(str(model_name) + ".png", dpi=600)
  110.  
  111.  
  112. class Model(object):
  113.     # Main network thingy to train
  114.  
  115.     @tf.autograph.experimental.do_not_convert
  116.     def __init__(self, ann_model):
  117.         self.model = ann_model()
  118.  
  119.     t = time.gmtime()
  120.  
  121.     @tf.autograph.experimental.do_not_convert
  122.     def train_model(self, train_x, train_y, val_x=None, val_y=None, small_batch_size=300, max_iteration=MAX_ITERATION,
  123.                     print_interval=1, test_x=None, test_y=None,
  124.                     confusion_matrix_img_name=time.strftime("%m-%d %H:%M:%S", t)):
  125.  
  126.         batch_size = str(small_batch_size)
  127.         train_accuracy_list = []
  128.         train_loss_list = []
  129.         validation_accuracy_list = []
  130.         validation_loss_list = []
  131.         test_accuracy_list = []
  132.         test_loss_list = []
  133.         epoch_list = []
  134.  
  135.         m = len(train_x)
  136.         for it in range(max_iteration):
  137.  
  138.             # split training data into even batches
  139.             batch_idx = np.random.permutation(m)
  140.             train_x = train_x[batch_idx]
  141.             train_y = train_y[batch_idx]
  142.  
  143.             num_batches = int(m / small_batch_size)
  144.             for batch in range(num_batches):
  145.                 x_batch = train_x[batch * small_batch_size: (batch + 1) * small_batch_size]
  146.                 y_batch = train_y[batch * small_batch_size: (batch + 1) * small_batch_size]
  147.                 print("starting batch\t", batch, "\t Epoch:\t", it)
  148.  
  149.                 # tf.autograph.experimental.do_not_convert(func=self.model.train_on_batch(x_batch, y_batch))
  150.                 self.model.train_on_batch(x_batch, y_batch)
  151.  
  152.             if it % print_interval == 0:
  153.                 validation_accuracy = self.model.evaluate(val_x, val_y)
  154.                 training_accuracy = self.model.evaluate(train_x, train_y)
  155.                 testing_accuracy = self.model.evaluate(test_x, test_y)
  156.  
  157.                 # print of test error used only after development of the model
  158.                 print("\nTraining accuracy: %f\t Validation accuracy: %f\t Testing Accuracy: %f" %
  159.                       (training_accuracy[1], validation_accuracy[1], testing_accuracy[1]))
  160.                 print("\nTraining loss: %f    \t Validation loss: %f    \t Testing Loss: %f \n" %
  161.                       (training_accuracy[0], validation_accuracy[0], testing_accuracy[0]))
  162.                 print()
  163.  
  164.                 # the data for draw
  165.                 epoch_list.append(it)
  166.  
  167.                 train_accuracy_list.append(training_accuracy[1])
  168.                 train_loss_list.append(training_accuracy[0])
  169.  
  170.                 validation_accuracy_list.append(validation_accuracy[1])
  171.                 validation_loss_list.append(validation_accuracy[0])
  172.  
  173.                 test_accuracy_list.append(testing_accuracy[1])
  174.                 test_loss_list.append(testing_accuracy[0])
  175.  
  176.             # 验证集准确度>.81则输出训练的效果图 matrix
  177.             if validation_accuracy[1] >= 0.8:
  178.                 print("Saving confusion data...")
  179.                 # 较早版本的
  180.                 model_name = "model_" + str(100 * validation_accuracy[1]) + "_" + str(100 * testing_accuracy[1]) + ".h5"
  181.                 # self.model.save(model_name)
  182.                 # pred = self.model.predict_classes(test_x, verbose=1)
  183.                 pred = self.model.predict(test_x, verbose=1).argmax(axis=1)
  184.                 cnf_matrix = confusion_matrix(np.argmax(test_y, axis=1), pred)
  185.                 np.set_printoptions(precision=2)
  186.  
  187.                 # visualization
  188.                 plt.figure()
  189.                 plot_confusion_matrix(cnf_matrix,
  190.                                       classes=song_labels,
  191.                                       normalize=True,
  192.                                       title='Normalized confusion matrix')
  193.                 print(precision_recall_fscore_support(np.argmax(test_y, axis=1), pred, average='macro'))
  194.                 plt.savefig(confusion_matrix_img_name + ".png", dpi=600)
  195.                 # break
  196.  
  197.         # save the png file and analyze the batch_size and the learning_rate
  198.         fig = plt.figure()
  199.         ax1 = fig.add_subplot(111)
  200.         ax1.plot(epoch_list, train_accuracy_list, label="train_accuracy", color='blue')
  201.         ax2 = ax1.twinx()
  202.         ax2.plot(epoch_list, train_loss_list, label="train_loss", color='red')
  203.  
  204.         learning_rate = 0.01
  205.         title_name = "batch_size=" + batch_size + " learning_rate=" + str(learning_rate)
  206.  
  207.         plt.legend()
  208.         plt.title(title_name)
  209.         plt.savefig(title_name + ".png", dpi=600)
  210.         print(title_name, "saved successfully!")
  211.  
  212.  
  213. def main():
  214.     # Data stuff
  215.     data = load_data.loadall('melspects.npz')
  216.  
  217.     # tmp = np.load('melspects.npz')
  218.     x_tr = data['x_tr']
  219.     y_tr = data['y_tr']
  220.     x_te = data['x_te']
  221.     y_te = data['y_te']
  222.     x_cv = data['x_cv']
  223.     y_cv = data['y_cv']
  224.  
  225.     tr_idx = np.random.permutation(len(x_tr))
  226.     te_idx = np.random.permutation(len(x_te))
  227.     cv_idx = np.random.permutation(len(x_cv))
  228.  
  229.     x_tr = x_tr[tr_idx]
  230.     y_tr = y_tr[tr_idx]
  231.     x_te = x_te[te_idx]
  232.     y_te = y_te[te_idx]
  233.     x_cv = x_cv[cv_idx]
  234.     y_cv = y_cv[cv_idx]
  235.  
  236.     x_tr = x_tr[:, :, :, np.newaxis]
  237.     x_te = x_te[:, :, :, np.newaxis]
  238.     x_cv = x_cv[:, :, :, np.newaxis]
  239.  
  240.     y_tr = np_utils.to_categorical(y_tr)
  241.     y_te = np_utils.to_categorical(y_te)
  242.     y_cv = np_utils.to_categorical(y_cv)
  243.  
  244.     # training = np.load('gtzan/gtzan_tr.npy')
  245.     # x_tr = np.delete(training, -1, 1)
  246.     # label_tr = training[:,-1]
  247.  
  248.     # test = np.load('gtzan/gtzan_te.npy')
  249.     # x_te = np.delete(test, -1, 1)
  250.     # label_te = test[:,-1]
  251.  
  252.     # cv = np.load('gtzan/gtzan_cv.npy')
  253.     # x_cv = np.delete(cv, -1, 1)
  254.     # label_cv = test[:,-1]
  255.  
  256.     # temp = np.zeros((len(label_tr),10))
  257.     # temp[np.arange(len(label_tr)),label_tr.astype(int)] = 1
  258.     # y_tr = temp
  259.     # temp = np.zeros((len(label_te),10))
  260.     # temp[np.arange(len(label_te)),label_te.astype(int)] = 1
  261.     # y_te = temp
  262.     # temp = np.zeros((len(label_cv),10))
  263.     # temp[np.arange(len(label_cv)),label_cv.astype(int)] = 1
  264.     # y_cv = temp
  265.     # del temp
  266.  
  267.     #################################################
  268.  
  269.     #    if True:
  270.     #   model = keras.models.load_model('model84.082.0.h5', custom_objects={'metric': metric})
  271.     #   print("Saving confusion data...")
  272.     #   pred = model.predict_classes(x_te, verbose=1)
  273.     #   cnf_matrix = confusion_matrix(np.argmax(y_te, axis=1), pred)
  274.     #   np.set_printoptions(precision=1)
  275.     #   plt.figure()
  276.     #   plot_confusion_matrix(cnf_matrix, classes=song_labels, normalize=True, title='Normalized confusion matrix')
  277.     #   print(precision_recall_fscore_support(np.argmax(y_te, axis=1),pred, average='macro'))
  278.     #   plt.savefig("matrix",format='png', dpi=1000)
  279.     #   raise SystemExit
  280.     al = {
  281.         "kernel_size": [],
  282.         "learning_rate": [],
  283.         "batch_size": [],
  284.     }
  285.     for a in [(4, 4), (3, 3), (5, 5)]:
  286.         for b in [0.01, 0.001, 0.1]:
  287.             for c in [300, 200, 400]:
  288.                 al["kernel_size"].append(a)
  289.                 al["learning_rate"].append(b)
  290.                 al["batch_size"].append(c)
  291.  
  292.     for i in range(0, len(al["kernel_size"])):
  293.         confusion_filename = "kernel_size:" + str(al["kernel_size"][i]) + \
  294.                              " lr:" + str(al["learning_rate"][i]) + \
  295.                              " bs:" + str(al["batch_size"][i])
  296.         print("train_network:",
  297.               str(al["kernel_size"][i]),
  298.               str(al["learning_rate"][i]),
  299.               str(al["batch_size"][i]),
  300.               confusion_filename
  301.               )
  302.         ml_train = cnn(kernel_size=al["kernel_size"][i], learning_rate=al["learning_rate"][i], input_shape=(64,173,1))
  303.         ann = Model(ml_train)
  304.         ann.train_model(confusion_matrix_img_name=confusion_filename,
  305.                         train_x=x_tr, train_y=y_tr, val_x=x_cv, val_y=y_cv, test_x=x_te, test_y=y_te)
  306.  
  307.     # ann = Model(cnn(kernel_size=(5, 5), learning_rate=0.01))
  308.     # dot_img_file = 'model_cnn_architecture.png'
  309.     # tf.keras.utils.plot_model(ann, to_file=dot_img_file, show_shapes=True)
  310.     # ann.train_model(x_tr, y_tr, val_x=x_cv, val_y=y_cv, test_x=x_te, test_y=y_te)
  311.  
  312.     # for item in os.listdir():
  313.     #     if item.rfind(".h5") != -1:
  314.     #         model_predict(item, keras.models.load_model(item, custom_objects={'metric': metric}), x_te, y_te)
  315.  
  316.  
  317. if __name__ == '__main__':
  318.     main()
  319.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement