Advertisement
Guest User

Untitled

a guest
Jun 27th, 2019
148
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.48 KB | None | 0 0
  1. def __init__(self, output_dim, **kwargs):
  2. self.output_dim = output_dim
  3. super(MyLayer, self).__init__(**kwargs)
  4.  
  5. def build(self, input_shape):
  6. print(self.output_dim)
  7. print(input_shape)
  8. # Create a trainable weight variable for this layer.
  9.  
  10. self.kernel = self.add_weight(name='kernel',
  11. shape=(self.output_dim, input_shape[1]),
  12. initializer='uniform',
  13. trainable=True)
  14. super(MyLayer, self).build(input_shape) # Be sure to call this at the end
  15.  
  16. def call(self, x):
  17. print('shape x: ')
  18. print(x.shape)
  19. print('shape kernel: ')
  20. print(self.kernel.shape)
  21. matrix = tf.transpose(self.kernel)
  22. print('matrix')
  23. print(matrix.shape)
  24. prod = K.dot(x, matrix)
  25. print('after product')
  26. print(prod.shape)
  27.  
  28. return prod
  29.  
  30. def compute_output_shape(self, input_shape):
  31. print('Compute output shape')
  32. print(input_shape)
  33. print(self.output_dim)
  34. return (input_shape[0], self.output_dim)
  35.  
  36. model = Sequential()
  37. model = Sequential()
  38. model.add(MyLayer(5, batch_input_shape=(100, 5)))
  39. model.compile(optimizer='adam', loss='mse')
  40. # fit model
  41. model.fit(X_train, X_target)
  42.  
  43. tf.keras.__version__ = '2.1.6-tf'
  44.  
  45. InvalidArgumentError Traceback (most recent call last)
  46. <ipython-input-77-4dee23ead957> in <module>()
  47. 6 model.compile(optimizer='adam', loss='mse')
  48. 7 # fit model
  49. ----> 8 model.fit(X_train, X_target)#, epochs=300, verbose=0)
  50.  
  51. ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
  52. 1037 initial_epoch=initial_epoch,
  53. 1038 steps_per_epoch=steps_per_epoch,
  54. -> 1039 validation_steps=validation_steps)
  55. 1040
  56. 1041 def evaluate(self, x=None, y=None,
  57.  
  58. ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
  59. 197 ins_batch[i] = ins_batch[i].toarray()
  60. 198
  61. --> 199 outs = f(ins_batch)
  62. 200 outs = to_list(outs)
  63. 201 for l, o in zip(out_labels, outs):
  64.  
  65. ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
  66. 2713 return self._legacy_call(inputs)
  67. 2714
  68. -> 2715 return self._call(inputs)
  69. 2716 else:
  70. 2717 if py_any(is_tensor(x) for x in inputs):
  71.  
  72. ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
  73. 2673 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
  74. 2674 else:
  75. -> 2675 fetched = self._callable_fn(*array_vals)
  76. 2676 return fetched[:len(self.outputs)]
  77. 2677
  78.  
  79. ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
  80. 1380 ret = tf_session.TF_SessionRunCallable(
  81. 1381 self._session._session, self._handle, args, status,
  82. -> 1382 run_metadata_ptr)
  83. 1383 if run_metadata:
  84. 1384 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
  85.  
  86. ~/anaconda3/envs/ldsa/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
  87. 517 None, None,
  88. 518 compat.as_text(c_api.TF_Message(self.status.status)),
  89. --> 519 c_api.TF_GetCode(self.status.status))
  90. 520 # Delete the underlying status object from memory otherwise it stays alive
  91. 521 # as there is a reference to status from this from the traceback due to
  92.  
  93. InvalidArgumentError: Incompatible shapes: [100,5] vs. [32,5]
  94. [[Node: training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/Reshape"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/Shape, training_12/Adam/gradients/loss_17/my_layer_19_loss/sub_grad/Shape_1)]]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement