Advertisement
Guest User

Untitled

a guest
Mar 22nd, 2017
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.37 KB | None | 0 0
  1. import functools
  2. import inspect
  3. import xarray as xr
  4. import dask.array as da
  5. import numpy as np
  6.  
  7. from . import util
  8.  
  9. from operator import mul
  10. from functools import reduce
  11.  
  12. @xr.register_dataarray_accessor('reshape')
  13. class XRReshaper(object):
  14. """An object for reshaping DataArrays into 2D matrices
  15.  
  16. This can be used to easily transform dataarrays into a format suitable for
  17. input to scikit-learn functions.
  18.  
  19. """
  20.  
  21. def __init__(self, da):
  22. self._da = da
  23.  
  24. @property
  25. def dims(self):
  26. return self._da.dims
  27.  
  28. def to(self, feature_dims):
  29. """Reshape data array into 2D array
  30.  
  31. Parameters
  32. ----------
  33. feature_dims: seq of dim names
  34. list of dimensions that will be the features (i.e. columns) for the result
  35.  
  36. Returns
  37. -------
  38. arr: matrix
  39. reshaped data
  40. dims: seq of dim names
  41. list of dim names in the same order as the output array. useful for the from function below.
  42.  
  43. """
  44. A = self._da
  45.  
  46.  
  47. dim_list = [dim for dim in A.dims if dim not in feature_dims] \
  48. + feature_dims
  49.  
  50. axes_list = [A.get_axis_num(dim) for dim in dim_list]
  51.  
  52. npa = A.data.transpose(axes_list)
  53.  
  54. sh = npa.shape
  55.  
  56. nfeats = np.prod(sh[-len(feature_dims):])
  57. npa = npa.reshape((-1, nfeats))
  58.  
  59. return npa, dim_list
  60.  
  61. def get(self, arr, dims, extra_coords={}):
  62.  
  63. coords = {}
  64. unknown_dims = []
  65. # get known coordinats
  66. for i, dim in enumerate(dims):
  67. if dim in self._da.coords:
  68. coords[dim] = self._da[dim].values
  69.  
  70.  
  71. # merge in extra coords
  72. coords.update(extra_coords)
  73.  
  74. unknown_dims = [dim for dim in dims
  75. if dim not in coords]
  76.  
  77. # deal with unknown coords
  78. if len(unknown_dims) == 0:
  79. pass
  80. elif len(unknown_dims) == 1:
  81. n_known_coords = reduce(mul, (len(val) for _,val in coords.items()))
  82. n_unknown_coord = arr.size / n_known_coords
  83. coords[unknown_dims[0]] = np.arange(n_unknown_coord)
  84. else:
  85. print(unknown_dims)
  86. raise ValueError("Only one unknown dim is allowed")
  87.  
  88. # create new shape
  89. sh = [len(coords[dim]) for dim in dims]
  90.  
  91. # reshape
  92. arr = arr.reshape(sh)
  93.  
  94. return xr.DataArray(arr, dims=dims, coords=coords)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement