Advertisement
Guest User

Untitled

a guest
Jul 24th, 2019
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.75 KB | None | 0 0
  1. class ReplayBuffer:
  2. """Fixed-size buffer to store experience tuples."""
  3.  
  4. def __init__(self, action_size, buffer_size, batch_size, seed):
  5. """Initialize a ReplayBuffer object.
  6. Params
  7. ======
  8. buffer_size (int): maximum size of buffer
  9. batch_size (int): size of each training batch
  10. """
  11. self.action_size = action_size
  12. self.memory = deque(maxlen=buffer_size) # internal memory (deque)
  13. self.batch_size = batch_size
  14. self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
  15. self.seed = random.seed(seed)
  16.  
  17. def add(self, state, action, reward, next_state, done):
  18. """Add a new experience to memory."""
  19. e = self.experience(state, action, reward, next_state, done)
  20. self.memory.append(e)
  21.  
  22. def sample(self):
  23. """Randomly sample a batch of experiences from memory."""
  24. experiences = random.sample(self.memory, k=self.batch_size)
  25. states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
  26. actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(device)
  27. rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
  28. next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
  29. dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
  30.  
  31. return (states, actions, rewards, next_states, dones)
  32.  
  33. def __len__(self):
  34. """Return the current size of internal memory."""
  35. return len(self.memory)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement