Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def __init__(self, output_dim, **kwargs):
- self.output_dim = output_dim
- super(MyLayer, self).__init__(**kwargs)
- def build(self, input_shape):
- print(self.output_dim)
- print(input_shape)
- # Create a trainable weight variable for this layer.
- self.kernel = self.add_weight(name='kernel',
- shape=(self.output_dim, input_shape[1]),
- initializer='uniform',
- trainable=True)
- super(MyLayer, self).build(input_shape) # Be sure to call this at the end
- def call(self, x):
- print('shape x: ')
- print(x.shape)
- print('shape kernel: ')
- print(self.kernel.shape)
- matrix = tf.transpose(self.kernel)
- print('matrix')
- print(matrix.shape)
- prod = K.dot(x, matrix)
- print('after product')
- print(prod.shape)
- return prod
- def compute_output_shape(self, input_shape):
- print('Compute output shape')
- print(input_shape)
- print(self.output_dim)
- return (input_shape[0], self.output_dim)
- model = Sequential()
- model = Sequential()
- model.add(MyLayer(5, batch_input_shape=(100, 5)))
- model.compile(optimizer='adam', loss='mse')
- # fit model
- model.fit(X_train, X_target)
- tf.keras.__version__ = '2.1.6-tf'
- InvalidArgumentError Traceback (most recent call last)
- <ipython-input-77-4dee23ead957> in <module>()
- 6 model.compile(optimizer='adam', loss='mse')
- 7 # fit model
- ----> 8 model.fit(X_train, X_target)#, epochs=300, verbose=0)
- ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
- 1037 initial_epoch=initial_epoch,
- 1038 steps_per_epoch=steps_per_epoch,
- -> 1039 validation_steps=validation_steps)
- 1040
- 1041 def evaluate(self, x=None, y=None,
- ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
- 197 ins_batch[i] = ins_batch[i].toarray()
- 198
- --> 199 outs = f(ins_batch)
- 200 outs = to_list(outs)
- 201 for l, o in zip(out_labels, outs):
- ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
- 2713 return self._legacy_call(inputs)
- 2714
- -> 2715 return self._call(inputs)
- 2716 else:
- 2717 if py_any(is_tensor(x) for x in inputs):
- ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
- 2673 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
- 2674 else:
- -> 2675 fetched = self._callable_fn(*array_vals)
- 2676 return fetched[:len(self.outputs)]
- 2677
- ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
- 1380 ret = tf_session.TF_SessionRunCallable(
- 1381 self._session._session, self._handle, args, status,
- -> 1382 run_metadata_ptr)
- 1383 if run_metadata:
- 1384 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
- ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
- 517 None, None,
- 518 compat.as_text(c_api.TF_Message(self.status.status)),
- --> 519 c_api.TF_GetCode(self.status.status))
- 520 # Delete the underlying status object from memory otherwise it stays alive
- 521 # as there is a reference to status from this from the traceback due to
- InvalidArgumentError: Incompatible shapes: [100,5] vs. [32,5]
- [[Node: training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/Reshape"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/Shape, training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/Shape_1)]]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement