Advertisement
Guest User

Untitled

a guest
Sep 17th, 2019
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.84 KB | None | 0 0
  1. def main():
  2. env = TradingEnv(custom_args=args, env_id='custom_trading_env', obs_data_len=obs_data_len, step_len=step_len, sample_len=sample_len,
  3. df=df, fee=fee, initial_budget=1, n_action_intervals=n_action_intervals, deal_col_name='c', sell_at_end=True,
  4. feature_names=['o', 'h','l','c','v',
  5. 'num_trades', 'taker_base_vol'])
  6. agent = dqn_agent.Agent(action_size=2 * n_action_intervals + 1, obs_len=obs_data_len, num_features=env.reset().shape[-1], **hyperparams)
  7. agent.qnetwork_local.load_state_dict(torch.load(os.path.join(load_location, 'TradingGym_Rainbow_1000.pth'), map_location=device))
  8. agent.qnetwork_local.to(device)
  9. for eps in range(n_episode=500):
  10. next_state, reward, done, _ = env.step(agent.act(state))
  11. agent.learn()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement