Guest User

Untitled

a guest
Jan 3rd, 2025
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 25.09 KB | None | 0 0
  1. ---------------------------------------------------------------------------------------------------------------------------------------
  2. PPO.PY:
  3. ---------------------------------------------------------------------------------------------------------------------------------------
  4.  
  5. class PPOMemory:
  6.     def __init__(self, batch_size, num_trajectories, num_steps_trajectory, state_size, device):
  7.  
  8.         self.states = torch.zeros((num_trajectories, num_steps_trajectory) + (state_size,)).to(device)
  9.         self.actions = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  10.         self.logprobs = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  11.         self.values = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  12.         self.rewards = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  13.         self.next_states = torch.zeros((num_trajectories, num_steps_trajectory) + (state_size,)).to(device)
  14.         self.dones = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  15.  
  16.         self.advantages = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  17.         self.returns = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  18.  
  19.  
  20.     def store(self, state, action, log_prob, value, reward, next_state, done, trajectory, step):
  21.         self.states[trajectory][step] = state
  22.         self.actions[trajectory][step] = action
  23.         self.logprobs[trajectory][step] = log_prob
  24.         self.values[trajectory][step] = value
  25.         self.rewards[trajectory][step] = reward
  26.         self.next_states[trajectory][step] = next_state
  27.         self.dones[trajectory][step] = done
  28.  
  29.  
  30.     def clear(self, num_trajectories, num_steps_trajectory, state_size, device):
  31.         self.states = torch.zeros((num_trajectories, num_steps_trajectory) + (state_size,)).to(device)
  32.         self.actions = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  33.         self.logprobs = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  34.         self.values = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  35.         self.rewards = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  36.         self.next_states = torch.zeros((num_trajectories, num_steps_trajectory) + (state_size,)).to(device)
  37.         self.dones = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  38.  
  39.         self.advantages = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  40.         self.returns = torch.zeros((num_trajectories, num_steps_trajectory)).to(device)
  41.  
  42.     def flatten(self, state_size):
  43.         # Flatten data
  44.         states_flat = self.states.reshape((-1,) + (state_size,))
  45.         actions_flat = self.actions.reshape((-1,))
  46.         logprobs_flat = self.logprobs.reshape((-1,))
  47.         advantages_flat = self.advantages.reshape((-1))
  48.         returns_flat = self.returns.reshape((-1,))
  49.         values_flat = self.values.reshape((-1,))
  50.  
  51.         return states_flat, actions_flat, logprobs_flat, advantages_flat, returns_flat, values_flat
  52.  
  53.     def generate_batches(self, batch_size, minibatch_size):
  54.  
  55.         # Generate batch indices
  56.         indices = np.arange(batch_size)
  57.         batch_start_ind = np.arange(0, batch_size, minibatch_size)
  58.         np.random.shuffle(indices)
  59.         batches = [indices[i:i + minibatch_size] for i in batch_start_ind]
  60.  
  61.         return batches
  62.  
  63. class BuildNetwork(nn.Module):
  64.     def __init__(self, nn_architecture, input_size, output_size, learning_rate, device):
  65.         super().__init__()
  66.  
  67.         # Initialize an empty list to fill with layer info
  68.         layers = []
  69.  
  70.         # Add input layer using nn.Linear
  71.         in_size = input_size
  72.         out_size, _ = nn_architecture[0]
  73.         layers.append(nn.Linear(in_size, out_size))
  74.  
  75.         # Iteratively add hidden layers with layer info from the nn_architecture parameter
  76.         for i in range(len(nn_architecture) - 1):
  77.  
  78.             # Load layer size and activation function
  79.             in_size, activation = nn_architecture[i]
  80.             out_size, _ = nn_architecture[i + 1]
  81.  
  82.             # Add layer to list layers using nn.Linear
  83.             layers.append(nn.Linear(in_size, out_size))
  84.  
  85.             # Add activation function to list layers
  86.             if activation.lower() == "relu":
  87.                 layers.append(nn.ReLU())
  88.             elif activation.lower() == "sigmoid":
  89.                 layers.append(nn.Sigmoid())
  90.             elif activation.lower() == "tanh":
  91.                 layers.append(nn.Tanh())
  92.             elif activation.lower() == "linear":
  93.                 pass
  94.             else:
  95.                 raise ValueError(f"Unsupported activation function: {activation}")
  96.  
  97.         # Add output layer using nn.Linear
  98.         last_in_size, last_activation = nn_architecture[-1]
  99.         layers.append(nn.Linear(last_in_size, output_size))
  100.  
  101.         # Create network using the layer info
  102.         self.model = nn.Sequential(*layers)
  103.  
  104.         # Set optimizer
  105.         self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
  106.  
  107.         # Move model to device
  108.         self.to(device)
  109.  
  110.     def forward(self, state):
  111.  
  112.         # Output unactivated outputs given a state
  113.         output = self.model(state)
  114.  
  115.         return output
  116.  
  117. class PPO_Agent():
  118.     def __init__(self, env, config, device):
  119.  
  120.         self.device = device
  121.         self.env = env
  122.  
  123.         # General
  124.         self.state_size = config["state_size"]
  125.         self.num_actions = config["num_actions"]
  126.         self.num_episodes_train = config["num_episodes_train"]
  127.         self.num_episodes_validate = config["num_episodes_validate"]
  128.         self.num_episodes_test = config["num_episodes_test"]
  129.         self.num_trajectories = config["num_trajectories"]
  130.         self.num_steps_trajectory = config["num_steps_trajectory"]
  131.         self.num_epochs = config["num_epochs"]
  132.         self.batch_size = self.num_trajectories * self.num_steps_trajectory
  133.         self.num_minibatches = config["num_minibatches"]
  134.         self.minibatch_size = self.batch_size // self.num_minibatches
  135.  
  136.         # Hyperparameters
  137.         self.nn_architecture_actor = config["nn_architecture_actor"]
  138.         self.nn_architecture_critic = config["nn_architecture_critic"]
  139.         self.learning_rate_actor = config["learning_rate_actor"]
  140.         self.learning_rate_critic = config["learning_rate_critic"]
  141.         self.gamma = config["gamma"]
  142.         self.lam = config["lam"]
  143.         self.clip_ratio = config["clip_ratio"]
  144.         self.entropy_coef = config["entropy_coef"]
  145.         self.value_coef = config["value_coef"]
  146.         self.entropy_coef_min = config["entropy_coef_min"]
  147.         self.entropy_coef_decay = config["entropy_coef_decay"]
  148.         self.max_grad_norm = config["max_grad_norm"]
  149.  
  150.         # Memory Initialization
  151.         self.memory = PPOMemory(self.batch_size, self.num_trajectories, self.num_steps_trajectory, self.state_size, self.device)
  152.  
  153.         # Network Initialization
  154.         self.actor = BuildNetwork(self.nn_architecture_actor, self.state_size, self.num_actions, self.learning_rate_actor, self.device)
  155.         self.critic = BuildNetwork(self.nn_architecture_critic, self.state_size, 1, self.learning_rate_critic, self.device)
  156.  
  157.     def load_weights(self, path):
  158.         self.actor.load_state_dict(torch.load(path))
  159.         self.actor.eval()
  160.  
  161.     def save_weights(self, path):
  162.         torch.save(self.actor.state_dict(), path)
  163.  
  164.     def get_action(self, state):
  165.  
  166.         # Output logits given a state
  167.         logits = self.actor(state)
  168.  
  169.         # Create a Categorical distribution from the output
  170.         dist = Categorical(logits=logits)
  171.  
  172.         action = dist.sample()
  173.         log_prob = dist.log_prob(action)
  174.  
  175.         return action, log_prob
  176.  
  177.     def get_value(self, state):
  178.  
  179.         # Output value give a state
  180.         value = self.critic(state)
  181.  
  182.         return value
  183.  
  184.     def calc_advantage(self, values, next_states, rewards, dones, device):
  185.         with torch.no_grad():
  186.  
  187.             # Initialize advantages and returns
  188.             advantages = torch.zeros_like(values).to(device)
  189.             returns = torch.zeros_like(values).to(device)
  190.  
  191.             # Last advantage accumulator
  192.             advantage = torch.zeros(self.num_trajectories).to(device)
  193.  
  194.             # Iterate in reverse over time steps
  195.             for t in reversed(range(self.num_steps_trajectory)):
  196.  
  197.                 # Mask for terminal states
  198.                 mask = 1.0 - dones[:, t]
  199.                 if t == self.num_steps_trajectory - 1:
  200.                     next_values = self.get_value(next_states[:, t])
  201.                 else:
  202.                     next_values = values[:, t + 1]
  203.  
  204.                 # Compute delta/TD-error
  205.                 delta = rewards[:, t] + self.gamma * next_values.view(-1) * mask - values[:, t]
  206.  
  207.                 # Recursive GAE
  208.                 advantage = delta + self.gamma * self.lam * mask * advantage
  209.                 advantages[:, t] = advantage
  210.  
  211.                 # Compute return
  212.                 returns[:, t] = advantage + values[:, t]
  213.  
  214.         return advantages, returns
  215.  
  216.     def compute_loss(self, states_mb, actions_mb, old_logprobs_mb, advantages_norm_mb, returns_mb):
  217.  
  218.         # Compute prob_ratio
  219.         logits = self.actor(states_mb)
  220.  
  221.         dist = Categorical(logits=logits)
  222.  
  223.         new_logprobs_mb = dist.log_prob(actions_mb)
  224.         prob_ratio = (new_logprobs_mb - old_logprobs_mb).exp()
  225.         clipped_prob_ratio = torch.clamp(prob_ratio, 1 - self.clip_ratio, 1 + self.clip_ratio)
  226.  
  227.         # Compute entropy loss
  228.         entropy = dist.entropy()
  229.         entropy_loss = entropy.mean()
  230.  
  231.         # Compute PPO objective function
  232.         actor_loss = -(torch.min(advantages_norm_mb * prob_ratio, advantages_norm_mb * clipped_prob_ratio).mean())
  233.  
  234.         # Compute value loss
  235.         critic_values = self.critic(states_mb)
  236.         critic_loss = ((critic_values - returns_mb) ** 2).mean()
  237.  
  238.         # Compute total loss
  239.         total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy_loss
  240.  
  241.         return total_loss, actor_loss, critic_loss, entropy_loss
  242.  
  243.     def update_networks(self):
  244.  
  245.         # Compute advantages and returns
  246.         advantages, returns = self.calc_advantage(self.memory.values, self.memory.next_states, self.memory.rewards, self.memory.dones, self.device)
  247.         self.memory.advantages = advantages
  248.         self.memory.returns = returns
  249.  
  250.         # Flatten memory
  251.         states, actions, old_logprobs, advantages, returns, values, = self.memory.flatten(self.state_size)
  252.  
  253.         for epoch in range(self.num_epochs):
  254.  
  255.             # Generate batches
  256.             batches = self.memory.generate_batches(self.batch_size, self.minibatch_size)
  257.  
  258.             for minibatch in batches:
  259.                 # States, logprobs, action, returns for the current minibatch
  260.                 states_mb = states[minibatch]
  261.                 old_logprobs_mb = old_logprobs[minibatch]
  262.                 actions_mb = actions[minibatch]
  263.                 returns_mb = returns[minibatch]
  264.                 values_mb = values[minibatch]
  265.  
  266.                 # Advantage normalization for the current minibatch
  267.                 advantages_mb = advantages[minibatch]
  268.                 advantages_norm_mb = (advantages_mb - advantages_mb.mean()) / (advantages_mb.std() + 1e-8)
  269.  
  270.                 total_loss, actor_loss, critic_loss, entropy_loss = self.compute_loss(states_mb, actions_mb, old_logprobs_mb, advantages_norm_mb, returns_mb)
  271.  
  272.                 # Perform updates
  273.                 self.actor.optimizer.zero_grad()
  274.                 self.critic.optimizer.zero_grad()
  275.                 total_loss.backward()
  276.                 nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
  277.                 nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
  278.                 self.actor.optimizer.step()
  279.                 self.critic.optimizer.step()
  280.  
  281.                 return total_loss, actor_loss, critic_loss, entropy_loss
  282.  
  283.     def run(self, mode, num_episodes):
  284.  
  285.         rewards = []
  286.  
  287.         if mode == "training":
  288.             training = True
  289.         else:
  290.             training = False
  291.  
  292.         total_loss = 0
  293.         actor_loss = 0
  294.         critic_loss = 0
  295.         entropy_loss = 0
  296.  
  297.         for episode in range(num_episodes):
  298.  
  299.             # Initialize reward to zero
  300.             reward_episode = 0
  301.  
  302.             # Keep track of the number of trajectories per policy rollout
  303.             trajectory = episode % self.num_trajectories
  304.  
  305.             # Reset environment
  306.             state = torch.tensor(self.env.reset(mode), dtype=torch.float32).to(self.device)
  307.  
  308.             # Perform one trajectory
  309.             for step in range(self.num_steps_trajectory):
  310.  
  311.                 # Perform one environment step
  312.                 with torch.no_grad():
  313.                     action, log_prob = self.get_action(state)
  314.                     value = self.get_value(state).squeeze()
  315.                     reward, next_state, done = self.env.step(action)
  316.                 reward_episode += reward
  317.  
  318.                 # Transform to tensors
  319.                 reward = torch.tensor([reward]).to(self.device)
  320.                 next_state = torch.tensor(next_state, dtype=torch.float32).to(self.device)
  321.                 done = torch.tensor([done]).to(self.device)
  322.  
  323.                 # Keep track of rollout data in case of training
  324.                 if training:
  325.                     self.memory.store(state, action, log_prob, value, reward, next_state, done, trajectory, step)
  326.  
  327.                 # Update state for next step
  328.                 state = next_state
  329.  
  330.             # Append episodic reward
  331.             rewards.append(reward_episode)
  332.  
  333.             average_reward = np.mean(rewards)
  334.  
  335.             # Perform update step with rollout data and clear memory afterwards for new rollout phase
  336.             if training and trajectory == self.num_trajectories - 1:
  337.  
  338.                 total_loss, actor_loss, critic_loss, entropy_loss = self.update_networks()
  339.  
  340.                 self.memory.clear(self.num_trajectories, self.num_steps_trajectory, self.state_size, self.device)
  341.  
  342.             # Log relevant data
  343.             wandb.log({"Charts/Average Reward": average_reward}, episode)
  344.             wandb.log({"Charts/Reward per Episode": reward_episode}, episode)
  345.             wandb.log({"Losses/Total Loss": total_loss}, episode)
  346.             wandb.log({"Losses/Actor Loss": actor_loss}, episode)
  347.             wandb.log({"Losses/Critic Loss": critic_loss}, episode)
  348.             wandb.log({"Losses/Entropy": entropy_loss}, episode)
  349.  
  350.             print(f"EPISODE: {episode + 1} / {num_episodes}, Total Reward: {reward_episode}")
  351.  
  352.         return rewards
  353.  
  354. ---------------------------------------------------------------------------------------------------------------------------------------
  355. ENVIRONMENT.PY
  356. ---------------------------------------------------------------------------------------------------------------------------------------
  357.  
  358. # actions: 0 (nothing), 1 (up), 2 (right), 3 (down), 4 (left)
  359.  
  360. # positions in grid:
  361. # - (0,0) is upper left corner
  362. # - first index is vertical (increasing from top to bottom)
  363. # - second index is horizontal (increasing from left to right)
  364.  
  365. # if new item appears in a cell into which the agent moves/at which the agent stays in the same time step,
  366. # it is not picked up (if agent wants to pick it up, it has to stay in the cell in the next time step)
  367.  
  368. import random
  369. from typing import List, Tuple
  370. import pandas as pd
  371. from copy import deepcopy
  372. from itertools import compress
  373. import numpy as np
  374.  
  375.  
  376. # TODO: delete / move
  377. def manhatten_dist(pos: Tuple[int, int], tgt: Tuple[int, int]) -> Tuple[int, int]:
  378.     return abs(pos[0] - tgt[0]) + abs(pos[1] - tgt[1])
  379.  
  380. class Environment(object):
  381.     def __init__(self, variant, data_dir):
  382.         self.variant = variant
  383.         self.vertical_cell_count = 5
  384.         self.horizontal_cell_count = 5
  385.         self.vertical_idx_target = 2
  386.         self.horizontal_idx_target = 0
  387.         self.target_loc = (self.vertical_idx_target, self.horizontal_idx_target)
  388.         self.episode_steps = 200
  389.         self.max_response_time = 15 if self.variant == 2 else 10
  390.         self.reward = 25 if self.variant == 2 else 15
  391.         self.data_dir = data_dir
  392.  
  393.         self.training_episodes = pd.read_csv(self.data_dir + f"/variant_{self.variant}/training_episodes.csv")
  394.         self.training_episodes = self.training_episodes.training_episodes.tolist()
  395.         self.validation_episodes = pd.read_csv(self.data_dir + f"/variant_{self.variant}/validation_episodes.csv")
  396.         self.validation_episodes = self.validation_episodes.validation_episodes.tolist()
  397.         self.test_episodes = pd.read_csv(self.data_dir + f"/variant_{self.variant}/test_episodes.csv")
  398.         self.test_episodes = self.test_episodes.test_episodes.tolist()
  399.  
  400.         self.remaining_training_episodes = deepcopy(self.training_episodes)
  401.         self.validation_episode_counter = 0
  402.  
  403.         if self.variant == 0 or self.variant == 2:
  404.             self.agent_capacity = 1
  405.         else:
  406.             self.agent_capacity = 3
  407.  
  408.         if self.variant == 0 or self.variant == 1:
  409.             self.eligible_cells = [(0,0), (0,1), (0,2), (0,3), (0,4),
  410.                                    (1,0), (1,1), (1,2), (1,3), (1,4),
  411.                                    (2,0), (2,1), (2,2), (2,3), (2,4),
  412.                                    (3,0), (3,1), (3,2), (3,3), (3,4),
  413.                                    (4,0), (4,1), (4,2), (4,3), (4,4)]
  414.         else:
  415.             self.eligible_cells = [(0,0),        (0,2), (0,3), (0,4),
  416.                                    (1,0),        (1,2),        (1,4),
  417.                                    (2,0),        (2,2),        (2,4),
  418.                                    (3,0), (3,1), (3,2),        (3,4),
  419.                                    (4,0), (4,1), (4,2),        (4,4)]
  420.  
  421.  
  422.        
  423.         self.current_episode_target_count = 0  # Counts number of items dropped off at target cell
  424.  
  425.     # initialize a new episode (specify if training, validation, or testing via the mode argument)
  426.     def reset(self, mode):
  427.         modes = ["training", "validation", "testing"]
  428.         if mode not in modes:
  429.             raise ValueError("Invalid mode. Expected one of: %s" % modes)
  430.  
  431.         self.step_count = 0
  432.         self.agent_loc = (self.vertical_idx_target, self.horizontal_idx_target)
  433.         self.agent_load = 0  # number of items loaded (0 or 1, except for first extension, where it can be 0,1,2,3)
  434.         self.item_locs = []
  435.         self.item_times = []
  436.  
  437.         self.past_items = []
  438.         #self.item_distances = []
  439.  
  440.         if mode == "testing":
  441.             episode = self.test_episodes[0]
  442.             self.test_episodes.remove(episode)
  443.         elif mode == "validation":
  444.             episode = self.validation_episodes[self.validation_episode_counter]
  445.             self.validation_episode_counter = (self.validation_episode_counter + 1) % 100
  446.         else:
  447.             if not self.remaining_training_episodes:
  448.                 self.remaining_training_episodes = deepcopy(self.training_episodes)
  449.             episode = random.choice(self.remaining_training_episodes)
  450.             self.remaining_training_episodes.remove(episode)
  451.         self.data = pd.read_csv(self.data_dir + f"/variant_{self.variant}/episode_data/episode_{episode:03d}.csv", index_col=0)
  452.        
  453.         #For CNN:
  454.         self.past_items = np.zeros((self.vertical_cell_count, self.horizontal_cell_count), dtype=int)
  455.  
  456.         return self.get_obs()
  457.  
  458.     # take one environment step based on the action act
  459.     def step(self, act):
  460.         self.step_count += 1
  461.         rew = 0
  462.  
  463.         # done signal (1 if episode ends, 0 if not)
  464.         if self.step_count == self.episode_steps:
  465.             done = 1
  466.         else:
  467.             done = 0
  468.  
  469.         #MP
  470.         new_loc = self.agent_loc
  471.  
  472.         # agent movement
  473.         if act != 0:
  474.             if act == 1:  # up
  475.                 new_loc = (self.agent_loc[0] - 1, self.agent_loc[1])
  476.             elif act == 2:  # right
  477.                 new_loc = (self.agent_loc[0], self.agent_loc[1] + 1)
  478.             elif act == 3:  # down
  479.                 new_loc = (self.agent_loc[0] + 1, self.agent_loc[1])
  480.             elif act == 4:  # left
  481.                 new_loc = (self.agent_loc[0], self.agent_loc[1] - 1)
  482.  
  483.             if new_loc in self.eligible_cells:
  484.                 self.agent_loc = new_loc
  485.                 rew += -1
  486.  
  487.         # item pick-up
  488.         if (self.agent_load < self.agent_capacity) and (self.agent_loc in self.item_locs):
  489.             self.agent_load += 1
  490.             idx = self.item_locs.index(self.agent_loc)
  491.             self.item_locs.pop(idx)
  492.             self.item_times.pop(idx)
  493.             rew += self.reward / 2
  494.  
  495.         # item drop-off
  496.         if self.agent_loc == self.target_loc:
  497.             rew += self.agent_load * self.reward / 2
  498.             self.current_episode_target_count += self.agent_load
  499.             self.agent_load = 0
  500.  
  501.         # track how long ago items appeared
  502.         self.item_times = [i + 1 for i in self.item_times]
  503.  
  504.         # remove items for which max response time is reached
  505.         mask = [i < self.max_response_time for i in self.item_times]
  506.         self.item_locs = list(compress(self.item_locs, mask))
  507.         self.item_times = list(compress(self.item_times, mask))
  508.  
  509.         # add items which appear in the current time step
  510.         new_items = self.data[self.data.step == self.step_count]
  511.         new_items = list(zip(new_items.vertical_idx, new_items.horizontal_idx))
  512.         new_items = [i for i in new_items if i not in self.item_locs]  # not more than one item per cell
  513.         self.item_locs += new_items
  514.         self.item_times += [0] * len(new_items)
  515.  
  516.         #FOR CNN
  517.         for loc in new_items:
  518.             self.past_items[loc] += 1
  519.  
  520.         # get new observation
  521.         next_obs = self.get_obs()
  522.  
  523.         return rew, next_obs, done
  524.  
  525.     def get_state_size(self):
  526.         size = self.reset("training").shape[0]
  527.         return size
  528.  
  529.     def get_obs(self) -> List[float]:
  530.         grid_shape = (self.vertical_cell_count, self.horizontal_cell_count)
  531.  
  532.         # Agent position
  533.         agent_position = np.zeros(grid_shape)
  534.         agent_position[self.agent_loc] = 1
  535.  
  536.         # Item positions
  537.         item_positions = np.zeros(grid_shape)
  538.         for loc in self.item_locs:
  539.             item_positions[loc] = 1
  540.  
  541.         # Remaining times of items (normalized)
  542.         remaining_times = np.zeros(grid_shape)
  543.         for loc, time in zip(self.item_locs, self.item_times):
  544.             remaining_times[loc] = (self.max_response_time - time) / self.max_response_time
  545.  
  546.         # Past items (spawn counts at each position)
  547.         past_items_vector = self.past_items_vector.flatten()
  548.  
  549.         total_spawns = past_items_vector.sum()
  550.         if total_spawns > 0:
  551.             past_items_vector = past_items_vector / total_spawns
  552.         else:
  553.             past_items_vector = np.zeros_like(past_items_vector)
  554.  
  555.         # Target location
  556.         target_location = np.zeros(grid_shape)
  557.         target_location[self.target_loc] = 1
  558.  
  559.         # Free capacity
  560.         free_capacity = self.agent_capacity - self.agent_load
  561.  
  562.         # Manhattan distances to items and distance to closest item
  563.         distance_to_closest_item = 1000
  564.         remaining_time_closest_item = 1000
  565.         distance_closest_item_to_target = 1000
  566.         manhattan_distance_items = np.zeros(grid_shape)
  567.         for loc in self.item_locs:
  568.             manhattan_distance_items[loc] = manhatten_dist(self.agent_loc, loc)
  569.  
  570.             if manhattan_distance_items[loc] < distance_to_closest_item:
  571.                 distance_to_closest_item = manhattan_distance_items[loc]
  572.                 remaining_time_closest_item = remaining_times[loc]
  573.                 distance_closest_item_to_target = manhatten_dist(loc, self.target_loc)
  574.  
  575.         # Manhatten distance to target
  576.         manhattan_distance_target = np.zeros(grid_shape)
  577.         manhattan_distance_target[self.target_loc] = manhatten_dist(self.agent_loc, self.target_loc)
  578.  
  579.  
  580.         # Distance to the closest wall
  581.         distance_to_walls = [
  582.             self.agent_loc[0],
  583.             self.vertical_cell_count - 1 - self.agent_loc[0],
  584.             self.agent_loc[1],
  585.             self.horizontal_cell_count - 1 - self.agent_loc[1]
  586.         ]
  587.  
  588.         # Current number of items on the field
  589.         num_items = len(self.item_locs)
  590.  
  591.         state = np.concatenate([
  592.             np.array(agent_position).flatten(), # Dim 25
  593.             np.array(item_positions).flatten(), # Dim 25
  594.             np.array(remaining_times).flatten(), # Dim 25
  595.             np.array(target_location).flatten(), # Dim 25
  596.             np.array(free_capacity).flatten(), # Dim 1
  597.             np.array(manhattan_distance_items).flatten(), # Dim 25
  598.             np.array(manhattan_distance_target).flatten(),  # Dim 25
  599.             np.array(distance_to_closest_item).flatten(),  # Dim 1
  600.             np.array(remaining_time_closest_item).flatten(),  # Dim 1
  601.             np.array(distance_closest_item_to_target).flatten(),  # Dim 1
  602.             np.array(distance_to_walls).flatten(), # Dim 4
  603.             np.array(num_items).flatten(), # Dim 1
  604.         ])
  605.  
  606.         return state
  607.    
  608.  
  609.  
  610.  
  611.  
  612.  
  613.  
Add Comment
Please, Sign In to add comment