perf(cuda/moe): backend-aware default for the fused decode-MoE Dff threshold#643
Merged
Merged
Conversation
…reshold The fused single-token decode-MoE kernel declines to gather_qmm when the expert intermediate (Dff) exceeds MLXCEL_FUSED_MOE_MAX_DFF, whose default was a flat 4096 tuned for Metal. On CUDA the crossover is higher, so mid-size-expert models such as phi-3.5-moe (Dff 6400) silently lost the fused kernel by default. Make the default backend-aware: 4096 when Metal is available, 8192 otherwise, with the env var still overriding both backends. The Dff-bound resolution moves into a pure, unit-tested fused_moe_max_dff_from(env, metal_available); the live backend is read through a new metal_is_available() FFI that mirrors the mlx::core::metal::is_available() gate the kernel dispatch already uses, so the default follows the actual device instead of a compile-time cfg. Re-measured the CUDA crossover on GB10 (DGX Spark, sm_121) under MLX pin e9463bb. The #625 pin bump moved gather_gemm to a JIT path that sped the gather_qmm fallback dramatically for the pathological 128-expert config (qwen3-30b-a3b fallback 58.2 -> 91.0 tok/s, erasing its old +55% fused win) while leaving low-expert configs unchanged (lfm2 still +12%). Fused now clearly wins through Dff 6400 (phi-3.5-moe +5%), is break-even at 8192 (llama-4-scout), and loses at 14336 (mixtral -1.2%), so the crossover collapsed from the old ~13-14k to ~8000. The CUDA default is set to the 8192 break-even boundary, which captures phi-3.5-moe and llama-4-scout while leaving mixtral on gather_qmm. Docs and the committed per-run sweep CSV under docs/benchmark_results/ record the new backend-specific defaults and the 2026-07-03 GB10 sweep.
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The fused single-token decode-MoE kernel declines to
gather_qmmwhen the expert intermediate (Dff) exceedsMLXCEL_FUSED_MOE_MAX_DFF, whose default was a flat 4096 tuned for Metal. On CUDA the crossover is higher, so mid-size-expert models silently lost the kernel. This makes the default backend-aware: 4096 on Metal, 8192 on CUDA (re-measured on GB10), with the env var still overriding both backends.What changed
src/models/switch_layers.rs: split the Dff-bound resolution into a pure, unit-testedfused_moe_max_dff_from(env, metal_available); the default is nowFUSED_MOE_MAX_DFF_METAL(4096) when Metal is available andFUSED_MOE_MAX_DFF_CUDA(8192) otherwise. An explicitMLXCEL_FUSED_MOE_MAX_DFFstill wins on both backends. Added a unit test for the backend-aware default and env override.src/lib/mlxcel-core: newmetal_is_available()FFI (cxx bridge decl insrc/lib.rs, header decl incpp/mlx_cxx_bridge.h, impl incpp/mlx_cxx_kernels.cpp) mirroring themlx::core::metal::is_available()gate the kernel dispatch already uses, so the Rust default follows the live backend without a compile-time cfg.docs/environment-variables.mdanddocs/benchmark_results/fused-moe-decode-kernel-design.mdupdated with the backend-specific defaults and a dated (2026-07-03) sweep addendum; raw per-run data committed atdocs/benchmark_results/fused-moe-decode-dff-sweep-gb10-2026-07-03.csv.Re-measured CUDA crossover (GB10, DGX Spark, sm_121, MLX pin e9463bb)
mlxcel-bench-decode, 100 decode tokens after 20-token warmup, median of 3; fallback forced withMLXCEL_FUSED_MOE_MAX_DFF=1, fused with=20000.The crossover collapsed from the old ~13-14k to ~8000: the #625 pin bump moved
gather_gemmto JIT, which sped the fallback dramatically for the pathological 128-expert config (qwen3-30b-a3b fallback 58.2 -> 91.0, erasing its old +55% fused win) while leaving low-expert configs unchanged (lfm2 still +12%). Fused clearly wins through Dff 6400, is break-even at 8192, and loses at 14336. CUDA default set to the 8192 break-even boundary (rounded down from the 14336 regression); Metal stays 4096.Before / after by default (CUDA)
Models with Dff in (4096, 8192] that the old 4096 default declined and the new 8192 default now fuses by default:
mixtral-8x7b (Dff 14336) stays on
gather_qmm(past the crossover, -1.2% if forced).Parity
#319-style fused-vs-fallback greedy (temp 0) parity re-run on qwen3-30b-a3b. This change touches only the threshold default, not the kernel math: models that now fuse (phi-3.5-moe, llama-4-scout) use the identical pre-existing fused kernel. On GB10/CUDA greedy temp-0 output is non-deterministic run-to-run for both paths (whole-model GPU FP-reduction order): two identical gather_qmm runs already diverge, as do two identical fused runs. All four runs share the same ~30-token prefix and then diverge at the same near-tie cascade point, so the fused kernel stays inside the same envelope gather_qmm has with itself, i.e. within the documented f16 jitter class. Byte-identical greedy parity is not achievable on CUDA for either path and is unaffected by this change.
Test plan
cargo build --release --features cudafused_moe_max_dff_default_is_backend_aware_and_env_overridespasses (cargo test --release --features cuda -p mlxcel fused_moe_max_dff -- --test-threads=1: 1 passed)cargo test --release --features cuda -p mlxcel-core -- --test-threads=1: three failures pre-date this PR and are unrelated to the additivemetal_is_available()FFI (which no failing test calls):steel_gemm_edge_tile_safe_load_matches_reference(numerical drift, max_abs 0.0198),test_from_bytes_fp16_native_dtype(dtype enum code 9 vs 10) andtest_fused_paged_decode_gqa_and_batched(a Metal-only kernel that throws[metal_kernel] No Metal back-endand aborts on the CUDA build). All three are MLX pin e9463bb / CUDA-environment issues introduced by the chore(mlx): upgrade pinned MLX commit, rebase CUDA overlays, GB10 re-baseline and outlier triage #625 pin bump;ffi_tests.rsis untouched here.Closes #626