Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def as_tensor_space(pspace, axis=None):
- """Convert a `ProductSpace` of `TensorSpace`'s to a tensor space.
- Parameters
- ----------
- pspace : `ProductSpace`
- Power space (with arbitrary shape) whose base is a `TensorSpace`.
- axis : int or sequence of int, optional
- Indices at which the powers should be inserted as new axes.
- For the default ``None``, the power axes are added to the left.
- Examples
- --------
- >>> pspace = odl.rn(3) ** 2
- >>> as_tensor_space(pspace)
- rn((2, 3))
- >>> as_tensor_space(pspace, axis=1)
- rn((3, 2))
- >>> pspace = odl.rn(4) ** (2, 3)
- >>> as_tensor_space(pspace)
- rn((2, 3, 4))
- >>> as_tensor_space(pspace, axis=(0, 2))
- rn((2, 4, 3))
- """
- assert isinstance(pspace, odl.ProductSpace) and pspace.is_power_space
- power_shape = pspace.shape
- power_ndim = len(pspace.shape)
- base = pspace[(0,) * power_ndim]
- assert isinstance(base, odl.space.base_tensors.TensorSpace)
- if axis is None:
- axis = list(range(power_ndim))
- elif np.isscalar(axis):
- axis = [axis]
- assert len(axis) == power_ndim
- newshape = []
- i = j = 0
- for nd in range(power_ndim + base.ndim):
- if nd in axis:
- newshape.append(power_shape[i])
- i += 1
- else:
- newshape.append(base.shape[j])
- j += 1
- # TODO: This disregards weighting completely, needs fix!
- return type(base)(newshape, dtype=base.dtype)
- def as_power_space(tspace, axis=None):
- """Convert a `TensorSpace` to a `ProductSpace` smaller tensor spaces.
- Parameters
- ----------
- tspace : `TensorSpace`
- Tensor space with ``ndim >= 1`` that should be converted to a
- `ProductSpace`.
- axis : int or sequence of int, optional
- Indices of the axes that should be turned into powers.
- For the default ``None``, the first axis is taken.
- Examples
- --------
- >>> tspace = odl.rn((2, 3))
- >>> as_power_space(tspace)
- ProductSpace(rn(3), 2)
- >>> as_power_space(tspace, axis=1)
- ProductSpace(rn(2), 3)
- >>> tspace = odl.rn((2, 3, 4))
- >>> as_power_space(tspace)
- ProductSpace(rn(3, 4), 2)
- >>> as_power_space(tspace, axis=(0, 2))
- ProductSpace(ProductSpace(rn(3), 4), 2)
- """
- assert isinstance(tspace, odl.space.base_tensors.TensorSpace)
- assert tspace.ndim >= 1
- if axis is None:
- axis = [0]
- elif np.isscalar(axis):
- axis = [axis]
- else:
- axis = list(axis)
- remaining_axes = [i for i in range(tspace.ndim) if i not in axis]
- removed_shape = [n for i, n in enumerate(tspace.shape) if i in axis]
- return tspace.byaxis[remaining_axes] ** removed_shape
Add Comment
Please, Sign In to add comment