Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def main():
- env = TradingEnv(custom_args=args, env_id='custom_trading_env', obs_data_len=obs_data_len, step_len=step_len, sample_len=sample_len,
- df=df, fee=fee, initial_budget=1, n_action_intervals=n_action_intervals, deal_col_name='c', sell_at_end=True,
- feature_names=['o', 'h','l','c','v',
- 'num_trades', 'taker_base_vol'])
- agent = dqn_agent.Agent(action_size=2 * n_action_intervals + 1, obs_len=obs_data_len, num_features=env.reset().shape[-1], **hyperparams)
- agent.qnetwork_local.load_state_dict(torch.load(os.path.join(load_location, 'TradingGym_Rainbow_1000.pth'), map_location=device))
- agent.qnetwork_local.to(device)
- for eps in range(n_episode=500):
- next_state, reward, done, _ = env.step(agent.act(state))
- agent.learn()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement