SHARE
TWEET

Untitled

a guest Apr 23rd, 2019 65 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. {
  2.  "cells": [
  3.   {
  4.    "cell_type": "code",
  5.    "execution_count": 1,
  6.    "metadata": {},
  7.    "outputs": [],
  8.    "source": [
  9.     "import numpy as np\n",
  10.     "\n",
  11.     "from tvm import relay\n",
  12.     "from tvm.relay import testing\n",
  13.     "from tvm.relay.testing import layers\n",
  14.     "from tvm.relay.op.nn import nn\n",
  15.     "\n",
  16.     "from tvm.relay.testing.init import create_workload\n",
  17.     "import tvm\n",
  18.     "from tvm.contrib import graph_runtime"
  19.    ]
  20.   },
  21.   {
  22.    "cell_type": "code",
  23.    "execution_count": 2,
  24.    "metadata": {},
  25.    "outputs": [
  26.     {
  27.      "name": "stdout",
  28.      "output_type": "stream",
  29.      "text": [
  30.       "26 26\n"
  31.      ]
  32.     }
  33.    ],
  34.    "source": [
  35.     "# set hyperparams\n",
  36.     "batch_size = 1\n",
  37.     "in_c = 512\n",
  38.     "num_filters = 512\n",
  39.     "in_h, in_w = 28, 28\n",
  40.     "kdim = 3\n",
  41.     "groups=1\n",
  42.     "stride=1\n",
  43.     "pad=0\n",
  44.     "data_shape = (batch_size, in_c, in_h, in_w)\n",
  45.     "num_groups = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]\n",
  46.     "out_h = (in_h + 2* pad - kdim) // stride + 1\n",
  47.     "out_w = (in_w + 2* pad - kdim) // stride + 1\n",
  48.     "print(out_w, out_h)\n",
  49.     "out_shape = (1, num_filters, out_h, out_w)"
  50.    ]
  51.   },
  52.   {
  53.    "cell_type": "code",
  54.    "execution_count": 3,
  55.    "metadata": {},
  56.    "outputs": [],
  57.    "source": [
  58.     "def conv_block(data, name, num_filters, kernel_size=(3, 3), strides=(1, 1),\n",
  59.     "               padding=(1, 1), groups=1, dtype='float32'):\n",
  60.     "    \"\"\"Helper function to construct conv layer\"\"\"\n",
  61.     "    kdim1, kdim2 = kernel_size\n",
  62.     "    weight = relay.var(name + \"_weight\")\n",
  63.     "    conv = nn.conv2d(\n",
  64.     "        data=data,\n",
  65.     "        weight=weight,\n",
  66.     "        channels=num_filters,\n",
  67.     "        kernel_size=kernel_size,\n",
  68.     "        strides=strides,\n",
  69.     "        padding=padding,\n",
  70.     "        groups=groups)\n",
  71.     "    return conv"
  72.    ]
  73.   },
  74.   {
  75.    "cell_type": "code",
  76.    "execution_count": 4,
  77.    "metadata": {},
  78.    "outputs": [],
  79.    "source": [
  80.     "def weenet(data_shape, num_filters, kdim, stride,\n",
  81.     "           pad, groups, dtype='float32'):\n",
  82.     "    \"\"\"Function to construct a WeeNet\"\"\"\n",
  83.     "    data = relay.var(\"data\", shape=data_shape, dtype=dtype)\n",
  84.     "    body = conv_block(data, 'conv_block_1', num_filters, (kdim,kdim), \n",
  85.     "                      (stride,stride), (pad,pad), groups, dtype)\n",
  86.     "    return relay.Function(relay.ir_pass.free_vars(body), body)"
  87.    ]
  88.   },
  89.   {
  90.    "cell_type": "code",
  91.    "execution_count": 5,
  92.    "metadata": {},
  93.    "outputs": [],
  94.    "source": [
  95.     "def get_workload(data_shape, num_filters, kdim, stride,\n",
  96.     "           pad, groups, dtype='float32'):\n",
  97.     "    net = weenet(data_shape, num_filters, kdim, stride, pad, groups, dtype)\n",
  98.     "    return create_workload(net)"
  99.    ]
  100.   },
  101.   {
  102.    "cell_type": "code",
  103.    "execution_count": 6,
  104.    "metadata": {},
  105.    "outputs": [],
  106.    "source": [
  107.     "def compile_model(net, params, data, opt_level=1, ctx=tvm.cpu(0), target='llvm'):\n",
  108.     "    with relay.build_config(opt_level=opt_level):\n",
  109.     "        graph, lib, params = relay.build_module.build(\n",
  110.     "            net, target, params=params)\n",
  111.     "    # create module\n",
  112.     "    module = graph_runtime.create(graph, lib, ctx)\n",
  113.     "    # set input and parameters\n",
  114.     "    module.set_input(\"data\", data)\n",
  115.     "    module.set_input(**params)\n",
  116.     "    # run\n",
  117.     "    module.run()\n",
  118.     "    # get output\n",
  119.     "    out = module.get_output(0, tvm.nd.empty(out_shape)).asnumpy()\n",
  120.     "    return out, module"
  121.    ]
  122.   },
  123.   {
  124.    "cell_type": "code",
  125.    "execution_count": 7,
  126.    "metadata": {},
  127.    "outputs": [],
  128.    "source": [
  129.     "# create random input\n",
  130.     "np.random.seed(0)\n",
  131.     "data_shape = (batch_size, in_c, in_h, in_w)\n",
  132.     "data = np.random.uniform(-1, 1, size=data_shape).astype(\"float32\")\n"
  133.    ]
  134.   },
  135.   {
  136.    "cell_type": "code",
  137.    "execution_count": 8,
  138.    "metadata": {},
  139.    "outputs": [
  140.     {
  141.      "name": "stderr",
  142.      "output_type": "stream",
  143.      "text": [
  144.       "Cannot find config for target=llvm, workload=('conv2d', (1, 512, 28, 28, 'float32'), (512, 512, 3, 3, 'float32'), (1, 1), (0, 0), (1, 1), 'NCHW', 'float32'). A fallback configuration is used, which may bring great performance regression.\n",
  145.       "WARNING:autotvm:Cannot find config for target=llvm, workload=('depthwise_conv2d_nchw', (1, 512, 28, 28, 'float32'), (512, 1, 3, 3, 'float32'), (1, 1), (0, 0), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.\n"
  146.      ]
  147.     }
  148.    ],
  149.    "source": [
  150.     "results = [None] * len(num_groups)\n",
  151.     "ctx=tvm.cpu(0)\n",
  152.     "opt_level = 1\n",
  153.     "\n",
  154.     "for i, g in enumerate(num_groups):\n",
  155.     "    net, params = get_workload(data_shape, num_filters, kdim, stride, pad, g)\n",
  156.     "    out, m = compile_model(net, params, data, opt_level, ctx)\n",
  157.     "    f = m.module.time_evaluator('run', ctx)\n",
  158.     "    results[i] = f().mean\n",
  159.     "    "
  160.    ]
  161.   },
  162.   {
  163.    "cell_type": "code",
  164.    "execution_count": 9,
  165.    "metadata": {},
  166.    "outputs": [
  167.     {
  168.      "data": {
  169.       "text/plain": [
  170.        "[0.22498705270000002,\n",
  171.        " 0.8115864547999999,\n",
  172.        " 0.4177320249,\n",
  173.        " 0.2104043008,\n",
  174.        " 0.1027479518,\n",
  175.        " 0.0514274909,\n",
  176.        " 0.0266668331,\n",
  177.        " 0.0042375367,\n",
  178.        " 0.0018110453,\n",
  179.        " 0.0008504548000000001]"
  180.       ]
  181.      },
  182.      "execution_count": 9,
  183.      "metadata": {},
  184.      "output_type": "execute_result"
  185.     }
  186.    ],
  187.    "source": [
  188.     "results"
  189.    ]
  190.   },
  191.   {
  192.    "cell_type": "code",
  193.    "execution_count": null,
  194.    "metadata": {},
  195.    "outputs": [],
  196.    "source": []
  197.   }
  198.  ],
  199.  "metadata": {
  200.   "kernelspec": {
  201.    "display_name": "meth",
  202.    "language": "python",
  203.    "name": "meth"
  204.   },
  205.   "language_info": {
  206.    "codemirror_mode": {
  207.     "name": "ipython",
  208.     "version": 3
  209.    },
  210.    "file_extension": ".py",
  211.    "mimetype": "text/x-python",
  212.    "name": "python",
  213.    "nbconvert_exporter": "python",
  214.    "pygments_lexer": "ipython3",
  215.    "version": "3.5.3"
  216.   }
  217.  },
  218.  "nbformat": 4,
  219.  "nbformat_minor": 2
  220. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top