Advertisement
Guest User

Untitled

a guest
Mar 9th, 2021
446
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Diff 16.25 KB | None | 0 0
  1. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py
  2. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py 2021-02-17 11:48:18.299635171 +0100
  3. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py  2021-02-22 15:39:02.265285353 +0100
  4. @@ -253,7 +253,7 @@
  5.          is_training=True)
  6.  
  7.      # Q scores for actions which we know were selected in the given state.
  8. -    one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
  9. +    one_hot_selection = F.one_hot(torch.tensor(train_batch[SampleBatch.ACTIONS], device=policy.device).long(),
  10.                                    policy.action_space.n)
  11.      q_t_selected = torch.sum(
  12.          torch.where(q_t > FLOAT_MIN, q_t,
  13. @@ -292,8 +292,10 @@
  14.  
  15.      policy.q_loss = QLoss(
  16.          q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best,
  17. -        train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS],
  18. -        train_batch[SampleBatch.DONES].float(), config["gamma"],
  19. +        torch.tensor(train_batch[PRIO_WEIGHTS], device=policy.device),
  20. +        torch.tensor(train_batch[SampleBatch.REWARDS], device=policy.device),
  21. +        torch.tensor(train_batch[SampleBatch.DONES], device=policy.device).float(),
  22. +        config["gamma"],
  23.          config["n_step"], config["num_atoms"], config["v_min"],
  24.          config["v_max"])
  25.  
  26. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py
  27. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py 2021-02-17 11:48:18.303635120 +0100
  28. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py  2021-02-16 17:08:20.794954673 +0100
  29. @@ -217,7 +217,7 @@
  30.  
  31.                  def value(**input_dict):
  32.                      model_out, _ = self.model.from_batch(
  33. -                        convert_to_torch_tensor(input_dict, self.device),
  34. +                        input_dict,
  35.                          is_training=False)
  36.                      # [0] = remove the batch dim.
  37.                      return self.model.value_function()[0]
  38. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py
  39. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py 2021-02-17 11:48:18.311635017 +0100
  40. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py  2021-02-17 12:51:10.791367734 +0100
  41. @@ -534,10 +534,10 @@
  42.              post_batches[agent_id] = policy.postprocess_trajectory(
  43.                  post_batches[agent_id], other_batches, episode)
  44.  
  45. -        if log_once("after_post"):
  46. -            logger.info(
  47. -                "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
  48. -                format(summarize(post_batches)))
  49. +        # if log_once("after_post"):
  50. +        #     logger.info(
  51. +        #         "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
  52. +        #         format(summarize(post_batches)))
  53.  
  54.          # Append into policy batches and reset.
  55.          from ray.rllib.evaluation.rollout_worker import get_global_worker
  56. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py
  57. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py   2021-02-17 11:48:18.311635017 +0100
  58. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py    2021-02-22 15:06:38.718194674 +0100
  59. @@ -665,9 +665,9 @@
  60.                  for estimator in self.reward_estimators:
  61.                      estimator.process(sub_batch)
  62.  
  63. -        if log_once("sample_end"):
  64. -            logger.info("Completed sample batch:\n\n{}\n".format(
  65. -                summarize(batch)))
  66. +        # if log_once("sample_end"):
  67. +        #     logger.info("Completed sample batch:\n\n{}\n".format(
  68. +        #         summarize(batch)))
  69.  
  70.          if self.compress_observations == "bulk":
  71.              batch.compress(bulk=True)
  72. @@ -805,10 +805,10 @@
  73.              >>> batch = worker.sample()
  74.              >>> worker.learn_on_batch(samples)
  75.          """
  76. -        if log_once("learn_on_batch"):
  77. -            logger.info(
  78. -                "Training on concatenated sample batches:\n\n{}\n".format(
  79. -                    summarize(samples)))
  80. +        # if log_once("learn_on_batch"):
  81. +        #     logger.info(
  82. +        #         "Training on concatenated sample batches:\n\n{}\n".format(
  83. +        #             summarize(samples)))
  84.          if isinstance(samples, MultiAgentBatch):
  85.              info_out = {}
  86.              to_fetch = {}
  87. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py
  88. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py  2021-02-17 11:48:18.311635017 +0100
  89. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py   2021-02-17 12:50:15.336001627 +0100
  90. @@ -920,8 +920,9 @@
  91.                  for agent_id, raw_obs in resetted_obs.items():
  92.                      policy_id: PolicyID = episode.policy_for(agent_id)
  93.                      policy: Policy = _get_or_raise(policies, policy_id)
  94. -                    prep_obs: EnvObsType = _get_or_raise(
  95. -                        preprocessors, policy_id).transform(raw_obs)
  96. +                    # prep_obs: EnvObsType = _get_or_raise(
  97. +                    #     preprocessors, policy_id).transform(raw_obs)
  98. +                    prep_obs = raw_obs
  99.                      filtered_obs: EnvObsType = _get_or_raise(
  100.                          obs_filters, policy_id)(prep_obs)
  101.                      episode._set_last_observation(agent_id, filtered_obs)
  102. @@ -1017,15 +1018,16 @@
  103.          for agent_id, raw_obs in all_agents_obs.items():
  104.              assert agent_id != "__all__"
  105.              policy_id: PolicyID = episode.policy_for(agent_id)
  106. -            prep_obs: EnvObsType = _get_or_raise(preprocessors,
  107. -                                                 policy_id).transform(raw_obs)
  108. +            # prep_obs: EnvObsType = _get_or_raise(preprocessors,
  109. +            #                                      policy_id).transform(raw_obs)
  110. +            prep_obs: EnvObsType = raw_obs
  111.              if log_once("prep_obs"):
  112.                  logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
  113.  
  114.              filtered_obs: EnvObsType = _get_or_raise(obs_filters,
  115.                                                       policy_id)(prep_obs)
  116. -            if log_once("filtered_obs"):
  117. -                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
  118. +            # if log_once("filtered_obs"):
  119. +            #     logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
  120.  
  121.              agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
  122.  
  123. @@ -1062,11 +1064,8 @@
  124.                  # Add extra-action-fetches to collectors.
  125.                  pol = policies[policy_id]
  126.                  for key, value in episode.last_pi_info_for(agent_id).items():
  127. -                    if key in pol.view_requirements:
  128. -                        values_dict[key] = value
  129. -                # Env infos for this agent.
  130. -                if "infos" in pol.view_requirements:
  131. -                    values_dict["infos"] = agent_infos
  132. +                    values_dict[key] = value
  133. +                values_dict["infos"] = agent_infos
  134.                  _sample_collector.add_action_reward_next_obs(
  135.                      episode.episode_id, agent_id, env_id, policy_id,
  136.                      agent_done, values_dict)
  137. @@ -1151,8 +1150,9 @@
  138.                  # type: AgentID, EnvObsType
  139.                  for agent_id, raw_obs in resetted_obs.items():
  140.                      policy_id: PolicyID = new_episode.policy_for(agent_id)
  141. -                    prep_obs: EnvObsType = _get_or_raise(
  142. -                        preprocessors, policy_id).transform(raw_obs)
  143. +                    # prep_obs: EnvObsType = _get_or_raise(
  144. +                    #     preprocessors, policy_id).transform(raw_obs)
  145. +                    prep_obs = raw_obs
  146.                      filtered_obs: EnvObsType = _get_or_raise(
  147.                          obs_filters, policy_id)(prep_obs)
  148.                      new_episode._set_last_observation(agent_id, filtered_obs)
  149. @@ -1295,9 +1295,9 @@
  150.      else:
  151.          builder = None
  152.  
  153. -    if log_once("compute_actions_input"):
  154. -        logger.info("Inputs to compute_actions():\n\n{}\n".format(
  155. -            summarize(to_eval)))
  156. +    # if log_once("compute_actions_input"):
  157. +    #     logger.info("Inputs to compute_actions():\n\n{}\n".format(
  158. +    #         summarize(to_eval)))
  159.  
  160.      for policy_id, eval_data in to_eval.items():
  161.          policy: Policy = _get_or_raise(policies, policy_id)
  162. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/modelv2.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/modelv2.py
  163. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/modelv2.py  2021-02-17 11:48:18.323634862 +0100
  164. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/modelv2.py   2021-02-22 11:27:25.177406475 +0100
  165. @@ -199,12 +199,12 @@
  166.          """
  167.  
  168.          restored = input_dict.copy()
  169. -        restored["obs"] = restore_original_dimensions(
  170. -            input_dict["obs"], self.obs_space, self.framework)
  171. -        if len(input_dict["obs"].shape) > 2:
  172. -            restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
  173. -        else:
  174. -            restored["obs_flat"] = input_dict["obs"]
  175. +        # restored["obs"] = restore_original_dimensions(
  176. +        #     input_dict["obs"], self.obs_space, self.framework)
  177. +        # if len(input_dict["obs"].shape) > 2:
  178. +        #     restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
  179. +        # else:
  180. +        #     restored["obs_flat"] = input_dict["obs"]
  181.          with self.context():
  182.              res = self.forward(restored, state or [], seq_lens)
  183.          if ((not isinstance(res, list) and not isinstance(res, tuple))
  184. @@ -218,11 +218,11 @@
  185.              shape = outputs.shape
  186.          except AttributeError:
  187.              raise ValueError("Output is not a tensor: {}".format(outputs))
  188. -        else:
  189. -            if len(shape) != 2 or int(shape[1]) != self.num_outputs:
  190. -                raise ValueError(
  191. -                    "Expected output shape of [None, {}], got {}".format(
  192. -                        self.num_outputs, shape))
  193. +        # else:
  194. +        #     if len(shape) != 2 or int(shape[1]) != self.num_outputs:
  195. +        #         raise ValueError(
  196. +        #             "Expected output shape of [None, {}], got {}".format(
  197. +        #                 self.num_outputs, shape))
  198.          if not isinstance(state, list):
  199.              raise ValueError("State output is not a list: {}".format(state))
  200.  
  201. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py
  202. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py    2021-02-17 11:48:18.323634862 +0100
  203. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/models/preprocessors.py 2021-02-09 17:34:04.877052059 +0100
  204. @@ -59,7 +59,7 @@
  205.                      self._obs_space, gym.spaces.Box):
  206.                  observation = np.array(observation)
  207.              try:
  208. -                if not self._obs_space.contains(observation):
  209. +                 if not self._obs_space.contains(observation):
  210.                      raise ValueError(
  211.                          "Observation ({}) outside given space ({})!",
  212.                          observation, self._obs_space)
  213. @@ -245,8 +245,8 @@
  214.                offset: int) -> None:
  215.          if not isinstance(observation, OrderedDict):
  216.              observation = OrderedDict(sorted(observation.items()))
  217. -        assert len(observation) == len(self.preprocessors), \
  218. -            (len(observation), len(self.preprocessors))
  219. +#        assert len(observation) == len(self.preprocessors), \
  220. +#            (len(observation), len(self.preprocessors))
  221.          for o, p in zip(observation.values(), self.preprocessors):
  222.              p.write(o, array, offset)
  223.              offset += p.size
  224. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/policy.py
  225. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/policy.py   2021-02-17 11:48:18.327634811 +0100
  226. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/policy.py    2021-02-16 17:04:00.533302840 +0100
  227. @@ -585,6 +585,9 @@
  228.              SampleBatch.EPS_ID: ViewRequirement(),
  229.              SampleBatch.UNROLL_ID: ViewRequirement(),
  230.              SampleBatch.AGENT_INDEX: ViewRequirement(),
  231. +            SampleBatch.VF_PREDS: ViewRequirement(),
  232. +            SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(),
  233. +            SampleBatch.ACTION_LOGP: ViewRequirement(),
  234.              "t": ViewRequirement(),
  235.          }
  236.  
  237. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py
  238. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py 2021-02-17 11:48:18.327634811 +0100
  239. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py  2021-02-22 15:17:05.590408618 +0100
  240. @@ -588,8 +588,8 @@
  241.  
  242.      def _lazy_tensor_dict(self, postprocessed_batch):
  243.          train_batch = UsageTrackingDict(postprocessed_batch)
  244. -        train_batch.set_get_interceptor(
  245. -            functools.partial(convert_to_torch_tensor, device=self.device))
  246. +        # train_batch.set_get_interceptor(
  247. +        #     functools.partial(convert_to_torch_tensor, device=self.device))
  248.          return train_batch
  249.  
  250.      def _lazy_numpy_dict(self, postprocessed_batch):
  251. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py
  252. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py    2021-02-17 11:48:18.327634811 +0100
  253. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_template.py 2021-02-16 16:37:00.641800138 +0100
  254. @@ -246,10 +246,10 @@
  255.                                    self.action_space, config)
  256.  
  257.              # Perform test runs through postprocessing- and loss functions.
  258. -            self._initialize_loss_from_dummy_batch(
  259. -                auto_remove_unneeded_view_reqs=True,
  260. -                stats_fn=stats_fn,
  261. -            )
  262. +            # self._initialize_loss_from_dummy_batch(
  263. +            #     auto_remove_unneeded_view_reqs=True,
  264. +            #     stats_fn=stats_fn,
  265. +            # )
  266.  
  267.              if _after_loss_init:
  268.                  _after_loss_init(self, obs_space, action_space, config)
  269. diff -ur -x '*.pyc' /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py
  270. --- /tmp/myEnvOrig/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py 2021-02-17 11:48:18.331634759 +0100
  271. +++ /tmp/myEnvPatched/lib/python3.7/site-packages/ray/rllib/utils/tracking_dict.py  2021-02-16 17:23:34.977863590 +0100
  272. @@ -1,3 +1,5 @@
  273. +from ray.rllib.policy.sample_batch import SampleBatch
  274. +
  275.  class UsageTrackingDict(dict):
  276.      """Dict that tracks which keys have been accessed.
  277.  
  278. @@ -33,6 +35,10 @@
  279.      def __getitem__(self, key):
  280.          self.accessed_keys.add(key)
  281.          value = dict.__getitem__(self, key)
  282. +
  283. +        if key == SampleBatch.CUR_OBS:
  284. +            return value
  285. +
  286.          if self.get_interceptor:
  287.              if key not in self.intercepted_values:
  288.                  self.intercepted_values[key] = self.get_interceptor(value)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement