Advertisement
Guest User

Untitled

a guest
Jun 18th, 2019
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.50 KB | None | 0 0
  1. import h5py
  2. import numpy as np
  3.  
  4. class H5Buffer():
  5. def __init__(self, array_shape, maxlen, dtype):
  6. self.maxlen = maxlen
  7. self.current_idx = 0
  8.  
  9. self.file = h5py.File("buffer.hdf5", "w")
  10. self.buffer = self.file.create_dataset('buffer', (0,)+array_shape, maxshape=(maxlen,)+array_shape, dtype=dtype)
  11.  
  12. def append(self, array):
  13. '''
  14. array is numpy array with the shape of array_shape
  15. '''
  16. add_size = array.shape[0]
  17. if self.buffer.shape[0]<self.maxlen:
  18. self._resize(self.buffer.shape[0], add_size)
  19.  
  20. add_idx = add_size
  21. end_idx = self.current_idx + add_idx
  22.  
  23. if end_idx >= self.maxlen:
  24. add_idx-= end_idx - self.maxlen
  25. end_idx = self.maxlen
  26.  
  27. self.buffer[self.current_idx:end_idx] = array[:add_idx]
  28.  
  29. self.current_idx = end_idx
  30. if self.current_idx == self.maxlen:
  31. self.current_idx = 0
  32. if add_idx != add_size:
  33. self.append(array[add_idx:])
  34.  
  35. def _resize(self, current_size, add_size):
  36. new_size = current_size + add_size
  37. if new_size > self.maxlen:
  38. new_size = self.maxlen
  39. self.buffer.resize(new_size, axis=0)
  40.  
  41. def sample(self, start_idx, end_idx):
  42. return self.buffer[start_idx:end_idx]
  43.  
  44. def length(self):
  45. return len(self.actions)
  46.  
  47. def close(self):
  48. if self.file:
  49. self.file.close()
  50. self.file = None
  51.  
  52. def __del__(self):
  53. self.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement