Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class TensorBoardSummaries(gym.Wrapper):
- """ Writes env summaries."""
- def __init__(
- self,
- env: gym.Env,
- writer: SummaryWriter,
- prefix: Optional[str] = None,
- start_idx: int = 0,
- running_mean_size: int = 100,
- ):
- super().__init__(env)
- self.episode_counter = 0
- self.prefix = prefix or self.env.spec.id
- self.step_idx = start_idx
- self.writer = writer
- nenvs = getattr(self.env.unwrapped, "nenvs", 1)
- self.rewards = np.zeros(nenvs)
- self.had_ended_episodes = np.zeros(nenvs, dtype=np.bool)
- self.episode_lengths = np.zeros(nenvs)
- self.reward_queues = [deque([], maxlen=running_mean_size)
- for _ in range(nenvs)]
- def should_write_summaries(self):
- """ Returns true if it's time to write summaries. """
- return np.all(self.had_ended_episodes)
- def add_summaries(self):
- """ Writes summaries. """
- summaries = {
- f"{self.prefix}/rewards/total": np.mean([q[-1] for q in self.reward_queues]),
- f"{self.prefix}/rewards/mean_{self.reward_queues[0].maxlen}": np.mean(
- [np.mean(q) for q in self.reward_queues]),
- f"{self.prefix}/episode_length": np.mean(self.episode_lengths),
- }
- if self.had_ended_episodes.size > 1:
- summaries.update(**{
- f"{self.prefix}/rewards/min": min(q[-1] for q in self.reward_queues),
- f"{self.prefix}/rewards/max": max(q[-1] for q in self.reward_queues),
- })
- for name, value in summaries.items():
- self.writer.add_scalar(name, float(value), self.step_idx)
- self.episode_lengths.fill(0)
- self.had_ended_episodes.fill(False)
- def step(self, action):
- obs, rew, done, info = self.env.step(action)
- self.rewards += rew
- self.episode_lengths[~self.had_ended_episodes] += 1
- info_collection = [info] if isinstance(info, dict) else info
- done_collection = [done] if isinstance(done, bool) else done
- done_indices = [i for i, info in enumerate(info_collection)
- if info.get("real_done", done_collection[i])]
- for i in done_indices:
- if not self.had_ended_episodes[i]:
- self.had_ended_episodes[i] = True
- self.reward_queues[i].append(self.rewards[i])
- self.rewards[i] = 0
- if self.should_write_summaries():
- self.add_summaries()
- self.step_idx += 1
- return obs, rew, done, info
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement