Advertisement
Guest User

Untitled

a guest
Jun 25th, 2019
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.78 KB | None | 0 0
  1. ValueError: Error when checking input: expected conv2d_1_input to have 4 dimensions, but got array with shape (8, 4)
  2.  
  3. from google_drive_downloader import GoogleDriveDownloader as gdd
  4.  
  5. gdd.download_file_from_google_drive(
  6. file_id='13WSlx4cmXh3wfvzNEXbZAbc2a1RHpPPP',
  7. dest_path='./data/klasifikacia.zip',
  8. unzip=True)
  9.  
  10. gdd.download_file_from_google_drive(
  11. file_id='1k3Lz79cF40peKCf-UdobyIUdbKZ2Onfp',
  12. dest_path='./data/model.h5',
  13. unzip=False)
  14.  
  15. import keras
  16. import numpy as np
  17. from keras.preprocessing.image import ImageDataGenerator
  18. from keras.models import load_model
  19. from matplotlib import pyplot as plt
  20. import matplotlib.pyplot as plt
  21. %matplotlib inline
  22.  
  23. test_path = './data/test'
  24. model_name = 'keras_drone_trained_model.h5'
  25.  
  26. test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(100, 100), classes=['biker', 'pedestrian', 'golf_cart', 'skater'], batch_size=8)
  27. test_datagen = ImageDataGenerator(rescale=1./255)
  28.  
  29. def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
  30. if type(ims[0]) is np.ndarray:
  31. ims = np.array(ims).astype(np.uint8)
  32. if(ims.shape[-1] != 3):
  33. ims = ims.transpose((0,2,3,1))
  34. f = plt.figure(figsize=figsize)
  35. cols = len(ims)//rows if len(ims)%2 == 0 else len(ims)//rows + 1
  36. for i in range(len (ims)):
  37. sp = f.add_subplot(rows, cols, i+1)
  38. sp.axis('Off')
  39. if titles is not None:
  40. sp.set_title(titles[i], fontsize=14)
  41. plt.imshow(ims[i], interpolation=None if interp else 'none')
  42.  
  43. test_imgs, test_labels = next(test_batches)
  44. plots(test_imgs, titles=test_labels)
  45.  
  46. classes = test_batches.class_indices
  47. print(classes)
  48.  
  49. model = load_model('data/model.h5')
  50.  
  51. x = model.predict_generator(test_batches, steps=1, verbose=0)
  52. predict = model.predict_classes(x)
  53. predict
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement