Advertisement
Guest User

Untitled

a guest
Oct 14th, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.31 KB | None | 0 0
  1. import torch
  2. from torch import Tensor
  3. from typing import Iterable
  4. from fastprogress import progress_bar
  5.  
  6. class RunningStatistics:
  7. '''Records mean and variance of the final `n_dims` dimension over other dimensions across items. So collecting across `(l,m,n,o)` sized
  8. items with `n_dims=1` will collect `(l,m,n)` sized statistics while with `n_dims=2` the collected statistics will be of size `(l,m)`.
  9.  
  10. Uses the algorithm from Chan, Golub, and LeVeque in "Algorithms for computing the sample variance: analysis and recommendations":
  11.  
  12. `variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2), 2)`
  13.  
  14. This combines the variance for 2 blocks: block 1 having `n` elements with `variance1` and a sum of `t1` and block 2 having `m` elements
  15. with `variance2` and a sum of `t2`. The algorithm is proven to be numerically stable but there is a reasonable loss of accuracy (~0.1% error).
  16.  
  17. Note that collecting minimum and maximum values is reasonably innefficient, adding about 80% to the running time, and hence is disabled by default.
  18. '''
  19. def __init__(self, n_dims:int=2, record_range=False):
  20. self._n_dims,self._range = n_dims,record_range
  21. self.n,self.sum,self.min,self.max = 0,None,None,None
  22.  
  23. def update(self, data:Tensor):
  24. data = data.view(*list(data.shape[:-self._n_dims]) + [-1])
  25. with torch.no_grad():
  26. new_n,new_var,new_sum = data.shape[-1],data.var(-1),data.sum(-1)
  27. if self.n == 0:
  28. self.n = new_n
  29. self._shape = data.shape[:-1]
  30. self.sum = new_sum
  31. self._nvar = new_var.mul_(new_n)
  32. if self._range:
  33. self.min = data.min(-1)[0]
  34. self.max = data.max(-1)[0]
  35. else:
  36. assert data.shape[:-1] == self._shape, f"Mismatched shapes, expected {self._shape} but got {data.shape[:-1]}."
  37. ratio = self.n / new_n
  38. t = (self.sum / ratio).sub_(new_sum).pow_(2)
  39. self._nvar.add_(new_n, new_var).add_(ratio / (self.n + new_n), t)
  40. self.sum.add_(new_sum)
  41. self.n += new_n
  42. if self._range:
  43. self.min = torch.min(self.min, data.min(-1)[0])
  44. self.max = torch.max(self.max, data.max(-1)[0])
  45.  
  46. @property
  47. def mean(self): return self.sum / self.n if self.n > 0 else None
  48. @property
  49. def var(self): return self._nvar / self.n if self.n > 0 else None
  50. @property
  51. def std(self): return self.var.sqrt() if self.n > 0 else None
  52.  
  53. def __repr__(self):
  54. def _fmt_t(t:Tensor):
  55. if t.numel() > 5: return f"tensor of ({','.join(map(str,t.shape))})"
  56. def __fmt_t(t:Tensor):
  57. return '[' + ','.join([f"{v:.3g}" if v.ndim==0 else __fmt_t(v) for v in t]) + ']'
  58. return __fmt_t(t)
  59. rng_str = f", min={_fmt_t(self.min)}, max={_fmt_t(self.max)}" if self._range else ""
  60. return f"RunningStatistics(n={self.n}, mean={_fmt_t(self.mean)}, std={_fmt_t(self.std)}{rng_str})"
  61.  
  62. def collect_stats(items:Iterable, n_dims:int=2, record_range:bool=False):
  63. stats = RunningStatistics(n_dims, record_range)
  64. for it in progress_bar(items):
  65. if hasattr(it, 'data'):
  66. stats.update(it.data)
  67. else:
  68. stats.update(it)
  69. return stats
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement