Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py 2021-02-17 11:48:18.299635171 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py 2021-02-22 15:39:02.265285353 +0100
- @@ -253,7 +253,7 @@
- is_training=True)
- # Q scores for actions which we know were selected in the given state.
- - one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
- + one_hot_selection = F.one_hot(torch.tensor(train_batch[SampleBatch.ACTIONS], device=policy.device).long(),
- policy.action_space.n)
- q_t_selected = torch.sum(
- torch.where(q_t > FLOAT_MIN, q_t,
- @@ -292,8 +292,10 @@
- policy.q_loss = QLoss(
- q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best,
- - train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS],
- - train_batch[SampleBatch.DONES].float(), config["gamma"],
- + torch.tensor(train_batch[PRIO_WEIGHTS], device=policy.device),
- + torch.tensor(train_batch[SampleBatch.REWARDS], device=policy.device),
- + torch.tensor(train_batch[SampleBatch.DONES], device=policy.device).float(),
- + config["gamma"],
- config["n_step"], config["num_atoms"], config["v_min"],
- config["v_max"])
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py 2021-02-17 11:48:18.303635120 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py 2021-02-16 17:08:20.794954673 +0100
- @@ -217,7 +217,7 @@
- def value(**input_dict):
- model_out, _ = self.model.from_batch(
- - convert_to_torch_tensor(input_dict, self.device),
- + input_dict,
- is_training=False)
- # [0] = remove the batch dim.
- return self.model.value_function()[0]
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py 2021-02-17 11:48:18.311635017 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py 2021-02-17 12:51:10.791367734 +0100
- @@ -534,10 +534,10 @@
- post_batches[agent_id] = policy.postprocess_trajectory(
- post_batches[agent_id], other_batches, episode)
- - if log_once("after_post"):
- - logger.info(
- - "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
- - format(summarize(post_batches)))
- + # if log_once("after_post"):
- + # logger.info(
- + # "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
- + # format(summarize(post_batches)))
- # Append into policy batches and reset.
- from ray.rllib.evaluation.rollout_worker import get_global_worker
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py 2021-02-17 11:48:18.311635017 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py 2021-02-22 15:06:38.718194674 +0100
- @@ -665,9 +665,9 @@
- for estimator in self.reward_estimators:
- estimator.process(sub_batch)
- - if log_once("sample_end"):
- - logger.info("Completed sample batch:\n\n{}\n".format(
- - summarize(batch)))
- + # if log_once("sample_end"):
- + # logger.info("Completed sample batch:\n\n{}\n".format(
- + # summarize(batch)))
- if self.compress_observations == "bulk":
- batch.compress(bulk=True)
- @@ -805,10 +805,10 @@
- >>> batch = worker.sample()
- >>> worker.learn_on_batch(samples)
- """
- - if log_once("learn_on_batch"):
- - logger.info(
- - "Training on concatenated sample batches:\n\n{}\n".format(
- - summarize(samples)))
- + # if log_once("learn_on_batch"):
- + # logger.info(
- + # "Training on concatenated sample batches:\n\n{}\n".format(
- + # summarize(samples)))
- if isinstance(samples, MultiAgentBatch):
- info_out = {}
- to_fetch = {}
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py 2021-02-17 11:48:18.311635017 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py 2021-02-17 12:50:15.336001627 +0100
- @@ -920,8 +920,9 @@
- for agent_id, raw_obs in resetted_obs.items():
- policy_id: PolicyID = episode.policy_for(agent_id)
- policy: Policy = _get_or_raise(policies, policy_id)
- - prep_obs: EnvObsType = _get_or_raise(
- - preprocessors, policy_id).transform(raw_obs)
- + # prep_obs: EnvObsType = _get_or_raise(
- + # preprocessors, policy_id).transform(raw_obs)
- + prep_obs = raw_obs
- filtered_obs: EnvObsType = _get_or_raise(
- obs_filters, policy_id)(prep_obs)
- episode._set_last_observation(agent_id, filtered_obs)
- @@ -1017,15 +1018,16 @@
- for agent_id, raw_obs in all_agents_obs.items():
- assert agent_id != "__all__"
- policy_id: PolicyID = episode.policy_for(agent_id)
- - prep_obs: EnvObsType = _get_or_raise(preprocessors,
- - policy_id).transform(raw_obs)
- + # prep_obs: EnvObsType = _get_or_raise(preprocessors,
- + # policy_id).transform(raw_obs)
- + prep_obs: EnvObsType = raw_obs
- if log_once("prep_obs"):
- logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
- filtered_obs: EnvObsType = _get_or_raise(obs_filters,
- policy_id)(prep_obs)
- - if log_once("filtered_obs"):
- - logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
- + # if log_once("filtered_obs"):
- + # logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
- agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
- @@ -1062,11 +1064,8 @@
- # Add extra-action-fetches to collectors.
- pol = policies[policy_id]
- for key, value in episode.last_pi_info_for(agent_id).items():
- - if key in pol.view_requirements:
- - values_dict[key] = value
- - # Env infos for this agent.
- - if "infos" in pol.view_requirements:
- - values_dict["infos"] = agent_infos
- + values_dict[key] = value
- + values_dict["infos"] = agent_infos
- _sample_collector.add_action_reward_next_obs(
- episode.episode_id, agent_id, env_id, policy_id,
- agent_done, values_dict)
- @@ -1151,8 +1150,9 @@
- # type: AgentID, EnvObsType
- for agent_id, raw_obs in resetted_obs.items():
- policy_id: PolicyID = new_episode.policy_for(agent_id)
- - prep_obs: EnvObsType = _get_or_raise(
- - preprocessors, policy_id).transform(raw_obs)
- + # prep_obs: EnvObsType = _get_or_raise(
- + # preprocessors, policy_id).transform(raw_obs)
- + prep_obs = raw_obs
- filtered_obs: EnvObsType = _get_or_raise(
- obs_filters, policy_id)(prep_obs)
- new_episode._set_last_observation(agent_id, filtered_obs)
- @@ -1295,9 +1295,9 @@
- else:
- builder = None
- - if log_once("compute_actions_input"):
- - logger.info("Inputs to compute_actions():\n\n{}\n".format(
- - summarize(to_eval)))
- + # if log_once("compute_actions_input"):
- + # logger.info("Inputs to compute_actions():\n\n{}\n".format(
- + # summarize(to_eval)))
- for policy_id, eval_data in to_eval.items():
- policy: Policy = _get_or_raise(policies, policy_id)
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/modelv2.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/modelv2.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/modelv2.py 2021-02-17 11:48:18.323634862 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/modelv2.py 2021-02-22 11:27:25.177406475 +0100
- @@ -199,12 +199,12 @@
- """
- restored = input_dict.copy()
- - restored["obs"] = restore_original_dimensions(
- - input_dict["obs"], self.obs_space, self.framework)
- - if len(input_dict["obs"].shape) > 2:
- - restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
- - else:
- - restored["obs_flat"] = input_dict["obs"]
- + # restored["obs"] = restore_original_dimensions(
- + # input_dict["obs"], self.obs_space, self.framework)
- + # if len(input_dict["obs"].shape) > 2:
- + # restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
- + # else:
- + # restored["obs_flat"] = input_dict["obs"]
- with self.context():
- res = self.forward(restored, state or [], seq_lens)
- if ((not isinstance(res, list) and not isinstance(res, tuple))
- @@ -218,11 +218,11 @@
- shape = outputs.shape
- except AttributeError:
- raise ValueError("Output is not a tensor: {}".format(outputs))
- - else:
- - if len(shape) != 2 or int(shape[1]) != self.num_outputs:
- - raise ValueError(
- - "Expected output shape of [None, {}], got {}".format(
- - self.num_outputs, shape))
- + # else:
- + # if len(shape) != 2 or int(shape[1]) != self.num_outputs:
- + # raise ValueError(
- + # "Expected output shape of [None, {}], got {}".format(
- + # self.num_outputs, shape))
- if not isinstance(state, list):
- raise ValueError("State output is not a list: {}".format(state))
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py 2021-02-17 11:48:18.323634862 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py 2021-02-09 17:34:04.877052059 +0100
- @@ -59,7 +59,7 @@
- self._obs_space, gym.spaces.Box):
- observation = np.array(observation)
- try:
- - if not self._obs_space.contains(observation):
- + if not self._obs_space.contains(observation):
- raise ValueError(
- "Observation ({}) outside given space ({})!",
- observation, self._obs_space)
- @@ -245,8 +245,8 @@
- offset: int) -> None:
- if not isinstance(observation, OrderedDict):
- observation = OrderedDict(sorted(observation.items()))
- - assert len(observation) == len(self.preprocessors), \
- - (len(observation), len(self.preprocessors))
- +# assert len(observation) == len(self.preprocessors), \
- +# (len(observation), len(self.preprocessors))
- for o, p in zip(observation.values(), self.preprocessors):
- p.write(o, array, offset)
- offset += p.size
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/policy.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/policy.py 2021-02-17 11:48:18.327634811 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/policy.py 2021-02-16 17:04:00.533302840 +0100
- @@ -585,6 +585,9 @@
- SampleBatch.EPS_ID: ViewRequirement(),
- SampleBatch.UNROLL_ID: ViewRequirement(),
- SampleBatch.AGENT_INDEX: ViewRequirement(),
- + SampleBatch.VF_PREDS: ViewRequirement(),
- + SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(),
- + SampleBatch.ACTION_LOGP: ViewRequirement(),
- "t": ViewRequirement(),
- }
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py 2021-02-17 11:48:18.327634811 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py 2021-02-22 15:17:05.590408618 +0100
- @@ -588,8 +588,8 @@
- def _lazy_tensor_dict(self, postprocessed_batch):
- train_batch = UsageTrackingDict(postprocessed_batch)
- - train_batch.set_get_interceptor(
- - functools.partial(convert_to_torch_tensor, device=self.device))
- + # train_batch.set_get_interceptor(
- + # functools.partial(convert_to_torch_tensor, device=self.device))
- return train_batch
- def _lazy_numpy_dict(self, postprocessed_batch):
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py 2021-02-17 11:48:18.327634811 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py 2021-02-16 16:37:00.641800138 +0100
- @@ -246,10 +246,10 @@
- self.action_space, config)
- # Perform test runs through postprocessing- and loss functions.
- - self._initialize_loss_from_dummy_batch(
- - auto_remove_unneeded_view_reqs=True,
- - stats_fn=stats_fn,
- - )
- + # self._initialize_loss_from_dummy_batch(
- + # auto_remove_unneeded_view_reqs=True,
- + # stats_fn=stats_fn,
- + # )
- if _after_loss_init:
- _after_loss_init(self, obs_space, action_space, config)
- diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py
- --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py 2021-02-17 11:48:18.331634759 +0100
- +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py 2021-02-16 17:23:34.977863590 +0100
- @@ -1,3 +1,5 @@
- +from ray.rllib.policy.sample_batch import SampleBatch
- +
- class UsageTrackingDict(dict):
- """Dict that tracks which keys have been accessed.
- @@ -33,6 +35,10 @@
- def __getitem__(self, key):
- self.accessed_keys.add(key)
- value = dict.__getitem__(self, key)
- +
- + if key == SampleBatch.CUR_OBS:
- + return value
- +
- if self.get_interceptor:
- if key not in self.intercepted_values:
- self.intercepted_values[key] = self.get_interceptor(value)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement