SHARE
TWEET

Untitled

a guest Oct 21st, 2019 81 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. {
  2.  "cells": [
  3.   {
  4.    "cell_type": "code",
  5.    "execution_count": 1,
  6.    "metadata": {},
  7.    "outputs": [],
  8.    "source": [
  9.     "%matplotlib inline\n",
  10.     "\n",
  11.     "from IPython.display import clear_output\n",
  12.     "\n",
  13.     "import matplotlib.pyplot as plt\n",
  14.     "\n",
  15.     "import torch\n",
  16.     "import torch.nn.functional as F\n",
  17.     "import torch.nn as nn\n",
  18.     "from torch.distributions import Categorical, Bernoulli\n",
  19.     "import torch.optim as optim\n",
  20.     "from torch.autograd import Variable\n",
  21.     "\n",
  22.     "import gym\n",
  23.     "\n",
  24.     "from itertools import count\n",
  25.     "import numpy as np"
  26.    ]
  27.   },
  28.   {
  29.    "cell_type": "code",
  30.    "execution_count": 2,
  31.    "metadata": {},
  32.    "outputs": [],
  33.    "source": [
  34.     "class ActorNet(nn.Module):\n",
  35.     "    def __init__(self, state_size: int, action_size: int):\n",
  36.     "        super(ActorNet, self).__init__()\n",
  37.     "        self.fc1 = nn.Linear(state_size, 24)\n",
  38.     "        self.fc2 = nn.Linear(24, 36)\n",
  39.     "        #self.fc3 = nn.Linear(36, action_size)\n",
  40.     "        self.fc3 = nn.Linear(36, 1)\n",
  41.     "        \n",
  42.     "    def forward(self, x):\n",
  43.     "        x = F.relu(self.fc1(x))\n",
  44.     "        x = F.relu(self.fc2(x))\n",
  45.     "        \n",
  46.     "        #prob = F.softmax(self.fc3(x), dim=-1)\n",
  47.     "        #policy = Categorical(prob)\n",
  48.     "        prob = torch.sigmoid(self.fc3(x))\n",
  49.     "        policy = Bernoulli(prob)        \n",
  50.     "        return policy\n",
  51.     "    \n",
  52.     "class CriticNet(nn.Module):\n",
  53.     "    def __init__(self, state_size: int):\n",
  54.     "        super(CriticNet, self).__init__()\n",
  55.     "        #self.state_size = state_size\n",
  56.     "        self.fc1 = nn.Linear(state_size, 128)\n",
  57.     "        self.fc2 = nn.Linear(128, 256)\n",
  58.     "        self.fc3 = nn.Linear(256, 1)\n",
  59.     "        \n",
  60.     "    def forward(self, x):\n",
  61.     "        x = F.relu(self.fc1(x))\n",
  62.     "        x = F.relu(self.fc2(x))\n",
  63.     "        value = self.fc3(x)\n",
  64.     "        return value"
  65.    ]
  66.   },
  67.   {
  68.    "cell_type": "code",
  69.    "execution_count": 3,
  70.    "metadata": {},
  71.    "outputs": [],
  72.    "source": [
  73.     "def calc_discount_reward(reward_buff, gamma=0.99, normalize=True, done_flag=True):\n",
  74.     "    prev_dr = 0\n",
  75.     "    for ii in reversed(range(len(reward_buff))):\n",
  76.     "        if done_flag and reward_buff[ii] == 0:\n",
  77.     "            prev_dr = 0            \n",
  78.     "        else:\n",
  79.     "            reward_buff[ii] += prev_dr * gamma\n",
  80.     "            prev_dr = reward_buff[ii]\n",
  81.     "    \n",
  82.     "    if normalize:\n",
  83.     "        mean, std = np.mean(reward_buff), np.std(reward_buff)\n",
  84.     "        for ii in range(len(reward_buff)):\n",
  85.     "            reward_buff[ii] = (reward_buff[ii] - mean) / std"
  86.    ]
  87.   },
  88.   {
  89.    "cell_type": "code",
  90.    "execution_count": 4,
  91.    "metadata": {},
  92.    "outputs": [
  93.     {
  94.      "name": "stdout",
  95.      "output_type": "stream",
  96.      "text": [
  97.       "Episode:1999\n"
  98.      ]
  99.     },
  100.     {
  101.      "data": {
  102.       "image/png": "\n",
  103.       "text/plain": [
  104.        "<Figure size 432x288 with 1 Axes>"
  105.       ]
  106.      },
  107.      "metadata": {
  108.       "needs_background": "light"
  109.      },
  110.      "output_type": "display_data"
  111.     },
  112.     {
  113.      "name": "stdout",
  114.      "output_type": "stream",
  115.      "text": [
  116.       "Final results\n"
  117.      ]
  118.     },
  119.     {
  120.      "data": {
  121.       "image/png": "\n",
  122.       "text/plain": [
  123.        "<Figure size 432x288 with 1 Axes>"
  124.       ]
  125.      },
  126.      "metadata": {
  127.       "needs_background": "light"
  128.      },
  129.      "output_type": "display_data"
  130.     }
  131.    ],
  132.    "source": [
  133.     "num_episodes = 2000\n",
  134.     "durations_list = []\n",
  135.     "render_flag = False\n",
  136.     "lr = 3e-2\n",
  137.     "num_batch = 5\n",
  138.     "\n",
  139.     "def init_buff():\n",
  140.     "    global state_buff, action_buff, reward_buff\n",
  141.     "    state_buff = []\n",
  142.     "    action_buff = []\n",
  143.     "    reward_buff = []\n",
  144.     "    \n",
  145.     "def run(env):\n",
  146.     "    actor = ActorNet(env.observation_space.shape[0], env.action_space.n)    \n",
  147.     "    critic = CriticNet(env.observation_space.shape[0])\n",
  148.     "    optimizer_actor = optim.Adam(actor.parameters(), lr=lr)\n",
  149.     "    \n",
  150.     "    init_buff()\n",
  151.     "    for ep in range(num_episodes):\n",
  152.     "        state_numpy = env.reset()\n",
  153.     "        for t in count():\n",
  154.     "            state_torch = torch.FloatTensor(state_numpy)\n",
  155.     "            policy, value = actor(state_torch), critic(state_torch)\n",
  156.     "            action_torch = policy.sample()\n",
  157.     "            action_int = action_torch.numpy().astype(int)[0]\n",
  158.     "            next_state_numpy, reward, done, _ = env.step(action_int)\n",
  159.     "            if render_flag:\n",
  160.     "                env.render()            \n",
  161.     "            if done:\n",
  162.     "                reward = 0\n",
  163.     "                \n",
  164.     "            state_buff.append(state_torch)\n",
  165.     "            action_buff.append(action_torch)\n",
  166.     "            reward_buff.append(reward)\n",
  167.     "            \n",
  168.     "            if done:\n",
  169.     "                durations_list.append(t+1)\n",
  170.     "                break\n",
  171.     "            state_numpy = next_state_numpy\n",
  172.     "            \n",
  173.     "        if ep % num_batch == num_batch - 1:\n",
  174.     "            calc_discount_reward(reward_buff)\n",
  175.     "            optimizer_actor.zero_grad()\n",
  176.     "            for state_torch, action_torch, reward in zip(state_buff, action_buff, reward_buff):\n",
  177.     "                policy = actor(Variable(state_torch))\n",
  178.     "                loss = -policy.log_prob(Variable(action_torch)) * reward  \n",
  179.     "                loss.backward()\n",
  180.     "            optimizer_actor.step()        \n",
  181.     "            init_buff()\n",
  182.     "        \n",
  183.     "        if ep % 100 == 99:\n",
  184.     "            clear_output()\n",
  185.     "            print(f'Episode:{ep}')\n",
  186.     "            plt.plot(durations_list)\n",
  187.     "            plt.show()\n",
  188.     "            \n",
  189.     "    print('Final results')\n",
  190.     "    plt.plot(durations_list)\n",
  191.     "    plt.show()    \n",
  192.     "\n",
  193.     "env = gym.make('CartPole-v0')\n",
  194.     "run(env)\n",
  195.     "env.close()"
  196.    ]
  197.   }
  198.  ],
  199.  "metadata": {
  200.   "kernelspec": {
  201.    "display_name": "pytorch",
  202.    "language": "python",
  203.    "name": "pytorch"
  204.   },
  205.   "language_info": {
  206.    "codemirror_mode": {
  207.     "name": "ipython",
  208.     "version": 3
  209.    },
  210.    "file_extension": ".py",
  211.    "mimetype": "text/x-python",
  212.    "name": "python",
  213.    "nbconvert_exporter": "python",
  214.    "pygments_lexer": "ipython3",
  215.    "version": "3.7.3"
  216.   }
  217.  },
  218.  "nbformat": 4,
  219.  "nbformat_minor": 4
  220. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top