[Common][PyTorch]Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114
[Common][PyTorch]Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114vthumbe1503 wants to merge 77 commits into
Conversation
Route grouped Float8CurrentScalingQuantizer through the existing grouped quantize entry point, prepare per-group current-scaling metadata with existing amax/scale helpers, and add focused tests plus a GB200 bandwidth benchmark. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_5507e814ee50f9ff304a4ce708d19768 Orchestra-Run: run_516e1e26891f4ce7d4cde07147c10862
Use wider vectorized grouped FP8 cast-transpose tiles and vectorized masked stores for rowwise and columnwise outputs. Capture all benchmark modes in a single post-warmup profiler range. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3d6e33eab11e293d72eb4394bad76a81 Orchestra-Run: run_a6e2c31d5fdf850594f71438e53148da
Route non-MXFP8 grouped-linear bias backward through group_quantize plus grouped dbias while keeping MXFP8 bgrad_group_quantize fusion intact. Add focused zero-row grouped FP8 coverage and a current-scaling GroupedLinear bias-backward regression. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_ab566800d87047635cd27f9e64661abe Orchestra-Run: run_5f9bfef17ccd854232c54d56268ef9e8
Use packed FP8 conversion and reduce columnwise transpose staging register and synchronization overhead in group_cast_fp8_kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_7a830e018ceac8de0018280bd0740a54 Orchestra-Run: run_d2f1df4ffc2265d9cfa5ed01028ee476
Match the grouped FP8 conversion helper's element-count template parameter to Vec's uint32_t parameter so rowwise, columnwise, and activation instantiations can build. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_30c4b6ddb896e5ea3ca5b54731d2c819 Orchestra-Run: run_e95cdbb445943304622b95736f0eca49
Use cached grouped offsets to avoid launching FP8 quantization over unused overallocated rows, permit larger grouped backing buffers when split metadata is present, and tighten full-tile vector paths in the grouped FP8 cast kernel. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_c5db93823dc101838cb1323e283cd6e9 Orchestra-Run: run_063e2e4c724e132612aa5597d6765c9b
Use the FP8 grouped output logical shape when computing the tensor-scaling launch grid so overallocated buffers with active metadata avoid empty tail-row launches while preserving the allocated-shape fallback. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_b4abb47c990404d73142342a19996a3f Orchestra-Run: run_8f09e7b9d7af9754ef505f2e2ce3cf90
Use larger grouped FP8 tiles with 8-warp CTAs and 16-row columnwise store fragments. Treat uniform overallocated FP8 grouped outputs as same-shape wrappers during output reuse so the timed path avoids varying-shape metadata overlaunch. Add overallocated current-scaling coverage for all grouped FP8 direction modes. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_3f98ac9c5b82192ec289d8d2a9816c7f Orchestra-Run: run_83f3b99cc950024cf06ee836337fbf72
Stage columnwise transpose fragments through shared-memory vectors with smaller columnwise row tiles to reduce register pressure and barrier overhead while preserving the larger rowwise-only store path. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_495cc57eef84749103aded403a508d99 Orchestra-Run: run_53e038e90f83186bc6c12cb722c986b5
Add fast grouped FP8 rowwise and full-tile columnwise paths for uniform active groups while preserving the general fallback for varying grouped metadata. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_4c33e88776c8a7148e9da5cc2bae84ea Orchestra-Run: run_2caaff219394eb5d59b7be38ab2bf346
Add a same-shape bidirectional full-tile kernel with wider input vectors and rowwise stores while preserving the existing rowwise-only, columnwise-only, and fallback grouped paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_87cec01d94f053b53e3c79377ad379ab Orchestra-Run: run_ed48db00a730a4bf56530d551ecd350e
Route same-shape rowwise+columnwise grouped FP8 tensor-scaling quantization through the compact full-tile transpose schedule instead of the wide dynamic-shared-memory variant, preserving the existing single-direction and fallback paths. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_fdddd228a620039c024b4ecf43f3ab42 Orchestra-Run: run_30a2753eea9c893cb0fadb8233da8ce6
Hint the rowwise stores in the full-tile rowwise+columnwise grouped FP8 path as streaming global stores to reduce cache/writeback pressure without changing single-direction launch geometry. Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e Orchestra-Task: task_bf82020032e68276f4e47c65f62d97ae Orchestra-Run: run_754ea4c864f329c6f2003b413b723c43
Add graph-safe grouped FP8 tensor-scaling metadata, support varying last dimensions, preserve same-shape fast paths, adjust grouped FP8 columnwise allocation by architecture, and expand benchmark/test coverage for the reviewed shape cases. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d104e74844fbc3d3b1a98a8d96d76037 Orchestra-Run: run_1314e997c61ffb92ff7120b0b26f0318
Map varying-last columnwise tiles per group to avoid tile-alignment device errors, expand nonaligned boundary coverage, and restore same-shape benchmark baseline criteria. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_14e0e7973300d26f69550bc0aee21acc Orchestra-Run: run_2f42b8ba138ed8b2b4d9dc90b92caf85
Add grouped FP8 benchmark support for baseline-ref same-session reports and update the benchmark request to enforce same-shape baseline regression checks alongside the per-mode throughput thresholds. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d0cada957a4aafdce9d52be86520e182 Orchestra-Run: run_4da74e9bdb4f4a4c72304a385692b6c9
Update the grouped FP8 benchmark driver so same-session baseline checks out and builds the baseline ref into an isolated PyTorch install, verifies the baseline subprocess loads those shared objects, and preserves the required same-shape baseline comparisons. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_4fd88b172872f547f2f2d0053dce73d1 Orchestra-Run: run_6a44ee0467ffff47d4b278de6127354d
Preserve grouped delayed-FP8 amax metadata and keep unsupported FP8 tensor-scaling quantizers out of the grouped GEMM path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2aa8e6bf11ae356f4b34d4540b508031 Orchestra-Run: run_302681098d7f4e05b0ad96450f2d9826
Set NVTE_GROUPED_LINEAR_SINGLE_PARAM inside the targeted state-dict tests so they exercise the gated single grouped parameter path without relying on external environment setup. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_261900f987bdc9397965019983a77c41 Orchestra-Run: run_c6624e34717cbe121b3e0edcf490e3d3
Add a segmented flat rowwise kernel for varying-first grouped FP8 tensor-scaling outputs while preserving the existing same-shape fast path. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c1b7020b27290318848ef6ac9048dd5f Orchestra-Run: run_5c257b8a5d2e7e4aa95e67aa16436166
Omit the last_dims keyword when absent so the same-session baseline can run against the base extension, and refresh the benchmark request to include direct varying-last current-scaling coverage. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_c20a3c94fdc798e741a469bd7bb9c4df Orchestra-Run: run_457448e6cba80fc63ac72b3db71c5fd0
Dispatch varying-first tensor-scaling work per group to reduce inactive-tail CTAs and offset lookup overhead while preserving same-shape fast paths and graph-safe device metadata handling. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_d84e1fefef8641e558df064452f4689b Orchestra-Run: run_a361ca2f93fcec53ddd60dd99f4639e5
Add a no-tail rowwise flat kernel for aligned varying-first grouped FP8 tensor-scaling quantization and keep same-shape and varying-last dispatch isolated. Tighten benchmark profiler timing so post-warmup measured ranges exclude profiler start overhead. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_2e478be1fb38195f36d25c51320dc01f Orchestra-Run: run_9a133a75fa3d98dc3b1a63b0ff4d84af
Write grouped FP8 benchmark reports to a sidecar path by default and label script reports as benchmark_raw_report/v1 so regular 100-iteration measurements are fetched instead of the wrapper command report. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_27770b2e1d490b1a3053244d4b4ce248 Orchestra-Run: run_214052d0c1316e231443d645183a2675
Write the grouped FP8 benchmark JSON once and mirror the completed sidecar to ORCHESTRA_BENCHMARK_RAW_REPORT when running under Orchestra so the benchmark fetch path can parse the emitted measurements. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_b2e2747371204088c8e3f7cf10263164 Orchestra-Run: run_1d4ea38266807c8acb59143ee74ba241
Allow the grouped FP8 benchmark to use ORCHESTRA_BENCHMARK_RAW_REPORT as its primary output so the benchmark wrapper can fetch canonical measurements directly. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_10fdcfef6b70de4676b7843e4bbfac31 Orchestra-Run: run_4ce57df9e86d6d03a26f7aa95ac252cc
Write canonical grouped FP8 benchmark measurements to ORCHESTRA_BENCHMARK_RAW_REPORT in a small schema-shaped payload so the benchmark wrapper can materialize per-mode threshold evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_3e862eebd585c74f2a58497fedea3511 Orchestra-Run: run_3770ab3dbbf51329d0839b3d10a91b5c
Write candidate_results and nonempty measurements into the Orchestra raw report path, and fail fast if the benchmark cannot produce threshold-ready evidence. Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a Orchestra-Task: task_aa587a7b0d35aa9c2b715ec1b7c8bec3 Orchestra-Run: run_b42870e5d5e142a6cbf53bb5a3cafc2e
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR adds current-scaling group quantization (per-group amax → scale → FP8 cast) for both rowwise and columnwise layouts, and extends the existing group-quantize infrastructure in PyTorch and Common to support tensors with a varying last dimension (
Confidence Score: 4/5The current-scaling FP8 path is well-guarded and the new kernels handle all four shape representations correctly; the one unguarded combination (MXFP8 + last_dims) would silently produce wrong quantized output if triggered. All new code paths for FP8 current scaling, the new CUDA kernels, the 2D offset kernel, and the Python storage changes are internally consistent. NVFP4 and bgrad_group_quantize both explicitly reject a non-None last_dims, but the MXFP8 dispatch in group_quantize carries no equivalent guard. Passing last_dims with an MXFP8 quantizer causes create_grouped_tensor to size the scale buffer using the total (summed) logical last dim rather than per-tensor dims, then the MXFP8 cast kernel operates on the mismatched scale layout without error. transformer_engine/pytorch/csrc/extensions/cast.cpp — the MXFP8_GROUPED_QUANTIZE switch case needs a guard matching the one already present for NVFP4. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["group_quantize(tensor, quantizer, num_tensors,\nfirst_dims, last_dims, tensor_offsets, noop_flag)"] --> B{Quantizer type?}
B -->|NVFP4| C["NVTE_CHECK: !last_dims ✓\nnvfp4 group quantize impl"]
B -->|MXFP8| D["MXFP8_GROUPED_QUANTIZE\n❌ no last_dims guard\nnvte_group_quantize(MXFP8)"]
B -->|Float8CurrentScaling| E["compute_grouped_fp8_current_scaling_\namax_and_scale()"]
B -->|Other| F["NVTE_ERROR"]
E --> E1["nvte_group_compute_amax_with_config\n(zero + flat vectorized amax kernel)"]
E1 --> E2{"with_amax_reduction?"}
E2 -->|Yes| E3["NCCL allreduce MAX on amax buffer"]
E2 -->|No| E4
E3 --> E4["nvte_group_compute_scale_from_amax\n(scale = fp8_max / amax)"]
E4 --> E5["nvte_group_quantize\n(FP8 cast kernel)"]
subgraph "Shape Representations"
SR1["SAME_BOTH_DIMS\nflat rowwise fast path"]
SR2["VARYING_FIRST_DIM\naligned flat rowwise / generic 2D colwise"]
SR3["VARYING_LAST_DIM\nper-row flat rowwise / generic 2D colwise"]
SR4["VARYING_BOTH_DIMS\n(blocked in current scaling)"]
end
E5 --> SR1 & SR2 & SR3 & SR4
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["group_quantize(tensor, quantizer, num_tensors,\nfirst_dims, last_dims, tensor_offsets, noop_flag)"] --> B{Quantizer type?}
B -->|NVFP4| C["NVTE_CHECK: !last_dims ✓\nnvfp4 group quantize impl"]
B -->|MXFP8| D["MXFP8_GROUPED_QUANTIZE\n❌ no last_dims guard\nnvte_group_quantize(MXFP8)"]
B -->|Float8CurrentScaling| E["compute_grouped_fp8_current_scaling_\namax_and_scale()"]
B -->|Other| F["NVTE_ERROR"]
E --> E1["nvte_group_compute_amax_with_config\n(zero + flat vectorized amax kernel)"]
E1 --> E2{"with_amax_reduction?"}
E2 -->|Yes| E3["NCCL allreduce MAX on amax buffer"]
E2 -->|No| E4
E3 --> E4["nvte_group_compute_scale_from_amax\n(scale = fp8_max / amax)"]
E4 --> E5["nvte_group_quantize\n(FP8 cast kernel)"]
subgraph "Shape Representations"
SR1["SAME_BOTH_DIMS\nflat rowwise fast path"]
SR2["VARYING_FIRST_DIM\naligned flat rowwise / generic 2D colwise"]
SR3["VARYING_LAST_DIM\nper-row flat rowwise / generic 2D colwise"]
SR4["VARYING_BOTH_DIMS\n(blocked in current scaling)"]
end
E5 --> SR1 & SR2 & SR3 & SR4
Reviews (9): Last reviewed commit: "add CPP tests for current scaling group ..." | Re-trigger Greptile |
… specific Signed-off-by: Varun Thumbe <vthumbe@vthumbe-mlt.client.nvidia.com>
|
/te-ci |
|
Pipeline: 54747206 |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…scale from amax Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
…TransformerEngine into current_scaling_group_quant
Removed duplicate brief comment about scaled prefix-sum offsets. Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci |
Signed-off-by: Varun Thumbe <vthumbe@vthumbe-mlt.client.nvidia.com>
|
/te-ci |
Oleg-Goncharov
left a comment
There was a problem hiding this comment.
Hi @vthumbe1503, could you please also add C++ unit tests covering various shapes, similar to grouped MXFP8?
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci |
| @@ -276,7 +373,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const | |||
| } | |||
There was a problem hiding this comment.
MXFP8 case missing
last_dims guard — produces silently wrong output
The NVFP4_GROUPED_QUANTIZE case correctly rejects a non-None last_dims at line 341. The MXFP8_GROUPED_QUANTIZE case has no equivalent check. If a caller passes last_dims with an MXFP8 quantizer, MXFP8Quantizer::create_grouped_tensor will have already been called (before the switch) with is_varying_both = false, causing it to compute the total rowwise/columnwise scale buffer size via get_scale_shape({logical_first_dim, sum(last_dims)}, …) — treating all tensors as having the same last dimension equal to the sum, instead of computing per-tensor scale counts. The MXFP8 kernel then reads scale offsets that were sized for the wrong shape, producing wrong quantized output with no error raised. bgrad_group_quantize documents this exact constraint ("MXFP8 dbias kernel requires a constant last dimension"), so the same guard should be applied here.
Description
Performance from Benchmarking Script for Current Scaling Group Quantize.
NOTE:
GB200

H100



Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Core Changes
common.cuhhas been seperated out into two files --grouped_layout.cuhandgrouped_tma.cuh. The idea is that,grouped_layout.cuhhas generic utilities commonly used in grouped tensor kernels. Andgrouped_tma.cuhhas the arch specific TMA changes. This is done so that current scaling can still be non-arch specific and it can use just thegrouped_layout.cuhsince we didnt need TMA for current scaling.Other Common Changes
Pytorch Group Quantize API Changes
Benchmarking Scripts/Tests added for Current Scaling for Varying First/Last Dims and to handle overallocation of grouped tensors. Varying All Dims is not supported yet.
Checklist: