Advertisement
Guest User

Untitled

a guest
Apr 20th, 2023
147
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.89 KB | None | 0 0
  1. data = open('data_sdpa_fp16.txt', 'r').read()
  2. data = data.strip().split('\n')
  3. data = [i for i in data if "skipping" not in i]
  4. x = data[::2]
  5. y = data[1::2]
  6. x = x[:len(y)]
  7. x = [[int(i) for i in i.split(', ')] for i in x]
  8.  
  9. # Filter out cases where perf difference is insignificant
  10. nx = []
  11. ny = []
  12. for xi, yi in zip(x, y):
  13. yi = yi.split(', ')
  14. if float(yi[1]) > 1.1 or float(yi[1]) < 0.9:
  15. nx.append(xi)
  16. ny.append(yi[0])
  17.  
  18. x, y = nx, ny
  19. from sklearn.preprocessing import PolynomialFeatures
  20.  
  21. feature_names = ['batch_size', 'n_heads', 'seq_len', 'head_dim']
  22. def add_inverse(x, feature_names):
  23. for i in x:
  24. for j in range(len(i)):
  25. i.append(1/i[j])
  26. feature_names += [f"1/{i}" for i in feature_names]
  27. return x, feature_names
  28.  
  29. def add_polynomial(x, feature_names, degree=2):
  30. poly = PolynomialFeatures(degree=degree, interaction_only=True)
  31. x = poly.fit_transform(x)
  32. feature_names = list(poly.get_feature_names_out(input_features=feature_names))
  33. return x, feature_names
  34.  
  35. x_aug, feature_names = add_inverse(x, feature_names)
  36. x_aug, feature_names = add_polynomial(x_aug, feature_names, degree=3)
  37.  
  38. y = [i == "FLASH" for i in y]
  39. from sklearn.model_selection import train_test_split
  40. from sklearn.model_selection import train_test_split
  41. from sklearn.tree import DecisionTreeClassifier, export_text
  42. from sklearn.metrics import accuracy_score
  43.  
  44. X_train, X_val, y_train, y_val = train_test_split(
  45. x_aug, y, test_size=0.3, random_state=42)
  46.  
  47. param_grid = {
  48. 'max_depth': [1],
  49. 'min_samples_split': [2, 3, 4],
  50. 'min_samples_leaf': [1, 2, 3],
  51. }
  52.  
  53. def cur_predict(x):
  54. batch_size, num_heads, query_lengths, head_dim, *_ = x
  55. threads_flash = batch_size * num_heads
  56. threads_cutlass = threads_flash * (query_lengths // 64)
  57. more_threads_cutlass = (threads_cutlass // 2) >= threads_flash
  58. small_threads_flash = threads_flash < 60
  59. large_head_dim = head_dim == 128
  60. is_flash = (small_threads_flash and more_threads_cutlass) or large_head_dim
  61. return is_flash == y
  62.  
  63. def eval_f(f, X, Y):
  64. y_pred = []
  65. for x, y in zip(X, Y):
  66. y_pred.append(f(x))
  67. return accuracy_score(Y, y_pred)
  68.  
  69.  
  70. results = []
  71.  
  72. # Iterate over all combinations of hyperparameters and train a decision tree classifier
  73. for max_depth in param_grid['max_depth']:
  74. for min_samples_split in param_grid['min_samples_split']:
  75. for min_samples_leaf in param_grid['min_samples_leaf']:
  76. # Train a decision tree classifier with the current hyperparameters
  77. dt = DecisionTreeClassifier(max_depth=max_depth,
  78. min_samples_split=min_samples_split,
  79. min_samples_leaf=min_samples_leaf,
  80. random_state=42)
  81. dt.fit(X_train, y_train)
  82.  
  83. # Evaluate the accuracy on the validation set
  84. y_pred = dt.predict(X_val)
  85. acc = accuracy_score(y_val, y_pred)
  86.  
  87. # Append the results to the list
  88. results.append({
  89. 'max_depth': max_depth,
  90. 'min_samples_split': min_samples_split,
  91. 'min_samples_leaf': min_samples_leaf,
  92. 'accuracy': acc,
  93. 'classifier': dt,
  94. })
  95.  
  96. # Sort the results by accuracy in descending order
  97. results = sorted(results, key=lambda x: x['accuracy'], reverse=True)
  98.  
  99. # Print the results
  100. for result in results:
  101. print(f"max_depth={result['max_depth']}, min_samples_split={result['min_samples_split']}, min_samples_leaf={result['min_samples_leaf']}, accuracy={result['accuracy']}")
  102. print(export_text(results[0]['classifier'], feature_names=feature_names))
  103. # print(export_text(results[0]['classifier'], feature_names=['batch_size', 'n_heads', 'seq_len', 'head_dim', 'flash_threads', 'seq_len / flash_threads']))
  104. print(eval_f(cur_predict, x, y))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement