Advertisement
Guest User

Untitled

a guest
Jan 18th, 2017
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.82 KB | None | 0 0
  1. import argparse
  2. import gym
  3. import numpy as np
  4. from itertools import count
  5.  
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.optim as optim
  10. import torch.autograd as autograd
  11. from torch.autograd import Variable
  12. import torchvision.transforms as T
  13.  
  14.  
  15. parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
  16. parser.add_argument('--gamma', type=float, default=0.9, metavar='G',
  17. help='discount factor (default: 0.999)')
  18. parser.add_argument('--seed', type=int, default=1, metavar='N',
  19. help='random seed (default: 1)')
  20. parser.add_argument('--log-interval', type=int, default=50, metavar='N',
  21. help='interval between training status logs')
  22. args = parser.parse_args()
  23.  
  24. # torch.manual_seed(args.seed)
  25.  
  26. class Policy(nn.Module):
  27. def __init__(self):
  28. super(Policy, self).__init__()
  29. self.affine1 = nn.Linear(4, 128)
  30. self.affine3 = nn.Linear(128, 2)
  31.  
  32. self.sampled_probs = []
  33. self.sampled_actions = []
  34. self.rewards = []
  35.  
  36. def forward(self, x):
  37. x = F.relu(self.affine1(x))
  38. x = F.relu(self.affine3(x))
  39. return F.softmax(x)
  40.  
  41.  
  42. env = gym.make('CartPole-v0')
  43. model = Policy()
  44. optimizer = optim.RMSprop(model.parameters(), lr=1e-2, alpha=1, eps=1e-10)
  45.  
  46.  
  47. def select_action(state):
  48. state = torch.from_numpy(state).float().unsqueeze(0)
  49. probs = model(Variable(state))
  50. action = probs.multinomial(1, True).data.squeeze()[0]
  51. model.sampled_probs.append(probs)
  52. model.sampled_actions.append(action)
  53. return action
  54.  
  55.  
  56. def finish_episode():
  57. R = 0
  58. sampled_actions = model.sampled_actions
  59. sampled_probs = model.sampled_probs
  60. rewards = []
  61. for action, r in zip(sampled_actions[::-1], model.rewards[::-1]):
  62. R = r + args.gamma * R
  63. rewards.insert(0, R)
  64. rewards = torch.Tensor(rewards)
  65. rewards = (rewards - rewards.mean()) / rewards.std()
  66. ys = [torch.zeros(1, 2) for _ in sampled_probs]
  67. for i, action in enumerate(sampled_actions):
  68. ys[i][0][action] = 1
  69.  
  70. loss = [((Variable(y) - p)**2).sum() / 2 for p, y in zip(sampled_probs, ys)]
  71.  
  72. grads = {}
  73. for i, l in enumerate(loss):
  74. optimizer.zero_grad()
  75. l.backward()
  76. for j, group in enumerate(optimizer.param_groups):
  77. saved_group = grads.get(j, {})
  78. '''
  79. x = group['params'][-1]
  80. print(x)
  81. print(rewards[i] * x)
  82. '''
  83. for k, param in enumerate(group['params']):
  84. cumsum = saved_group.get(k, torch.zeros(param.grad.size()))
  85. cumsum += rewards[i] * param.grad.data
  86. saved_group[k] = cumsum
  87. '''
  88. print(cumsum)
  89. import pdb; pdb.set_trace()
  90. '''
  91. grads[j] = saved_group
  92. for j, group in enumerate(optimizer.param_groups):
  93. for k, param in enumerate(group['params']):
  94. param.grad.data = grads[j][k] / len(loss)
  95. '''
  96. for j, group in enumerate(optimizer.param_groups):
  97. for k, param in enumerate(group['params']):
  98. print("grad: ", param.grad.data.abs().max())
  99. print("param: ", param.data.abs().max())
  100. '''
  101.  
  102. optimizer.step()
  103. del model.rewards[:]
  104. del model.sampled_actions[:]
  105. del model.sampled_probs[:]
  106.  
  107.  
  108. running_reward = 10
  109. for i_episode in count(1):
  110. reward_sum = 0
  111. state = env.reset()
  112. for t in count(1):
  113. action = select_action(state)
  114. state, reward, done, _ = env.step(action)
  115. model.rewards.append(reward)
  116. reward_sum += reward
  117. if done:
  118. break
  119.  
  120. running_reward = running_reward * 0.99 + reward_sum * 0.01
  121. finish_episode()
  122. if i_episode % args.log_interval == 0:
  123. print('Episode {}\tLast length: {:5f}\tAverage length: {:.2f}'.format(
  124. i_episode, reward_sum, running_reward))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement