Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- class RunningStats:
- def __init__(self, shape, *, _mu=None, _s=None, _n=0):
- self._shape = shape
- self._mu = _mu if _mu is not None else np.zeros(shape)
- self._s = _s if _s is not None else np.zeros(shape)
- self.n = _n or 0
- def copy(self, *, shallow=False):
- if shallow:
- return self.__class__(
- self._shape, _mu=self._mu, _s=self._s, _n=self.n)
- else:
- return self.__class__(
- self._shape, _mu=self._mu.copy(), _s=self._s.copy(), _n=self.n)
- def clear(self):
- self._mu *= 0
- self._s *= 0
- self.n = 0
- def append(self, x):
- if x.shape != self._shape:
- raise ValueError('Wrong shape for x!')
- self.n += 1
- delta = self._mu - x
- self._mu = ((self.n - 1) * self._mu + x) / self.n
- self._s += delta * delta * (self.n - 1) / self.n
- def extend(self, xs):
- if isinstance(xs, self.__class__):
- mu2 = xs._mu
- s2 = xs._s
- n2 = xs.n
- if n2 == 0:
- return
- else:
- n2 = len(xs)
- if n2 == 0:
- return
- mu2 = np.mean(xs, axis=0)
- s2 = n2 * np.var(xs, axis=0)
- if self._mu.shape[1:] != mu2.shape[1:]:
- raise ValueError('Wrong shape for xs!')
- if self.n == 0:
- self._mu[:] = mu2
- self._s[:] = s2
- self.n = n2
- else:
- delta = self._mu - mu2
- n_sum = self.n + n2
- n_prod = self.n * n2
- self._mu = (self.n * self._mu + n2 * mu2) / n_sum
- self._s += s2 + (delta * delta) * n_prod / n_sum
- self.n = n_sum
- @property
- def mean(self):
- return self._mu.copy()
- @property
- def var(self):
- return self._s / self.n if self.n > 0 else self._s + np.nan
- @property
- def std(self):
- return np.sqrt(self.var)
- def test():
- shape = (15, 30)
- rs = RunningStats(shape)
- N = 1000
- for _ in range(50):
- xs = np.random.normal(50, 100, size=(N,) + shape)
- n = 0
- while n < N:
- if np.random.rand() < 0.3:
- rs.append(xs[n])
- assert np.allclose(rs.mean, xs[:n+1].mean(axis=0))
- assert np.allclose(rs.std, xs[:n+1].std(axis=0))
- n += 1
- else:
- i = np.random.randint(1, min(50, N - n) + 1)
- rs.extend(xs[n: n+i])
- assert np.allclose(rs.mean, xs[:n+i].mean(axis=0))
- assert np.allclose(rs.std, xs[:n+i].std(axis=0))
- n += i
- rs.clear()
Add Comment
Please, Sign In to add comment