Advertisement
Guest User

Untitled

a guest
Oct 19th, 2019
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.58 KB | None | 0 0
  1. import * as tf from '@tensorflow/tfjs';
  2.  
  3. class Lambda extends tf.layers.Layer {
  4. constructor() {
  5. super({});
  6. this.supportsMasking = true;
  7. this.constOutputShape = [256, 256]
  8. }
  9.  
  10. computeOutputShape(inputShape) {
  11. return [inputShape[0], this.constOutputShape[0], this.constOutputShape[1], inputShape[3]]
  12. }
  13.  
  14. call(inputs, kwargs) {
  15. let input = inputs;
  16. if (Array.isArray(input)) {
  17. input = input[0];
  18. }
  19. return tf.image.resizeBilinear(input, this.constOutputShape);
  20. }
  21.  
  22. static get className() {
  23. return 'Lambda';
  24. }
  25. }
  26. tf.serialization.registerClass(Lambda);
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement