Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ValueError: Error when checking input: expected conv2d_1_input to have 4 dimensions, but got array with shape (8, 4)
- from google_drive_downloader import GoogleDriveDownloader as gdd
- gdd.download_file_from_google_drive(
- file_id='13WSlx4cmXh3wfvzNEXbZAbc2a1RHpPPP',
- dest_path='./data/klasifikacia.zip',
- unzip=True)
- gdd.download_file_from_google_drive(
- file_id='1k3Lz79cF40peKCf-UdobyIUdbKZ2Onfp',
- dest_path='./data/model.h5',
- unzip=False)
- import keras
- import numpy as np
- from keras.preprocessing.image import ImageDataGenerator
- from keras.models import load_model
- from matplotlib import pyplot as plt
- import matplotlib.pyplot as plt
- %matplotlib inline
- test_path = './data/test'
- model_name = 'keras_drone_trained_model.h5'
- test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(100, 100), classes=['biker', 'pedestrian', 'golf_cart', 'skater'], batch_size=8)
- test_datagen = ImageDataGenerator(rescale=1./255)
- def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
- if type(ims[0]) is np.ndarray:
- ims = np.array(ims).astype(np.uint8)
- if(ims.shape[-1] != 3):
- ims = ims.transpose((0,2,3,1))
- f = plt.figure(figsize=figsize)
- cols = len(ims)//rows if len(ims)%2 == 0 else len(ims)//rows + 1
- for i in range(len (ims)):
- sp = f.add_subplot(rows, cols, i+1)
- sp.axis('Off')
- if titles is not None:
- sp.set_title(titles[i], fontsize=14)
- plt.imshow(ims[i], interpolation=None if interp else 'none')
- test_imgs, test_labels = next(test_batches)
- plots(test_imgs, titles=test_labels)
- classes = test_batches.class_indices
- print(classes)
- model = load_model('data/model.h5')
- x = model.predict_generator(test_batches, steps=1, verbose=0)
- predict = model.predict_classes(x)
- predict
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement