Advertisement
Guest User

Untitled

a guest
Apr 21st, 2019
113
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.72 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. """
  4. simple-random-nasbench.py
  5. """
  6.  
  7. import numpy as np
  8. import pandas as pd
  9. from tqdm import tqdm, trange
  10. from matplotlib import pyplot as plt
  11.  
  12. from nasbench.api import NASBench
  13.  
  14. np.random.seed(123)
  15.  
  16. # --
  17. # Helpers
  18.  
  19. cummax = np.maximum.accumulate
  20. cumsum = np.cumsum
  21.  
  22. def cumargmax(x):
  23. z = np.arange(x.shape[0], dtype=np.float)
  24. z[x != cummax(x)] = np.nan
  25. z = pd.Series(z).fillna(method='ffill')
  26. return z.values.astype(int)
  27.  
  28. def sample_one_column(x):
  29. i = np.arange(x.shape[0])
  30. j = np.random.choice(x.shape[1], x.shape[0], replace=True)
  31. return x[(i, j)]
  32.  
  33. # --
  34. # IO
  35.  
  36. path = 'data/nasbench_only108.tfrecord'
  37. api = NASBench(path)
  38.  
  39. # --
  40. # ETL
  41.  
  42. hashes = np.array(list(api.hash_iterator()))
  43. n_models = len(hashes)
  44. n_runs = 3
  45.  
  46. test_acc = np.zeros((n_models, n_runs))
  47. valid_acc = np.zeros((n_models, n_runs))
  48. cost = np.zeros((n_models, n_runs))
  49.  
  50. for i, h in tqdm(enumerate(hashes), total=len(hashes)):
  51. _, result = api.get_metrics_from_hash(h)
  52. result = result[108]
  53.  
  54. valid_acc[i] = [r['final_validation_accuracy'] for r in result]
  55. test_acc[i] = [r['final_test_accuracy'] for r in result]
  56. cost[i] = [r['final_training_time'] for r in result]
  57.  
  58.  
  59. mean_valid_acc = valid_acc.mean(axis=-1)
  60. mean_test_acc = test_acc.mean(axis=-1)
  61. mean_cost = cost.mean(axis=-1)
  62.  
  63. # --
  64. # Random runs
  65.  
  66. def random_run(valid_acc, mean_test_acc, mean_cost, models_per_run=int(1e4)):
  67. n_models = valid_acc.shape[0]
  68.  
  69. # Randomly sample `models_per_run` architectures w/o replacement
  70. sel = np.random.choice(n_models, models_per_run, replace=False)
  71.  
  72. # Get 1 of the validation accuracies for the models
  73. valid_acc_run = sample_one_column(valid_acc[sel])
  74.  
  75. # Compute index of arch. w/ best validation accuracy so far
  76. best_val_idx = cumargmax(valid_acc_run)
  77.  
  78. # Compute mean test accuracy for model w/ best validation accuracy so far
  79. test_acc_run = mean_test_acc[sel][best_val_idx]
  80.  
  81. # Cumulative cost of run
  82. cum_cost_run = cumsum(mean_cost[sel])
  83.  
  84. return test_acc_run, cum_cost_run
  85.  
  86.  
  87. rand_results = [random_run(valid_acc, mean_test_acc, mean_cost) for _ in trange(500)]
  88.  
  89. test_acc_runs, cum_cost_runs = list(zip(*rand_results))
  90.  
  91. # Average test acc of selected models
  92. mean_test_acc_run = np.stack(test_acc_runs).mean(axis=0)
  93.  
  94. # Average cumulative cost of random runs
  95. mean_cum_cost_run = np.stack(cum_cost_runs).mean(axis=0)
  96.  
  97. _ = plt.plot(mean_cum_cost_run, mean_test_acc.max() - mean_test_acc_run, c='red')
  98. _ = plt.xscale('log')
  99. _ = plt.yscale('log')
  100. _ = plt.ylim(1e-3, 1e-1)
  101. _ = plt.legend()
  102. _ = plt.grid(which='both', alpha=0.5)
  103. _ = plt.axhline(4e-3, c='grey', alpha=0.26)
  104. plt.show()
  105.  
  106. # Performance at 1e7 seconds is `5.5 * 1e-3`, compared to about `4.1 * 1e-3` in the paper
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement