Advertisement
Guest User

Untitled

a guest
Apr 20th, 2023
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.20 KB | None | 0 0
  1. import torch
  2. from triton.testing import do_bench
  3.  
  4. def get_backend(backend):
  5. if backend == "math":
  6. backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
  7. elif backend == "flash":
  8. backend = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
  9. elif backend == "mem_efficient":
  10. backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
  11. return backend
  12.  
  13. import torch.nn.functional as F
  14. def bench_size(batch_size, n_heads, seq_len, head_dim, dtype):
  15. query = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda', dtype=dtype, requires_grad=True)
  16. key = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda', dtype=dtype, requires_grad=True)
  17. value = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda', dtype=dtype, requires_grad=True)
  18.  
  19. def f():
  20. out = F.scaled_dot_product_attention(query, key, value, dropout_p=0, is_causal=True)
  21. out.sum().backward()
  22.  
  23. # print(f"batch_size={batch_size}, n_heads={n_heads}, seq_len={seq_len}, head_dim={head_dim}, dtype={dtype}")
  24. print(f"{batch_size}, {n_heads}, {seq_len}, {head_dim}")
  25. with get_backend("flash"):
  26. flash = do_bench(f)[0]
  27.  
  28. with get_backend("mem_efficient"):
  29. mem_efficient = do_bench(f)[0]
  30.  
  31. print(f'{"FLASH" if flash < mem_efficient else "MEM"}, {flash/mem_efficient}', flush=True)
  32.  
  33. import random
  34. import math
  35.  
  36. def get_logrand(min_size, max_size, divisible):
  37. min_exp = math.ceil(math.log(min_size, 2))
  38. max_exp = int(math.log(max_size, 2))
  39. val = 2**(random.random()*(max_exp - min_exp) + min_exp)
  40. return int((val+divisible - 1) // divisible * divisible)
  41.  
  42. for _ in range(1000):
  43. batch_size = get_logrand(1, 128, 1)
  44. num_heads = get_logrand(8, 96, 8)
  45. seq_len = get_logrand(128, 4096, 256)
  46. head_dim = get_logrand(64, 128, 8)
  47.  
  48. if batch_size * num_heads * seq_len * head_dim > 1e9:
  49. print(f"skipping {batch_size}, {num_heads}, {seq_len}, {head_dim}")
  50. continue
  51. bench_size(batch_size, num_heads, seq_len, head_dim, torch.bfloat16)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement