Advertisement
limmen

PPO

Mar 20th, 2023
558
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 23.15 KB | Source Code | 0 0
  1. from typing import Union, List, Optional
  2. import time
  3. import gym
  4. import os
  5. import numpy as np
  6. import math
  7. from stable_baselines3 import PPO
  8. from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
  9. from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
  10. from stable_baselines3.common.env_util import make_vec_env
  11. from stable_baselines3.common.callbacks import BaseCallback
  12. import csle_common.constants.constants as constants
  13. from csle_common.dao.emulation_config.emulation_env_config import EmulationEnvConfig
  14. from csle_common.dao.simulation_config.simulation_env_config import SimulationEnvConfig
  15. from csle_common.dao.training.experiment_config import ExperimentConfig
  16. from csle_common.dao.training.experiment_execution import ExperimentExecution
  17. from csle_common.dao.training.experiment_result import ExperimentResult
  18. from csle_common.dao.training.agent_type import AgentType
  19. from csle_common.util.experiment_util import ExperimentUtil
  20. from csle_common.logging.log import Logger
  21. from csle_common.metastore.metastore_facade import MetastoreFacade
  22. from csle_common.dao.jobs.training_job_config import TrainingJobConfig
  23. from csle_common.dao.training.ppo_policy import PPOPolicy
  24. from csle_common.dao.simulation_config.state import State
  25. from csle_common.dao.simulation_config.action import Action
  26. from csle_common.dao.training.player_type import PlayerType
  27. from csle_agents.agents.base.base_agent import BaseAgent
  28. import csle_agents.constants.constants as agents_constants
  29.  
  30.  
  31. class PPOAgent(BaseAgent):
  32.     """
  33.    A PPO agent using the implementation from OpenAI baselines
  34.    """
  35.  
  36.     def __init__(self, simulation_env_config: SimulationEnvConfig,
  37.                  emulation_env_config: Union[None, EmulationEnvConfig], experiment_config: ExperimentConfig,
  38.                  training_job: Optional[TrainingJobConfig] = None):
  39.         super(PPOAgent, self).__init__(simulation_env_config=simulation_env_config,
  40.                                        emulation_env_config=emulation_env_config,
  41.                                        experiment_config=experiment_config)
  42.         assert experiment_config.agent_type == AgentType.PPO
  43.         self.training_job = training_job
  44.  
  45.     def train(self) -> ExperimentExecution:
  46.         pid = os.getpid()
  47.  
  48.         # Setup experiment metrics
  49.         exp_result = ExperimentResult()
  50.         exp_result.plot_metrics.append(agents_constants.COMMON.AVERAGE_RETURN)
  51.         exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_RETURN)
  52.         exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_TIME_HORIZON)
  53.         exp_result.plot_metrics.append(agents_constants.COMMON.AVERAGE_TIME_HORIZON)
  54.         exp_result.plot_metrics.append(agents_constants.COMMON.AVERAGE_UPPER_BOUND_RETURN)
  55.         exp_result.plot_metrics.append(agents_constants.COMMON.AVERAGE_RANDOM_RETURN)
  56.         exp_result.plot_metrics.append(agents_constants.COMMON.RUNTIME)
  57.         descr = f"Training of policies with PPO using " \
  58.                 f"simulation:{self.simulation_env_config.name}"
  59.  
  60.         # Setup training job
  61.         if self.training_job is None:
  62.             self.training_job = TrainingJobConfig(
  63.                 simulation_env_name=self.simulation_env_config.name, experiment_config=self.experiment_config,
  64.                 progress_percentage=0, pid=pid, experiment_result=exp_result,
  65.                 emulation_env_name=self.emulation_env_config.name, simulation_traces=[],
  66.                 num_cached_traces=agents_constants.COMMON.NUM_CACHED_SIMULATION_TRACES,
  67.                 log_file_path=Logger.__call__().get_log_file_path(), descr=descr)
  68.             training_job_id = MetastoreFacade.save_training_job(training_job=self.training_job)
  69.             self.training_job.id = training_job_id
  70.         else:
  71.             self.training_job.pid = pid
  72.             self.training_job.progress_percentage = 0
  73.             self.training_job.experiment_result = exp_result
  74.             MetastoreFacade.update_training_job(training_job=self.training_job, id=self.training_job.id)
  75.  
  76.         # Setup experiment execution
  77.         ts = time.time()
  78.         emulation_name = None
  79.         if self.emulation_env_config is not None:
  80.             emulation_name = self.emulation_env_config.name
  81.         simulation_name = self.simulation_env_config.name
  82.         self.exp_execution = ExperimentExecution(
  83.             result=exp_result, config=self.experiment_config, timestamp=ts,
  84.             emulation_name=emulation_name, simulation_name=simulation_name, descr=descr,
  85.             log_file_path=self.training_job.log_file_path)
  86.         exp_execution_id = MetastoreFacade.save_experiment_execution(self.exp_execution)
  87.         self.exp_execution.id = exp_execution_id
  88.  
  89.         # Setup gym environment
  90.         config = self.simulation_env_config.simulation_env_input_config
  91.         orig_env = gym.make(self.simulation_env_config.gym_env_name, config=config)
  92.         env = make_vec_env(env_id=self.simulation_env_config.gym_env_name,
  93.                            n_envs=self.experiment_config.hparams[agents_constants.COMMON.NUM_PARALLEL_ENVS].value,
  94.                            env_kwargs={"config": config}, vec_env_cls=DummyVecEnv)
  95.         env = VecMonitor(env)
  96.  
  97.         # Training runs, one per seed
  98.         for seed in self.experiment_config.random_seeds:
  99.             self.start = time.time()
  100.             exp_result.all_metrics[seed] = {}
  101.             exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_RETURN] = []
  102.             exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_RETURN] = []
  103.             exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_TIME_HORIZON] = []
  104.             exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_TIME_HORIZON] = []
  105.             exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_UPPER_BOUND_RETURN] = []
  106.             exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_RANDOM_RETURN] = []
  107.             exp_result.all_metrics[seed][agents_constants.COMMON.RUNTIME] = []
  108.             ExperimentUtil.set_seed(seed)
  109.  
  110.             # Callback for logging training metrics
  111.             cb = PPOTrainingCallback(
  112.                 eval_every=self.experiment_config.hparams[agents_constants.COMMON.EVAL_EVERY].value,
  113.                 eval_batch_size=self.experiment_config.hparams[agents_constants.COMMON.EVAL_BATCH_SIZE].value,
  114.                 random_seeds=self.experiment_config.random_seeds, training_job=self.training_job,
  115.                 max_steps=self.experiment_config.hparams[agents_constants.COMMON.NUM_TRAINING_TIMESTEPS].value,
  116.                 seed=seed, exp_result=exp_result, simulation_name=self.simulation_env_config.name,
  117.                 player_type=self.experiment_config.player_type,
  118.                 states=self.simulation_env_config.state_space_config.states,
  119.                 actions=(
  120.                     self.simulation_env_config.joint_action_space_config.action_spaces[
  121.                         self.experiment_config.player_idx].actions),
  122.                 save_every=self.experiment_config.hparams[agents_constants.COMMON.SAVE_EVERY].value,
  123.                 save_dir=self.experiment_config.output_dir, exp_execution=self.exp_execution,
  124.                 env=orig_env, experiment_config=self.experiment_config,
  125.                 L=self.experiment_config.hparams[agents_constants.COMMON.L].value,
  126.                 gym_env_name=self.simulation_env_config.gym_env_name,
  127.                 start=self.start
  128.             )
  129.  
  130.             # Create PPO Agent
  131.             policy_kwargs = dict(
  132.                 net_arch=[self.experiment_config.hparams[constants.NEURAL_NETWORKS.NUM_NEURONS_PER_HIDDEN_LAYER].value
  133.                           ] * self.experiment_config.hparams[constants.NEURAL_NETWORKS.NUM_HIDDEN_LAYERS].value)
  134.             model = PPO(
  135.                 agents_constants.PPO.MLP_POLICY, env, verbose=0, policy_kwargs=policy_kwargs,
  136.                 n_steps=self.experiment_config.hparams[agents_constants.PPO.STEPS_BETWEEN_UPDATES].value,
  137.                 batch_size=self.experiment_config.hparams[agents_constants.COMMON.BATCH_SIZE].value,
  138.                 learning_rate=self.experiment_config.hparams[agents_constants.COMMON.LEARNING_RATE].value,
  139.                 seed=seed, device=self.experiment_config.hparams[constants.NEURAL_NETWORKS.DEVICE].value,
  140.                 gamma=self.experiment_config.hparams[agents_constants.COMMON.GAMMA].value,
  141.                 gae_lambda=self.experiment_config.hparams[agents_constants.PPO.GAE_LAMBDA].value,
  142.                 clip_range=self.experiment_config.hparams[agents_constants.PPO.CLIP_RANGE].value,
  143.                 clip_range_vf=self.experiment_config.hparams[agents_constants.PPO.CLIP_RANGE_VF].value,
  144.                 ent_coef=self.experiment_config.hparams[agents_constants.PPO.ENT_COEF].value,
  145.                 vf_coef=self.experiment_config.hparams[agents_constants.PPO.VF_COEF].value,
  146.                 max_grad_norm=self.experiment_config.hparams[agents_constants.PPO.MAX_GRAD_NORM].value,
  147.                 target_kl=self.experiment_config.hparams[agents_constants.PPO.TARGET_KL].value,
  148.             )
  149.             if self.experiment_config.player_type == PlayerType.ATTACKER \
  150.                     and "stopping" in self.simulation_env_config.gym_env_name:
  151.                 orig_env.set_model(model)
  152.  
  153.             # Train
  154.             model.learn(total_timesteps=self.experiment_config.hparams[
  155.                 agents_constants.COMMON.NUM_TRAINING_TIMESTEPS].value, callback=cb)
  156.  
  157.             # Save policy
  158.             exp_result = cb.exp_result
  159.             ts = time.time()
  160.             save_path = f"{self.experiment_config.output_dir}/ppo_policy_seed_{seed}_{ts}.zip"
  161.             model.save(save_path)
  162.             policy = PPOPolicy(
  163.                 model=model, simulation_name=self.simulation_env_config.name, save_path=save_path,
  164.                 states=self.simulation_env_config.state_space_config.states,
  165.                 actions=self.simulation_env_config.joint_action_space_config.action_spaces[
  166.                     self.experiment_config.player_idx].actions, player_type=self.experiment_config.player_type,
  167.                 experiment_config=self.experiment_config,
  168.                 avg_R=exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_RETURN][-1])
  169.             exp_result.policies[seed] = policy
  170.  
  171.             # Save policy metadata
  172.             MetastoreFacade.save_ppo_policy(ppo_policy=policy)
  173.             os.chmod(save_path, 0o777)
  174.  
  175.             # Save latest trace
  176.             MetastoreFacade.save_simulation_trace(orig_env.get_traces()[-1])
  177.             orig_env.reset_traces()
  178.  
  179.         # Calculate average and std metrics
  180.         exp_result.avg_metrics = {}
  181.         exp_result.std_metrics = {}
  182.         for metric in exp_result.all_metrics[self.experiment_config.random_seeds[0]].keys():
  183.             value_vectors = []
  184.             for seed in self.experiment_config.random_seeds:
  185.                 value_vectors.append(exp_result.all_metrics[seed][metric])
  186.  
  187.             avg_metrics = []
  188.             std_metrics = []
  189.             for i in range(len(value_vectors[0])):
  190.                 seed_values = []
  191.                 for seed_idx in range(len(self.experiment_config.random_seeds)):
  192.                     seed_values.append(value_vectors[seed_idx][i])
  193.                 avg_metrics.append(ExperimentUtil.mean_confidence_interval(
  194.                     data=seed_values,
  195.                     confidence=self.experiment_config.hparams[agents_constants.COMMON.CONFIDENCE_INTERVAL].value)[0])
  196.                 std_metrics.append(ExperimentUtil.mean_confidence_interval(
  197.                     data=seed_values,
  198.                     confidence=self.experiment_config.hparams[agents_constants.COMMON.CONFIDENCE_INTERVAL].value)[1])
  199.             exp_result.avg_metrics[metric] = avg_metrics
  200.             exp_result.std_metrics[metric] = std_metrics
  201.  
  202.         traces = orig_env.get_traces()
  203.         if len(traces) > 0:
  204.             MetastoreFacade.save_simulation_trace(traces[-1])
  205.         return self.exp_execution
  206.  
  207.     def hparam_names(self) -> List[str]:
  208.         """
  209.        :return: a list with the hyperparameter names
  210.        """
  211.         return [constants.NEURAL_NETWORKS.NUM_NEURONS_PER_HIDDEN_LAYER,
  212.                 constants.NEURAL_NETWORKS.NUM_HIDDEN_LAYERS,
  213.                 agents_constants.PPO.STEPS_BETWEEN_UPDATES,
  214.                 agents_constants.COMMON.LEARNING_RATE, agents_constants.COMMON.BATCH_SIZE,
  215.                 agents_constants.COMMON.GAMMA, agents_constants.PPO.GAE_LAMBDA, agents_constants.PPO.CLIP_RANGE,
  216.                 agents_constants.PPO.CLIP_RANGE_VF, agents_constants.PPO.ENT_COEF,
  217.                 agents_constants.PPO.VF_COEF, agents_constants.PPO.MAX_GRAD_NORM, agents_constants.PPO.TARGET_KL,
  218.                 agents_constants.COMMON.NUM_TRAINING_TIMESTEPS, agents_constants.COMMON.EVAL_EVERY,
  219.                 agents_constants.COMMON.EVAL_BATCH_SIZE, constants.NEURAL_NETWORKS.DEVICE,
  220.                 agents_constants.COMMON.SAVE_EVERY]
  221.  
  222.  
  223. class PPOTrainingCallback(BaseCallback):
  224.     """
  225.    Callback for monitoring PPO training
  226.    """
  227.  
  228.     def __init__(self, exp_result: ExperimentResult, seed: int, random_seeds: List[int],
  229.                  training_job: TrainingJobConfig, exp_execution: ExperimentExecution,
  230.                  max_steps: int, simulation_name: str, start: float,
  231.                  states: List[State], actions: List[Action], player_type: PlayerType,
  232.                  env: gym.Env, experiment_config: ExperimentConfig, verbose=0,
  233.                  eval_every: int = 100, eval_batch_size: int = 10, save_every: int = 10, save_dir: str = "",
  234.                  L: int = 3, gym_env_name: str = ""):
  235.         """
  236.        Initializes the callback
  237.  
  238.        :param exp_result: the experiment result to populate
  239.        :param seed: the random seed
  240.        :param random_seeds: the list of all random seeds
  241.        :param training_job: the training job
  242.        :param exp_execution: the experiment execution
  243.        :param max_steps: the maximum number of steps for evaluation
  244.        :param simulation_name: the name of the simulation
  245.        :param states: the list of states in the environment
  246.        :param actions: the list of actions in the environment
  247.        :param player_type: the type of the player
  248.        :param verbose: whether logging should be verbose or not
  249.        :param eval_every: how frequently to run the evaluation
  250.        :param eval_batch_size: the batch size for evaluation
  251.        :param save_every: how frequently to checkpoint the current model
  252.        :param save_dir: the path to checkpoint models
  253.        :param env: the training environment
  254.        :param experiment_config: the experiment configuration
  255.        :param L: num stops if a stopping environment
  256.        :param gym_env_name: name of gym env
  257.        :param start_time: the start time-stamp
  258.        """
  259.         super(PPOTrainingCallback, self).__init__(verbose)
  260.         self.states = states
  261.         self.simulation_name = simulation_name
  262.         self.iter = 0
  263.         self.eval_every = eval_every
  264.         self.eval_batch_size = eval_batch_size
  265.         self.exp_result = exp_result
  266.         self.seed = seed
  267.         self.random_seeds = random_seeds
  268.         self.training_job = training_job
  269.         self.exp_execution = exp_execution
  270.         self.max_steps = max_steps
  271.         self.player_type = player_type
  272.         self.actions = actions
  273.         self.save_every = save_every
  274.         self.save_dir = save_dir
  275.         self.env = env
  276.         self.experiment_config = experiment_config
  277.         self.L = L
  278.         self.gym_env_name = gym_env_name
  279.         self.start = start
  280.  
  281.     def _on_training_start(self) -> None:
  282.         """
  283.        This method is called before the first rollout starts.
  284.        """
  285.         pass
  286.  
  287.     def _on_rollout_start(self) -> None:
  288.         """
  289.        A rollout is the collection of environment interaction
  290.        using the current policy.
  291.        This event is triggered before collecting new samples.
  292.        """
  293.         pass
  294.  
  295.     def _on_step(self) -> bool:
  296.         """
  297.        This method will be called by the model after each call to `env.step()`.
  298.  
  299.        For child callback (of an `EventCallback`), this will be called
  300.        when the event is triggered.
  301.  
  302.        :return: (bool) If the callback returns False, training is aborted early.
  303.        """
  304.         if self.experiment_config.player_type == PlayerType.ATTACKER \
  305.                 and "stopping" in self.simulation_name:
  306.             self.env.set_model(self.model)
  307.         return True
  308.  
  309.     def _on_training_end(self) -> None:
  310.         """
  311.        This event is triggered before exiting the `learn()` method.
  312.        """
  313.         pass
  314.  
  315.     def _on_rollout_end(self) -> None:
  316.         """
  317.        This event is triggered before updating the policy.
  318.        """
  319.         Logger.__call__().get_logger().info(f"Training iteration: {self.iter}, seed:{self.seed}, "
  320.                                             f"progress: "
  321.                                             f"{round(100 * round(self.num_timesteps / self.max_steps, 2), 2)}%")
  322.         ts = time.time()
  323.         save_path = self.save_dir + f"/ppo_model{self.iter}_{ts}.zip"
  324.  
  325.         # Save model
  326.         if self.iter % self.save_every == 0 and self.iter > 0:
  327.             Logger.__call__().get_logger().info(f"Saving model to path: {save_path}")
  328.             self.model.save(save_path)
  329.             os.chmod(save_path, 0o777)
  330.  
  331.         # Eval model
  332.         if self.iter % self.eval_every == 0:
  333.             if self.player_type == PlayerType.ATTACKER and "stopping" in self.simulation_name:
  334.                 self.env.set_model(self.model)
  335.             policy = PPOPolicy(
  336.                 model=self.model, simulation_name=self.simulation_name, save_path=save_path,
  337.                 states=self.states, player_type=self.player_type, actions=self.actions,
  338.                 experiment_config=self.experiment_config, avg_R=-1)
  339.             o = self.env.reset()
  340.             max_horizon = self.experiment_config.hparams[agents_constants.COMMON.MAX_ENV_STEPS].value
  341.             avg_rewards = []
  342.             avg_horizons = []
  343.             avg_upper_bounds = []
  344.             avg_random_returns = []
  345.             info = {}
  346.             for i in range(self.eval_batch_size):
  347.                 o = self.env.reset()
  348.                 done = False
  349.                 t = 0
  350.                 cumulative_reward = 0
  351.                 while not done and t <= max_horizon:
  352.                     a = policy.action(o=o)
  353.                     o, r, done, info = self.env.step(a)
  354.                     cumulative_reward += r * math.pow(
  355.                         self.experiment_config.hparams[agents_constants.COMMON.GAMMA].value, t)
  356.                     t += 1
  357.                     Logger.__call__().get_logger().debug(f"t:{t}, a1:{a}, r:{r}, info:{info}, done:{done}")
  358.                 avg_rewards.append(cumulative_reward)
  359.                 if agents_constants.ENV_METRICS.TIME_HORIZON in info:
  360.                     avg_horizons.append(info[agents_constants.ENV_METRICS.TIME_HORIZON])
  361.                 else:
  362.                     avg_horizons.append(-1)
  363.                 if agents_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN in info:
  364.                     avg_upper_bounds.append(info[agents_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN])
  365.                 else:
  366.                     avg_upper_bounds.append(-1)
  367.                 if agents_constants.ENV_METRICS.AVERAGE_RANDOM_RETURN in info:
  368.                     avg_random_returns.append(info[agents_constants.ENV_METRICS.AVERAGE_RANDOM_RETURN])
  369.                 else:
  370.                     avg_random_returns.append(-1)
  371.  
  372.                 avg_upper_bounds.append(info[agents_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN])
  373.             avg_R = np.mean(avg_rewards)
  374.             avg_T = np.mean(avg_horizons)
  375.             avg_random_return = np.mean(avg_random_returns)
  376.             avg_upper_bound = np.mean(avg_upper_bounds)
  377.             policy.avg_R = avg_R
  378.             time_elapsed_minutes = (time.time() - self.start) // 60
  379.             self.exp_result.all_metrics[self.seed][agents_constants.COMMON.AVERAGE_RETURN].append(round(avg_R, 3))
  380.             self.exp_result.all_metrics[self.seed][agents_constants.COMMON.AVERAGE_TIME_HORIZON].append(round(avg_T, 3))
  381.             self.exp_result.all_metrics[self.seed][agents_constants.COMMON.AVERAGE_UPPER_BOUND_RETURN].append(
  382.                 round(avg_upper_bound, 3))
  383.             self.exp_result.all_metrics[self.seed][agents_constants.COMMON.RUNTIME].append(time_elapsed_minutes)
  384.             self.exp_result.all_metrics[self.seed][agents_constants.COMMON.AVERAGE_RANDOM_RETURN].append(
  385.                 round(avg_random_return, 3))
  386.             running_avg_J = ExperimentUtil.running_average(
  387.                 self.exp_result.all_metrics[self.seed][agents_constants.COMMON.AVERAGE_RETURN],
  388.                 self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value)
  389.             self.exp_result.all_metrics[self.seed][agents_constants.COMMON.RUNNING_AVERAGE_RETURN].append(
  390.                 round(running_avg_J, 3))
  391.             running_avg_T = ExperimentUtil.running_average(
  392.                 self.exp_result.all_metrics[self.seed][agents_constants.COMMON.AVERAGE_TIME_HORIZON],
  393.                 self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value)
  394.             self.exp_result.all_metrics[self.seed][agents_constants.COMMON.RUNNING_AVERAGE_TIME_HORIZON].append(
  395.                 round(running_avg_T, 3))
  396.             Logger.__call__().get_logger().info(
  397.                 f"[EVAL] Training iteration: {self.iter}, Avg R:{round(avg_R, 3)}, "
  398.                 f"Running_avg_{self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value}_R: "
  399.                 f"{round(running_avg_J, 3)}, Avg T:{round(avg_T, 3)}, "
  400.                 f"Running_avg_{self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value}_T: "
  401.                 f"{round(running_avg_T, 3)}, Avg pi*: {round(avg_upper_bound, 3)}, "
  402.                 f"Avg random R:{round(avg_random_return, 3)}, time elapsed (min): {time_elapsed_minutes}")
  403.  
  404.             self.env.reset()
  405.  
  406.             # Update training job
  407.             total_steps_done = len(self.random_seeds) * self.max_steps
  408.             steps_done = (self.random_seeds.index(self.seed)) * self.max_steps + self.num_timesteps
  409.             progress = round(steps_done / total_steps_done, 2)
  410.             self.training_job.progress_percentage = progress
  411.             self.training_job.experiment_result = self.exp_result
  412.             if len(self.env.get_traces()) > 0:
  413.                 self.training_job.simulation_traces.append(self.env.get_traces()[-1])
  414.             if len(self.training_job.simulation_traces) > self.training_job.num_cached_traces:
  415.                 self.training_job.simulation_traces = self.training_job.simulation_traces[1:]
  416.             MetastoreFacade.update_training_job(training_job=self.training_job, id=self.training_job.id)
  417.  
  418.             # Update execution
  419.             ts = time.time()
  420.             self.exp_execution.timestamp = ts
  421.             self.exp_execution.result = self.exp_result
  422.             MetastoreFacade.update_experiment_execution(experiment_execution=self.exp_execution,
  423.                                                         id=self.exp_execution.id)
  424.  
  425.         self.iter += 1
  426.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement