Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from triton.testing import do_bench
- def get_backend(backend):
- if backend == "math":
- backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
- elif backend == "flash":
- backend = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
- elif backend == "mem_efficient":
- backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
- return backend
- import torch.nn.functional as F
- def bench_size(batch_size, n_heads, seq_len, head_dim, dtype):
- query = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda', dtype=dtype, requires_grad=True)
- key = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda', dtype=dtype, requires_grad=True)
- value = torch.randn(batch_size, n_heads, seq_len, head_dim, device='cuda', dtype=dtype, requires_grad=True)
- def f():
- out = F.scaled_dot_product_attention(query, key, value, dropout_p=0, is_causal=True)
- out.sum().backward()
- # print(f"batch_size={batch_size}, n_heads={n_heads}, seq_len={seq_len}, head_dim={head_dim}, dtype={dtype}")
- print(f"{batch_size}, {n_heads}, {seq_len}, {head_dim}")
- with get_backend("flash"):
- flash = do_bench(f)[0]
- with get_backend("mem_efficient"):
- mem_efficient = do_bench(f)[0]
- print(f'{"FLASH" if flash < mem_efficient else "MEM"}, {flash/mem_efficient}', flush=True)
- import random
- import math
- def get_logrand(min_size, max_size, divisible):
- min_exp = math.ceil(math.log(min_size, 2))
- max_exp = int(math.log(max_size, 2))
- val = 2**(random.random()*(max_exp - min_exp) + min_exp)
- return int((val+divisible - 1) // divisible * divisible)
- for _ in range(1000):
- batch_size = get_logrand(1, 128, 1)
- num_heads = get_logrand(8, 96, 8)
- seq_len = get_logrand(128, 4096, 256)
- head_dim = get_logrand(64, 128, 8)
- if batch_size * num_heads * seq_len * head_dim > 1e9:
- print(f"skipping {batch_size}, {num_heads}, {seq_len}, {head_dim}")
- continue
- bench_size(batch_size, num_heads, seq_len, head_dim, torch.bfloat16)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement