Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- #
- # Copyright © 2018 bzhou <bzhou@server2>
- #
- # Distributed under terms of the MIT license.
- """
- mpirun -c 64 --hostfile hosts.txt python ~/cloud/es_walker.py
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- import gym
- import cma
- from mpi4py import MPI
- from copy import deepcopy
- class Policy(nn.Module):
- def __init__(self):
- super(Policy, self).__init__()
- self.fcs = nn.Sequential(
- nn.Linear(24, 128),
- nn.ReLU(inplace=True),
- nn.Linear(128, 4),
- nn.Tanh()
- )
- def forward(self, x):
- x = self.fcs(x)
- return x
- def flatten_model(model):
- params = []
- for param in model.parameters():
- params.append(param.data.detach().view(-1).numpy())
- return np.concatenate(params, 0)
- def deflatten_model(params):
- model = Policy()
- for param in model.parameters():
- size = param.data.view(-1).size(0)
- data = params[:size]
- param.data = torch.FloatTensor(data).view_as(param.data)
- params = params[size:]
- return model
- def rollout(model):
- env = gym.make('BipedalWalker-v2')
- rewards = []
- for epi in range(10):
- s = env.reset()
- reward = 0
- for step in range(10000):
- s = torch.FloatTensor(s).unsqueeze(0)
- a = model(s).detach().squeeze().numpy()
- s, r, d, _ = env.step(a)
- reward += r
- if d:
- break
- rewards.append(reward)
- return np.mean(rewards)
- def main():
- comm = MPI.COMM_WORLD
- size, rank = comm.size, comm.rank
- if rank == 0:
- model = Policy()
- params = flatten_model(model)
- es = cma.CMAEvolutionStrategy(params, 0.1, {'popsize': 64})
- f = open('log.log', 'w')
- for step in range(1000):
- if rank == 0:
- solution = es.ask()
- else:
- solution = None
- solution = comm.scatter(solution, root=0)
- model = deflatten_model(solution)
- reward = rollout(model)
- data = comm.gather((reward, solution, rank), root=0)
- if rank == 0:
- cost = [-x[0] for x in data]
- solution = [x[1] for x in data]
- es.tell(solution, cost)
- info = 'COST--Mean: {}\t Max: {}\t Min: {}\n'.format(np.mean(cost), np.max(cost), np.min(cost))
- print(info)
- f.write(info)
- f.flush()
- else:
- assert data is None
- if __name__ == '__main__':
- main()
Add Comment
Please, Sign In to add comment