Guest User

Untitled

a guest
Jul 22nd, 2018
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.63 KB | None | 0 0
  1. import numpy as np
  2.  
  3.  
  4. class RunningStats:
  5. def __init__(self, shape, *, _mu=None, _s=None, _n=0):
  6. self._shape = shape
  7. self._mu = _mu if _mu is not None else np.zeros(shape)
  8. self._s = _s if _s is not None else np.zeros(shape)
  9. self.n = _n or 0
  10.  
  11. def copy(self, *, shallow=False):
  12. if shallow:
  13. return self.__class__(
  14. self._shape, _mu=self._mu, _s=self._s, _n=self.n)
  15. else:
  16. return self.__class__(
  17. self._shape, _mu=self._mu.copy(), _s=self._s.copy(), _n=self.n)
  18.  
  19. def clear(self):
  20. self._mu *= 0
  21. self._s *= 0
  22. self.n = 0
  23.  
  24. def append(self, x):
  25. if x.shape != self._shape:
  26. raise ValueError('Wrong shape for x!')
  27. self.n += 1
  28. delta = self._mu - x
  29. self._mu = ((self.n - 1) * self._mu + x) / self.n
  30. self._s += delta * delta * (self.n - 1) / self.n
  31.  
  32. def extend(self, xs):
  33. if isinstance(xs, self.__class__):
  34. mu2 = xs._mu
  35. s2 = xs._s
  36. n2 = xs.n
  37. if n2 == 0:
  38. return
  39. else:
  40. n2 = len(xs)
  41. if n2 == 0:
  42. return
  43. mu2 = np.mean(xs, axis=0)
  44. s2 = n2 * np.var(xs, axis=0)
  45. if self._mu.shape[1:] != mu2.shape[1:]:
  46. raise ValueError('Wrong shape for xs!')
  47. if self.n == 0:
  48. self._mu[:] = mu2
  49. self._s[:] = s2
  50. self.n = n2
  51. else:
  52. delta = self._mu - mu2
  53. n_sum = self.n + n2
  54. n_prod = self.n * n2
  55. self._mu = (self.n * self._mu + n2 * mu2) / n_sum
  56. self._s += s2 + (delta * delta) * n_prod / n_sum
  57. self.n = n_sum
  58.  
  59. @property
  60. def mean(self):
  61. return self._mu.copy()
  62.  
  63. @property
  64. def var(self):
  65. return self._s / self.n if self.n > 0 else self._s + np.nan
  66.  
  67. @property
  68. def std(self):
  69. return np.sqrt(self.var)
  70.  
  71.  
  72. def test():
  73. shape = (15, 30)
  74. rs = RunningStats(shape)
  75. N = 1000
  76.  
  77. for _ in range(50):
  78. xs = np.random.normal(50, 100, size=(N,) + shape)
  79. n = 0
  80. while n < N:
  81. if np.random.rand() < 0.3:
  82. rs.append(xs[n])
  83. assert np.allclose(rs.mean, xs[:n+1].mean(axis=0))
  84. assert np.allclose(rs.std, xs[:n+1].std(axis=0))
  85. n += 1
  86. else:
  87. i = np.random.randint(1, min(50, N - n) + 1)
  88. rs.extend(xs[n: n+i])
  89. assert np.allclose(rs.mean, xs[:n+i].mean(axis=0))
  90. assert np.allclose(rs.std, xs[:n+i].std(axis=0))
  91. n += i
  92. rs.clear()
Add Comment
Please, Sign In to add comment