Advertisement
Guest User

TensorBoardSummaries

a guest
Mar 31st, 2020
120
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.64 KB | None | 0 0
  1. class TensorBoardSummaries(gym.Wrapper):
  2.     """ Writes env summaries."""
  3.  
  4.     def __init__(
  5.         self,
  6.         env: gym.Env,
  7.         writer: SummaryWriter,
  8.         prefix: Optional[str] = None,
  9.         start_idx: int = 0,
  10.         running_mean_size: int = 100,
  11.     ):
  12.         super().__init__(env)
  13.         self.episode_counter = 0
  14.         self.prefix = prefix or self.env.spec.id
  15.         self.step_idx = start_idx
  16.         self.writer = writer
  17.  
  18.         nenvs = getattr(self.env.unwrapped, "nenvs", 1)
  19.         self.rewards = np.zeros(nenvs)
  20.         self.had_ended_episodes = np.zeros(nenvs, dtype=np.bool)
  21.         self.episode_lengths = np.zeros(nenvs)
  22.         self.reward_queues = [deque([], maxlen=running_mean_size)
  23.                               for _ in range(nenvs)]
  24.  
  25.     def should_write_summaries(self):
  26.         """ Returns true if it's time to write summaries. """
  27.         return np.all(self.had_ended_episodes)
  28.  
  29.     def add_summaries(self):
  30.         """ Writes summaries. """
  31.         summaries = {
  32.             f"{self.prefix}/rewards/total": np.mean([q[-1] for q in self.reward_queues]),
  33.             f"{self.prefix}/rewards/mean_{self.reward_queues[0].maxlen}": np.mean(
  34.                 [np.mean(q) for q in self.reward_queues]),
  35.             f"{self.prefix}/episode_length": np.mean(self.episode_lengths),
  36.         }
  37.         if self.had_ended_episodes.size > 1:
  38.             summaries.update(**{
  39.                 f"{self.prefix}/rewards/min": min(q[-1] for q in self.reward_queues),
  40.                 f"{self.prefix}/rewards/max": max(q[-1] for q in self.reward_queues),
  41.             })
  42.            
  43.         for name, value in summaries.items():
  44.             self.writer.add_scalar(name, float(value), self.step_idx)
  45.        
  46.         self.episode_lengths.fill(0)
  47.         self.had_ended_episodes.fill(False)
  48.  
  49.     def step(self, action):
  50.         obs, rew, done, info = self.env.step(action)
  51.         self.rewards += rew
  52.         self.episode_lengths[~self.had_ended_episodes] += 1
  53.  
  54.         info_collection = [info] if isinstance(info, dict) else info
  55.         done_collection = [done] if isinstance(done, bool) else done
  56.         done_indices = [i for i, info in enumerate(info_collection)
  57.                         if info.get("real_done", done_collection[i])]
  58.         for i in done_indices:
  59.             if not self.had_ended_episodes[i]:
  60.                 self.had_ended_episodes[i] = True
  61.             self.reward_queues[i].append(self.rewards[i])
  62.             self.rewards[i] = 0
  63.  
  64.         if self.should_write_summaries():
  65.             self.add_summaries()
  66.        
  67.         self.step_idx += 1
  68.            
  69.         return obs, rew, done, info
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement