Skip to content

perf(cuda/moe): backend-aware default for the fused decode-MoE Dff threshold#643

Merged
inureyes merged 1 commit into
mainfrom
feature/issue-626-cuda-fused-moe-dff
Jul 2, 2026
Merged

perf(cuda/moe): backend-aware default for the fused decode-MoE Dff threshold#643
inureyes merged 1 commit into
mainfrom
feature/issue-626-cuda-fused-moe-dff

Conversation

@inureyes

@inureyes inureyes commented Jul 2, 2026

Copy link
Copy Markdown
Member

Summary

The fused single-token decode-MoE kernel declines to gather_qmm when the expert intermediate (Dff) exceeds MLXCEL_FUSED_MOE_MAX_DFF, whose default was a flat 4096 tuned for Metal. On CUDA the crossover is higher, so mid-size-expert models silently lost the kernel. This makes the default backend-aware: 4096 on Metal, 8192 on CUDA (re-measured on GB10), with the env var still overriding both backends.

What changed

  • src/models/switch_layers.rs: split the Dff-bound resolution into a pure, unit-tested fused_moe_max_dff_from(env, metal_available); the default is now FUSED_MOE_MAX_DFF_METAL (4096) when Metal is available and FUSED_MOE_MAX_DFF_CUDA (8192) otherwise. An explicit MLXCEL_FUSED_MOE_MAX_DFF still wins on both backends. Added a unit test for the backend-aware default and env override.
  • src/lib/mlxcel-core: new metal_is_available() FFI (cxx bridge decl in src/lib.rs, header decl in cpp/mlx_cxx_bridge.h, impl in cpp/mlx_cxx_kernels.cpp) mirroring the mlx::core::metal::is_available() gate the kernel dispatch already uses, so the Rust default follows the live backend without a compile-time cfg.
  • Docs: docs/environment-variables.md and docs/benchmark_results/fused-moe-decode-kernel-design.md updated with the backend-specific defaults and a dated (2026-07-03) sweep addendum; raw per-run data committed at docs/benchmark_results/fused-moe-decode-dff-sweep-gb10-2026-07-03.csv.

Re-measured CUDA crossover (GB10, DGX Spark, sm_121, MLX pin e9463bb)

mlxcel-bench-decode, 100 decode tokens after 20-token warmup, median of 3; fallback forced with MLXCEL_FUSED_MOE_MAX_DFF=1, fused with =20000.

Dff model gather_qmm fused fused/fallback
768 qwen3-30b-a3b 91.00 91.34 1.00
1792 lfm2-8b-a1b 140.35 157.76 1.12
2880 gpt-oss-20b (mxfp4 control) 78.59 78.00 0.99
6400 phi-3.5-moe 51.09 53.62 1.05
8192 llama-4-scout-17b 21.23 21.08 0.99
14336 mixtral-8x7b 28.23 27.89 0.99

The crossover collapsed from the old ~13-14k to ~8000: the #625 pin bump moved gather_gemm to JIT, which sped the fallback dramatically for the pathological 128-expert config (qwen3-30b-a3b fallback 58.2 -> 91.0, erasing its old +55% fused win) while leaving low-expert configs unchanged (lfm2 still +12%). Fused clearly wins through Dff 6400, is break-even at 8192, and loses at 14336. CUDA default set to the 8192 break-even boundary (rounded down from the 14336 regression); Metal stays 4096.

Before / after by default (CUDA)

Models with Dff in (4096, 8192] that the old 4096 default declined and the new 8192 default now fuses by default:

model Dff before (4096 default, gather_qmm) after (8192 default, fused) delta
phi-3.5-moe-4bit 6400 51.09 53.62 +5.0%
llama-4-scout-17b-4bit 8192 21.23 21.08 break-even

mixtral-8x7b (Dff 14336) stays on gather_qmm (past the crossover, -1.2% if forced).

Parity

#319-style fused-vs-fallback greedy (temp 0) parity re-run on qwen3-30b-a3b. This change touches only the threshold default, not the kernel math: models that now fuse (phi-3.5-moe, llama-4-scout) use the identical pre-existing fused kernel. On GB10/CUDA greedy temp-0 output is non-deterministic run-to-run for both paths (whole-model GPU FP-reduction order): two identical gather_qmm runs already diverge, as do two identical fused runs. All four runs share the same ~30-token prefix and then diverge at the same near-tie cascade point, so the fused kernel stays inside the same envelope gather_qmm has with itself, i.e. within the documented f16 jitter class. Byte-identical greedy parity is not achievable on CUDA for either path and is unaffected by this change.

Test plan

  • cargo build --release --features cuda
  • New unit test fused_moe_max_dff_default_is_backend_aware_and_env_overrides passes (cargo test --release --features cuda -p mlxcel fused_moe_max_dff -- --test-threads=1: 1 passed)
  • Fused-vs-fallback greedy parity characterized on qwen3-30b-a3b (see Parity)
  • cargo test --release --features cuda -p mlxcel-core -- --test-threads=1: three failures pre-date this PR and are unrelated to the additive metal_is_available() FFI (which no failing test calls): steel_gemm_edge_tile_safe_load_matches_reference (numerical drift, max_abs 0.0198), test_from_bytes_fp16_native_dtype (dtype enum code 9 vs 10) and test_fused_paged_decode_gqa_and_batched (a Metal-only kernel that throws [metal_kernel] No Metal back-end and aborts on the CUDA build). All three are MLX pin e9463bb / CUDA-environment issues introduced by the chore(mlx): upgrade pinned MLX commit, rebase CUDA overlays, GB10 re-baseline and outlier triage #625 pin bump; ffi_tests.rs is untouched here.

Closes #626

…reshold

The fused single-token decode-MoE kernel declines to gather_qmm when the expert intermediate (Dff) exceeds MLXCEL_FUSED_MOE_MAX_DFF, whose default was a flat 4096 tuned for Metal. On CUDA the crossover is higher, so mid-size-expert models such as phi-3.5-moe (Dff 6400) silently lost the fused kernel by default. Make the default backend-aware: 4096 when Metal is available, 8192 otherwise, with the env var still overriding both backends.

The Dff-bound resolution moves into a pure, unit-tested fused_moe_max_dff_from(env, metal_available); the live backend is read through a new metal_is_available() FFI that mirrors the mlx::core::metal::is_available() gate the kernel dispatch already uses, so the default follows the actual device instead of a compile-time cfg.

Re-measured the CUDA crossover on GB10 (DGX Spark, sm_121) under MLX pin e9463bb. The #625 pin bump moved gather_gemm to a JIT path that sped the gather_qmm fallback dramatically for the pathological 128-expert config (qwen3-30b-a3b fallback 58.2 -> 91.0 tok/s, erasing its old +55% fused win) while leaving low-expert configs unchanged (lfm2 still +12%). Fused now clearly wins through Dff 6400 (phi-3.5-moe +5%), is break-even at 8192 (llama-4-scout), and loses at 14336 (mixtral -1.2%), so the crossover collapsed from the old ~13-14k to ~8000. The CUDA default is set to the 8192 break-even boundary, which captures phi-3.5-moe and llama-4-scout while leaving mixtral on gather_qmm.

Docs and the committed per-run sweep CSV under docs/benchmark_results/ record the new backend-specific defaults and the 2026-07-03 GB10 sweep.
@inureyes inureyes added type:performance Performance improvements priority:medium Medium priority area:models Model architectures, weights, loading, metadata area:core mlxcel-core: MLX FFI, primitives, KV cache, layers status:done Completed labels Jul 2, 2026
@inureyes inureyes merged commit 37931d7 into main Jul 2, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:models Model architectures, weights, loading, metadata priority:medium Medium priority status:done Completed type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf(cuda/moe): backend-specific default for the fused decode-MoE Dff threshold

1 participant