Guest User

Untitled

a guest
Jul 19th, 2018
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.54 KB | None | 0 0
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. # Copyright © 2018 bzhou <bzhou@server2>
  6. #
  7. # Distributed under terms of the MIT license.
  8.  
  9. """
  10. mpirun -c 64 --hostfile hosts.txt python ~/cloud/es_walker.py
  11. """
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15.  
  16. import numpy as np
  17. import gym
  18. import cma
  19. from mpi4py import MPI
  20. from copy import deepcopy
  21.  
  22.  
  23. class Policy(nn.Module):
  24. def __init__(self):
  25. super(Policy, self).__init__()
  26.  
  27. self.fcs = nn.Sequential(
  28. nn.Linear(24, 128),
  29. nn.ReLU(inplace=True),
  30. nn.Linear(128, 4),
  31. nn.Tanh()
  32. )
  33.  
  34. def forward(self, x):
  35. x = self.fcs(x)
  36. return x
  37.  
  38. def flatten_model(model):
  39. params = []
  40. for param in model.parameters():
  41. params.append(param.data.detach().view(-1).numpy())
  42. return np.concatenate(params, 0)
  43.  
  44. def deflatten_model(params):
  45. model = Policy()
  46. for param in model.parameters():
  47. size = param.data.view(-1).size(0)
  48. data = params[:size]
  49. param.data = torch.FloatTensor(data).view_as(param.data)
  50. params = params[size:]
  51. return model
  52.  
  53. def rollout(model):
  54. env = gym.make('BipedalWalker-v2')
  55. rewards = []
  56. for epi in range(10):
  57. s = env.reset()
  58. reward = 0
  59. for step in range(10000):
  60. s = torch.FloatTensor(s).unsqueeze(0)
  61. a = model(s).detach().squeeze().numpy()
  62. s, r, d, _ = env.step(a)
  63. reward += r
  64. if d:
  65. break
  66. rewards.append(reward)
  67.  
  68. return np.mean(rewards)
  69.  
  70.  
  71. def main():
  72. comm = MPI.COMM_WORLD
  73. size, rank = comm.size, comm.rank
  74.  
  75. if rank == 0:
  76. model = Policy()
  77. params = flatten_model(model)
  78. es = cma.CMAEvolutionStrategy(params, 0.1, {'popsize': 64})
  79. f = open('log.log', 'w')
  80.  
  81. for step in range(1000):
  82. if rank == 0:
  83. solution = es.ask()
  84. else:
  85. solution = None
  86.  
  87. solution = comm.scatter(solution, root=0)
  88. model = deflatten_model(solution)
  89. reward = rollout(model)
  90. data = comm.gather((reward, solution, rank), root=0)
  91.  
  92. if rank == 0:
  93. cost = [-x[0] for x in data]
  94. solution = [x[1] for x in data]
  95. es.tell(solution, cost)
  96. info = 'COST--Mean: {}\t Max: {}\t Min: {}\n'.format(np.mean(cost), np.max(cost), np.min(cost))
  97. print(info)
  98. f.write(info)
  99. f.flush()
  100. else:
  101. assert data is None
  102.  
  103.  
  104.  
  105.  
  106.  
  107.  
  108.  
  109.  
  110.  
  111.  
  112.  
  113.  
  114.  
  115.  
  116.  
  117.  
  118.  
  119.  
  120.  
  121.  
  122.  
  123. if __name__ == '__main__':
  124. main()
Add Comment
Please, Sign In to add comment