diff -ru --ignore-trailing-space /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/attn_qk_int8_per_block.py sageattention/attn_qk_int8_per_block.py --- /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/attn_qk_int8_per_block.py 2024-11-12 02:54:20.000000000 +1100 +++ sageattention/attn_qk_int8_per_block.py 2025-08-10 17:22:54.569382100 +1000 @@ -2,6 +2,238 @@ import triton import triton.language as tl +import os +import json +import time + +from braceexpand import braceexpand + +# gfx1100: best config selected: BLOCK_M: 64, BLOCK_N: 16, STAGE: 1, waves_per_eu: 4, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None; +# within original patch limits: best config selected: BLOCK_M: 32, BLOCK_N: 16, STAGE: 1, waves_per_eu: 4, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None; +# 512x512x81 best config selected: BLOCK_M: 32, BLOCK_N: 16, STAGE: 1, waves_per_eu: 3, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None; +# best config selected: BLOCK_M: 32, BLOCK_N: 16, STAGE: 1, waves_per_eu: 3, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None; +# 512x512x129 best config selected: BLOCK_M: 32, BLOCK_N: 16, STAGE: 1, waves_per_eu: 3, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None; +# when original patch limits extends just enough to allow 64x16: +# best config selected: BLOCK_M: 64, BLOCK_N: 16, STAGE: 1, waves_per_eu: 4, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None; +# in quad best config selected: BLOCK_M: 64, BLOCK_N: 16, STAGE: 1, waves_per_eu: 4, num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None; +# in sage w native wan vace best config selected: BLOCK_M: 64, BLOCK_N: 16, STAGE: 1, waves_per_eu: 3, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None; +# in sage w kijai num_s=4 best config selected: BLOCK_M: 64, BLOCK_N: 16, STAGE: 1, waves_per_eu: 4, num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None; + +BLACKLIST_CACHE = None +BLACKLIST_MTIME = None +BLACKLIST_PATH = os.path.join(os.path.dirname(__file__), "blacklist.json") + +def load_blacklist_if_updated(): + global BLACKLIST_CACHE, BLACKLIST_MTIME + + try: + mtime = os.path.getmtime(BLACKLIST_PATH) + if BLACKLIST_CACHE is None or mtime != BLACKLIST_MTIME: + with open(BLACKLIST_PATH, "r") as f: + BLACKLIST_CACHE = json.load(f) + BLACKLIST_MTIME = mtime + print(f"\033[93m[debug] sage-attn: Reloaded from {BLACKLIST_PATH}\033[0m") + except FileNotFoundError: + # Create the file with an example (with silly values) + default_blacklist = [ + {"BLOCK_M": 42, "BLOCK_N": 24, "num_stages": 1, "num_ctas": 1, "waves_per_eu": 7}, + {"waves_per_eu": 14} + ] + with open(BLACKLIST_PATH, "w") as f: + json.dump(default_blacklist, f, indent=2) + BLACKLIST_CACHE = default_blacklist + BLACKLIST_MTIME = os.path.getmtime(BLACKLIST_PATH) + print(f"\033[93m[debug] sage-attn: Created new blacklist.json with default entry at {BLACKLIST_PATH}\033[0m") + + +# cuda:0 +DEVICE = triton.runtime.driver.active.get_active_torch_device() +# True +IS_HIP = triton.runtime.driver.active.get_current_target().backend == "hip" +# gfx1100 +ARCH = triton.runtime.driver.active.get_current_target().arch +# SHM_LIMIT = 32384 if ARCH.startswith('gfx10') else 65536 +# I'm not actually sure that the gfx1030 has a smaller SHM (LDS) anymore, need to check for myself. + +# shm_limit will be 65536 (it's an RDNA thing)... not sure about CDNA though +SHM_LIMIT = triton.runtime.driver.active.utils.get_device_properties(triton.runtime.driver.active.get_current_device())["max_shared_mem"] +_BM_SIZE = 32 if ARCH.startswith('gfx10') else 64 +_BN_SIZE = 16 + +BM_SIZE = int(os.environ.get('SAGE_BM_SIZE', _BM_SIZE)) +BN_SIZE = int(os.environ.get('SAGE_BN_SIZE', _BN_SIZE)) + +def braceexpandlist(be): + return [int(x) for x in braceexpand(be)] + +SAGE_ATTENTION_BLOCK_M = braceexpandlist(os.environ.get('SAGE_ATTENTION_BLOCK_M', str(BM_SIZE))) +SAGE_ATTENTION_BLOCK_N = braceexpandlist(os.environ.get('SAGE_ATTENTION_BLOCK_N', str(BN_SIZE))) +SAGE_ATTENTION_NUM_WARPS = braceexpandlist(os.environ.get('SAGE_ATTENTION_NUM_WARPS', '{2,4}')) +SAGE_ATTENTION_NUM_STAGES = braceexpandlist(os.environ.get('SAGE_ATTENTION_NUM_STAGES', '{1,2,3,4}')) +SAGE_ATTENTION_STAGE = braceexpandlist(os.environ.get('SAGE_ATTENTION_STAGE', '1')) +SAGE_ATTENTION_WAVES_PER_EU = braceexpandlist(os.environ.get('SAGE_ATTENTION_WAVES_PER_EU', '{3,4}')) + +# DEBUG +# print(f"\033[93m[debug] sage-attn.attn_qk_int8_per_block: SAGE_ATTENTION_NUM_WARPS: {SAGE_ATTENTION_NUM_WARPS}\033[0m") +# print(f"\033[93m[debug] sage-attn.attn_qk_int8_per_block: SAGE_ATTENTION_BLOCK_M: {SAGE_ATTENTION_BLOCK_N}\033[0m") +# print(f"\033[93m[debug] sage-attn.attn_qk_int8_per_block: BM_SIZE {BM_SIZE}\033[0m") +# print(f"\033[93m[debug] sage-attn.attn_qk_int8_per_block: BN_SIZE {BN_SIZE}\033[0m") +# print(f"\033[93m[debug] sage-attn.attn_qk_int8_per_block.configs: DEVICE: {DEVICE}, IS_HIP: {IS_HIP}, ARCH: {ARCH}, LDS: {SHM_LIMIT}\033[0m") + +# Autotune Here +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'STAGE':S, 'waves_per_eu':wpe}, num_warps=nw, num_stages=ns) \ + for BM in SAGE_ATTENTION_BLOCK_M \ + for BN in SAGE_ATTENTION_BLOCK_N \ + for nw in SAGE_ATTENTION_NUM_WARPS \ + for ns in SAGE_ATTENTION_NUM_STAGES \ + for S in SAGE_ATTENTION_STAGE \ + for wpe in SAGE_ATTENTION_WAVES_PER_EU +] +# Selected config for my gfx1100 is always: +# BLOCK_M: 64, BLOCK_N: 16, STAGE: 1, waves_per_eu: 4, num_warps: 4, num_ctas: 1, num_stages: 3 or 4, maxnreg: None; +# +# IMPORTANT: These configurations were developed for pytorch 2.8, they die a terrible death post mmir under pytorch 2.7 +# +# IMPORTANT: We can't actually choose an optimum BLOCK_M/BLOCK_N, as these values are hardcoded into other functions. +# +# ## General Waffling Comments (as proclaimed by ChatGPT 4o) +# +# | Constraint | Description | +# | ----------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +# | **VGPR pressure** | Vector general-purpose registers; blow past the VGPR limit and your waves get throttled hard (or you spill to memory, which is poison). | +# | **SGPR pressure** | Scalar registers can limit wave count per CU too, especially on RDNA3 (gfx1100+) where each thread gets more scalar state. | +# | **LDS (shared memory)** | Shared memory is smaller per CU than on most NVIDIA GPUs (e.g., 32–64KB). High tile size = higher usage. | +# | **Occupancy** | CUs run fewer waves if resources (VGPR/LDS) are maxed. Fewer waves = lower ability to hide latency. | +# +# ### On `num_warps` for AMD +# +# * There are no warps, just wavefronts — each being 64 threads. +# * num_warps still exists as a Triton abstraction, but its semantics are fuzzy on AMD. +# * Triton may lower num_warps=2 into 128 threads, but it depends on the compiler pass and kernel shape. And this number does not get exposed back in Python. +# * So: you can’t query num_threads directly, and trying to estimate actual thread count on AMD requires inspecting the MLIR lowered kernel (or profiling it live). +# +# ## General Rules (as proclaimed by ChatGPT 4o, the ultimate opinionated idiot in this field) +# waves_per_eu=4 is better for latency hiding if you don’t run out of LDS +# num_stages=3–4 yields better pipeline throughput, but requires sufficient LDS and VGPRs +# block_area=2048 only viable if num_warps>=4 +# +# ## Avoid +# Any config with shm_bytes > 65536 — that’s over the LDS limit +# num_warps=2 on large blocks (1024–2048), it’s under-occupying EUs +# waves_per_eu=3 — it’s fine for testing but less efficient than 4 on RDNA2+ + +def keep(conf): + global ARCH + + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + HEAD_DIM = 128 # Hardcoded in kernel, only used here for calculated shm usage. Can be obtained by some triton.x.y.z trick too. + STAGE = conf.kwargs["STAGE"] + waves_per_eu = conf.kwargs["waves_per_eu"] + BLOCK_AREA = BLOCK_M * BLOCK_N + shm_bytes = BLOCK_M * HEAD_DIM * conf.num_stages * 3//8 + shm_str = f"{shm_bytes:6}" + + def conf_str(): + return (f"BLOCK_M: {BLOCK_M}, BLOCK_N: {BLOCK_N}, STAGE: {STAGE}, " + f"waves_per_eu: {waves_per_eu}, num_warps: {conf.num_warps}, " + f"num_ctas: {conf.num_ctas}, num_stages: {conf.num_stages}, " + f"shm_bytes: {shm_str}") + + # ChatGPT 4o: Triton assumes 32 threads per warp unless you rewrite the kernel to support otherwise + total_threads = conf.num_warps * 32 + if total_threads < 64: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: total_threads ({total_threads}) < 64 → too few to fill one AMD wavefront\033[0m") + return False + + # ChatGPT 4o: Check to see that size won't exceed 64k (somewhat of an absolute limit for everything up to and including RDNA4) + # (but what about lower limits for cheap RDNA2 cards?) + if shm_bytes > SHM_LIMIT: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: shm_bytes {shm_bytes} > 64KiB\033[0m") + return False + + # do not keep too high block area, any higher doesnt seem to help for navi21 + # (this comment was originally attached to a limit of 1024 for BLOCK_AREA, though has not been disproven on my gfx1100) + if BLOCK_AREA > 16384: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: BLOCK_AREA {BLOCK_AREA} > 4096\033[0m") + return False + + # do not keep 'mirror image' configs (ie keep [64,32] and discard [32,64]) + if BLOCK_M < BLOCK_N: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: BLOCK_M < BLOCK_N\033[0m") + return False + + # do not keep skinny sizes for now (unknown reasoning by original author of patches) + # my gfx1100 always picks 64 x 16, so it might like even skinnier blocks + if (BLOCK_M // BLOCK_N) >= 8: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: BLOCK_M // BLOCK_N >= 8\033[0m") + return False + + # do not keep configs where num_warps is too high or low (only disabling too low ATM) + # (this was extrapolated from original author, who supplied strict rules for 1024 (2 warps) and 2048 (4 warps). these + # probably need to be re-thought for newer RDNA cards, as my gfx1100 likes 4 warps with BLOCK_AREA = 1024) + if BLOCK_AREA >= 4096: + if conf.num_warps < 8: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: BLOCK_AREA >= 4096 and num_warps < 8\033[0m") + return False + elif BLOCK_AREA >= 2048: + if conf.num_warps < 4: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: BLOCK_AREA >= 2048 and num_warps < 4\033[0m") + return False + + # avoid num_warps=2 on large blocks (1024–2048), it’s under-occupying EUs + elif BLOCK_AREA >= 1024: + # my gfx1100 always picks 4 warps for 64x16, 2 warps for 32x16 + if conf.num_warps < 4: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: BLOCK_AREA >= 1024 and num_warps < 4\033[0m") + return False + else: + # my gfx1100 always picks 4 warps on 64x16, 2 warps on 32x16 + if conf.num_warps > 2: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: BLOCK_AREA < 1024 and num_warps > 2\033[0m") + return False + + # we're only offering 3 or 4 waves_per_eu max, so this is a superfluous check, mainly included for + # attracting opinions on optimal waves_per_eu usage. + # ChatGPT 4o has claimed 3 as being a useless value, and an essential one. + if waves_per_eu > 4: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: waves_per_eu {waves_per_eu} > 4\033[0m") + return False + + load_blacklist_if_updated() + + for item in BLACKLIST_CACHE: + blacklist = True + for key, value in item.items(): + actual = None + if key in conf.kwargs: + actual = conf.kwargs[key] + elif hasattr(conf, key): + actual = getattr(conf, key) + elif key == "BLOCK_M": + actual = BLOCK_M + elif key == "BLOCK_N": + actual = BLOCK_N + elif key == "HEAD_DIM": + actual = HEAD_DIM + else: + print(f"\033[93m[debug] [BLACKLIST] Unknown key in config: {key}, skipping this rule.\033[0m") + blacklist = False + break + + if actual != value: + blacklist = False + break + + if blacklist: + print(f"\033[93m[debug] sage-attn: [SKIP] {conf_str()} reason: blacklisted by file\033[0m") + return False + + print(f"\033[93m[debug] sage-attn: {conf_str()}\033[0m") + return True + + @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, @@ -36,6 +268,10 @@ V_ptrs += BLOCK_N * stride_vn return acc, l_i +@triton.autotune( + list(filter(keep, configs)), + key=['qo_len', 'kv_len', 'h_qo'] +) @triton.jit def _attn_fwd(Q, K, V, Q_scale, K_scale, Out, stride_qz, stride_qh, stride_qn, @@ -81,8 +317,10 @@ tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) def forward(q, k, v, q_scale, k_scale, tensor_layout="HND", output_dtype=torch.float16): - BLOCK_M = 128 - BLOCK_N = 64 + global BM_SIZE + global BN_SIZE + BLOCK_M = BM_SIZE + BLOCK_N = BN_SIZE stage = 1 o = torch.empty(q.shape, dtype=output_dtype, device=q.device) @@ -109,7 +347,7 @@ HEAD_DIM_K = head_dim num_kv_groups = h_qo // h_kv - grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b) + grid = lambda META: (triton.cdiv(qo_len, META['BLOCK_M']), h_qo, b) _attn_fwd[grid]( q, k, v, q_scale, k_scale, o, stride_bz_q, stride_h_q, stride_seq_q, @@ -118,8 +356,6 @@ stride_bz_o, stride_h_o, stride_seq_o, qo_len, kv_len, h_qo, num_kv_groups, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, - STAGE=stage, - num_warps=4 if head_dim == 64 else 8, - num_stages=3 if head_dim == 64 else 4) + HEAD_DIM=HEAD_DIM_K, + ) return o \ No newline at end of file diff -ru --ignore-trailing-space /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/attn_qk_int8_per_block_causal.py sageattention/attn_qk_int8_per_block_causal.py --- /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/attn_qk_int8_per_block_causal.py 2024-11-12 02:54:20.000000000 +1100 +++ sageattention/attn_qk_int8_per_block_causal.py 2025-08-09 05:19:17.268906800 +1000 @@ -44,7 +44,7 @@ v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n)) p = p.to(tl.float16) - acc += tl.dot(p, v, out_dtype=tl.float16) + acc += tl.dot(p, v, out_dtype=tl.float32) # zlp m_i = m_ij K_ptrs += BLOCK_N * stride_kn K_scale_ptr += 1 @@ -65,6 +65,8 @@ ): start_m = tl.program_id(0) + print("\033[93m%%% [debug] {}::_attn_fwd called\033[0m".format(__file__)) + off_z = tl.program_id(2).to(tl.int64) off_h = tl.program_id(1).to(tl.int64) @@ -102,9 +104,12 @@ tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) def forward(q, k, v, q_scale, k_scale, tensor_layout="HND", output_dtype=torch.float16): - BLOCK_M = 128 - BLOCK_N = 64 + global BM_SIZE + global BN_SIZE + BLOCK_M = BM_SIZE # sfinktah + BLOCK_N = BN_SIZE # sfinktah stage = 3 + print(f"\033[93m%%% [debug] sage-attn.attn_qk_int8_per_block_causal.forward: BLOCK_M={BLOCK_M} BLOCK_N={BLOCK_N}\033[0m") o = torch.empty(q.shape, dtype=output_dtype, device=q.device) diff -ru --ignore-trailing-space /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/core.py sageattention/core.py --- /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/core.py 2024-11-20 13:25:52.000000000 +1100 +++ sageattention/core.py 2025-08-07 06:37:21.103920500 +1000 @@ -79,7 +79,7 @@ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." headdim = q.size(-1) - assert headdim in [64, 96, 128], "headdim should be in [64, 96, 128]." + assert headdim in [64, 96, 128], "headdim should be in [64, 96, 128], not " + str(headdim) # assert last dim is contiguous assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." diff -ru --ignore-trailing-space /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/quant_per_block.py sageattention/quant_per_block.py --- /tmp/tmp.xtWqWmVq9L/unzipped/sageattention/quant_per_block.py 2024-11-15 03:34:50.000000000 +1100 +++ sageattention/quant_per_block.py 2025-08-10 17:17:18.094870300 +1000 @@ -1,6 +1,7 @@ import torch import triton import triton.language as tl +import os @triton.jit def quant_per_block_int8_kernel(Input, Output, Scale, L, @@ -30,7 +31,18 @@ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) tl.store(scale_ptrs, scale) -def per_block_int8(q, k, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"): +ARCH = triton.runtime.driver.active.get_current_target().arch + +_BM_SIZE = 32 if ARCH.startswith('gfx10') else 64 +_BN_SIZE = 16 + +BM_SIZE = int(os.environ.get('SAGE_BM_SIZE', _BM_SIZE)) +BN_SIZE = int(os.environ.get('SAGE_BN_SIZE', _BN_SIZE)) + +print(f"\033[93m[debug] sage-attn.quant_per_block: BM_SIZE {BM_SIZE}\033[0m") +print(f"\033[93m[debug] sage-attn.quant_per_block: BN_SIZE {BN_SIZE}\033[0m") + +def per_block_int8(q, k, BLKQ=BM_SIZE, BLKK=BN_SIZE, sm_scale=None, tensor_layout="HND"): q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)