Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /**
- original: the tensor to modify
- axis: the axis to splice in
- start: the position in the given axis to begin splicing from
- deleteCount: how many frames to remove from the tensor
- toInsert: the tensor to insert at "start" (all axes must match except for "axis")
- */
- function spliceTensor(original, axis, start, deleteCount, toInsert)
- {
- var preStart = original.shape.map((e) => 0);
- var preEnd = original.shape.map((e, i) => i == axis ? start : e);
- var postStart = original.shape.map((e, i) => i == axis ? (start + deleteCount) : 0);
- var postEnd = original.shape.map((e, i) => i == axis ? (e - (start + deleteCount)) : e);
- var pre = preEnd[axis] == 0 ? null : original.slice(preStart, preEnd);
- var post = postStart[axis] == original.shape[axis] ? null : original.slice(postStart, postEnd);
- var toConcat = [];
- if(pre != null) { toConcat.push(pre); }
- if(toInsert != null) { toConcat.push(toInsert); }
- if(post != null) { toConcat.push(post); }
- return tf.concat(toConcat, axis);
- }
- /* usage */
- var tf = require("@tensorflow/tfjs-core")
- var orig = tf.ones([4, 4]);
- var vertical = tf.zeros([4, 2]);
- var horizontal = tf.zeros([2, 4]);
- orig.print();
- vertical.print();
- horizontal.print();
- var spliced = spliceTensor(orig, 1, 2, 0, vertical);
- spliced.print();
- /**
- Tensor
- [[1, 1, 0, 0, 1, 1],
- [1, 1, 0, 0, 1, 1],
- [1, 1, 0, 0, 1, 1],
- [1, 1, 0, 0, 1, 1]]
- */
- spliced = spliceTensor(orig, 0, 2, 1, horizontal);
- spliced.print();
- /**
- Tensor
- [[1, 1, 1, 1],
- [1, 1, 1, 1],
- [0, 0, 0, 0],
- [0, 0, 0, 0],
- [1, 1, 1, 1]]
- */
- spliced = spliceTensor(orig, 0, 2, 1);
- spliced.print();
- /*
- Tensor
- [[1, 1, 1, 1],
- [1, 1, 1, 1],
- [1, 1, 1, 1]]
- */
Add Comment
Please, Sign In to add comment