Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Seq of tensors to tensor
- proc toTensor*[T](s: openarray[Tensor[T]]): Tensor[T] =
- s.map(proc(t: Tensor[T]): Tensor[T] = t.unsafeUnsqueeze(0)).concat(0)
- # Make universal does not work with apply
- proc abs*[T](t: Tensor[T]): Tensor[T] =
- t.map(proc(x: T):T = abs(x))
- # Seq to reshaped tensor, no copy
- proc unsafeToTensorReshape[T](data: seq[T], shape: openarray[int]): Tensor[T] {.noSideEffect.} =
- result.shape = @shape
- result.strides = shape_to_strides(result.shape)
- result.offset = 0
- shallowCopy(result.data, data)
- # This is not the full implementation
- template unsafeAt[T](t: Tensor[T], x: int): Tensor[T] =
- t.unsafeView(x, _, _).unsafeReshape([t.shape[1], t.shape[2]])
- # unsafeSqueeze on axis
- proc unsafeSqueeze*[T](t: Tensor[T], axis: int): Tensor[T] {.noSideEffect,inline.} =
- var shape = t.shape
- assert shape[axis] == 1
- shape.delete(axis)
- t.unsafeReshape(shape)
- # unsqueeze on axis
- proc unsafeUnsqueeze*(t: Tensor, axis: int): Tensor =
- var shape = t.shape
- shape.insert(1, axis)
- t.reshape(shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement