Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from torch import Tensor
- from typing import Iterable
- from fastprogress import progress_bar
- class RunningStatistics:
- '''Records mean and variance of the final `n_dims` dimension over other dimensions across items. So collecting across `(l,m,n,o)` sized
- items with `n_dims=1` will collect `(l,m,n)` sized statistics while with `n_dims=2` the collected statistics will be of size `(l,m)`.
- Uses the algorithm from Chan, Golub, and LeVeque in "Algorithms for computing the sample variance: analysis and recommendations":
- `variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2), 2)`
- This combines the variance for 2 blocks: block 1 having `n` elements with `variance1` and a sum of `t1` and block 2 having `m` elements
- with `variance2` and a sum of `t2`. The algorithm is proven to be numerically stable but there is a reasonable loss of accuracy (~0.1% error).
- Note that collecting minimum and maximum values is reasonably innefficient, adding about 80% to the running time, and hence is disabled by default.
- '''
- def __init__(self, n_dims:int=2, record_range=False):
- self._n_dims,self._range = n_dims,record_range
- self.n,self.sum,self.min,self.max = 0,None,None,None
- def update(self, data:Tensor):
- data = data.view(*list(data.shape[:-self._n_dims]) + [-1])
- with torch.no_grad():
- new_n,new_var,new_sum = data.shape[-1],data.var(-1),data.sum(-1)
- if self.n == 0:
- self.n = new_n
- self._shape = data.shape[:-1]
- self.sum = new_sum
- self._nvar = new_var.mul_(new_n)
- if self._range:
- self.min = data.min(-1)[0]
- self.max = data.max(-1)[0]
- else:
- assert data.shape[:-1] == self._shape, f"Mismatched shapes, expected {self._shape} but got {data.shape[:-1]}."
- ratio = self.n / new_n
- t = (self.sum / ratio).sub_(new_sum).pow_(2)
- self._nvar.add_(new_n, new_var).add_(ratio / (self.n + new_n), t)
- self.sum.add_(new_sum)
- self.n += new_n
- if self._range:
- self.min = torch.min(self.min, data.min(-1)[0])
- self.max = torch.max(self.max, data.max(-1)[0])
- @property
- def mean(self): return self.sum / self.n if self.n > 0 else None
- @property
- def var(self): return self._nvar / self.n if self.n > 0 else None
- @property
- def std(self): return self.var.sqrt() if self.n > 0 else None
- def __repr__(self):
- def _fmt_t(t:Tensor):
- if t.numel() > 5: return f"tensor of ({','.join(map(str,t.shape))})"
- def __fmt_t(t:Tensor):
- return '[' + ','.join([f"{v:.3g}" if v.ndim==0 else __fmt_t(v) for v in t]) + ']'
- return __fmt_t(t)
- rng_str = f", min={_fmt_t(self.min)}, max={_fmt_t(self.max)}" if self._range else ""
- return f"RunningStatistics(n={self.n}, mean={_fmt_t(self.mean)}, std={_fmt_t(self.std)}{rng_str})"
- def collect_stats(items:Iterable, n_dims:int=2, record_range:bool=False):
- stats = RunningStatistics(n_dims, record_range)
- for it in progress_bar(items):
- if hasattr(it, 'data'):
- stats.update(it.data)
- else:
- stats.update(it)
- return stats
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement