Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import functools
- import inspect
- import xarray as xr
- import dask.array as da
- import numpy as np
- from . import util
- from operator import mul
- from functools import reduce
- @xr.register_dataarray_accessor('reshape')
- class XRReshaper(object):
- """An object for reshaping DataArrays into 2D matrices
- This can be used to easily transform dataarrays into a format suitable for
- input to scikit-learn functions.
- """
- def __init__(self, da):
- self._da = da
- @property
- def dims(self):
- return self._da.dims
- def to(self, feature_dims):
- """Reshape data array into 2D array
- Parameters
- ----------
- feature_dims: seq of dim names
- list of dimensions that will be the features (i.e. columns) for the result
- Returns
- -------
- arr: matrix
- reshaped data
- dims: seq of dim names
- list of dim names in the same order as the output array. useful for the from function below.
- """
- A = self._da
- dim_list = [dim for dim in A.dims if dim not in feature_dims] \
- + feature_dims
- axes_list = [A.get_axis_num(dim) for dim in dim_list]
- npa = A.data.transpose(axes_list)
- sh = npa.shape
- nfeats = np.prod(sh[-len(feature_dims):])
- npa = npa.reshape((-1, nfeats))
- return npa, dim_list
- def get(self, arr, dims, extra_coords={}):
- coords = {}
- unknown_dims = []
- # get known coordinats
- for i, dim in enumerate(dims):
- if dim in self._da.coords:
- coords[dim] = self._da[dim].values
- # merge in extra coords
- coords.update(extra_coords)
- unknown_dims = [dim for dim in dims
- if dim not in coords]
- # deal with unknown coords
- if len(unknown_dims) == 0:
- pass
- elif len(unknown_dims) == 1:
- n_known_coords = reduce(mul, (len(val) for _,val in coords.items()))
- n_unknown_coord = arr.size / n_known_coords
- coords[unknown_dims[0]] = np.arange(n_unknown_coord)
- else:
- print(unknown_dims)
- raise ValueError("Only one unknown dim is allowed")
- # create new shape
- sh = [len(coords[dim]) for dim in dims]
- # reshape
- arr = arr.reshape(sh)
- return xr.DataArray(arr, dims=dims, coords=coords)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement