Skip to content

[Common][PyTorch]Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114

Open
vthumbe1503 wants to merge 77 commits into
NVIDIA:mainfrom
vthumbe1503:current_scaling_group_quant
Open

[Common][PyTorch]Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize#3114
vthumbe1503 wants to merge 77 commits into
NVIDIA:mainfrom
vthumbe1503:current_scaling_group_quant

Conversation

@vthumbe1503

@vthumbe1503 vthumbe1503 commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Description

  • Current Scaling Group Quantization added in Common/Pytorch for Varying First/Last Dims
  • Enables Varying Last/Both Dims for Group Quantize from Pytorch C++/Python Infrastructure(Wasnt present before for MXFP8 group quantize)
  • Enables No_op flag from group_quantize Pytorch infra.(This was present previously only in Pytorch single tensor quantization infra)
  • Benchmarking Scripts/Tests to benchmark Group Quantize Current Scaling for corner cases like overallocation, non-uniform group sizes etc. Tests are added for both Common(CPP) and Pytorch(Python)

Performance from Benchmarking Script for Current Scaling Group Quantize.
NOTE:

  1. overalloc means the grouped tensor is overallocated 4 times the actual size. But this doesnt matter to the kernel since kernel reads the offsets from the device and decides the CTA to elements mapping.
  2. zipf, heavy, mild etc represent the amount of group imbalance. As seen below, the kernel doesnt deteriorate in performance if group imbalance is present in Blackwell. For Hopper, group_imbalance affects the perf of columnwise cast kernel at the moment due to the use of 2D grid. This will be fixed in a followup PR.

GB200
image

H100
image
image
image

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Core Changes

    • common.cuh has been seperated out into two files -- grouped_layout.cuh and grouped_tma.cuh. The idea is that, grouped_layout.cuh has generic utilities commonly used in grouped tensor kernels. And grouped_tma.cuh has the arch specific TMA changes. This is done so that current scaling can still be non-arch specific and it can use just the grouped_layout.cuh since we didnt need TMA for current scaling.
    • Dispatch has the change to enable calling group quantize for NVTE_DELAYED_SCALING which is the enum used for both delayed scaling and current scaling in the legacy infra(Note: Delayed Scaling hasnt enabled in the Pytorch Infra in this PR)
    • FP8 : Kernels for Computing Grouped Amax and grouped Cast to Per tensor FP8 are added. Both hit close to 6TB/s to all use-cases on Blackwell and 3 TB/s for H100 rowwise cast kernel and 2.5 TB/s for H100 columnwise cast kernel.
  • Other Common Changes

    • NVTE APIs for computing grouped amax and computing grouped scale from grouped amax added which uses the kernels implemented in core.
    • nvte_splits_to_offsets_2d --> NVTE API for computing for computing offsets for VARYING_BOTH_DIMS added. Since we now enable enabling varying last/both dims from pytorch infra.
  • Pytorch Group Quantize API Changes

    • Current Scaling Support added that uses the NVTE APIs from common for computing grouped amax and performing a grouped cast
    • Accepts last_dims in the API that enables Varying Last Dims/Both Dims in the Pytorch infrastructure
    • Accepts no_op flag for cuda graph when wanting skip_fp8_weight_update. (Would be needed for single_grouped_weight correctness with cuda graph when skip_weight_update is enabled --> essentially when we use group_quantize on the weights)
  • Benchmarking Scripts/Tests added for Current Scaling for Varying First/Last Dims and to handle overallocation of grouped tensors. Varying All Dims is not supported yet.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Orchestra and others added 30 commits May 14, 2026 18:42
Route grouped Float8CurrentScalingQuantizer through the existing grouped quantize entry point, prepare per-group current-scaling metadata with existing amax/scale helpers, and add focused tests plus a GB200 bandwidth benchmark.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_5507e814ee50f9ff304a4ce708d19768
Orchestra-Run: run_516e1e26891f4ce7d4cde07147c10862
Use wider vectorized grouped FP8 cast-transpose tiles and vectorized masked stores for rowwise and columnwise outputs. Capture all benchmark modes in a single post-warmup profiler range.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_3d6e33eab11e293d72eb4394bad76a81
Orchestra-Run: run_a6e2c31d5fdf850594f71438e53148da
Route non-MXFP8 grouped-linear bias backward through group_quantize plus grouped dbias while keeping MXFP8 bgrad_group_quantize fusion intact. Add focused zero-row grouped FP8 coverage and a current-scaling GroupedLinear bias-backward regression.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_ab566800d87047635cd27f9e64661abe
Orchestra-Run: run_5f9bfef17ccd854232c54d56268ef9e8
Use packed FP8 conversion and reduce columnwise transpose staging register and synchronization overhead in group_cast_fp8_kernel.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_7a830e018ceac8de0018280bd0740a54
Orchestra-Run: run_d2f1df4ffc2265d9cfa5ed01028ee476
Match the grouped FP8 conversion helper's element-count template parameter to Vec's uint32_t parameter so rowwise, columnwise, and activation instantiations can build.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_30c4b6ddb896e5ea3ca5b54731d2c819
Orchestra-Run: run_e95cdbb445943304622b95736f0eca49
Use cached grouped offsets to avoid launching FP8 quantization over unused overallocated rows, permit larger grouped backing buffers when split metadata is present, and tighten full-tile vector paths in the grouped FP8 cast kernel.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_c5db93823dc101838cb1323e283cd6e9
Orchestra-Run: run_063e2e4c724e132612aa5597d6765c9b
Use the FP8 grouped output logical shape when computing the tensor-scaling launch grid so overallocated buffers with active metadata avoid empty tail-row launches while preserving the allocated-shape fallback.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_b4abb47c990404d73142342a19996a3f
Orchestra-Run: run_8f09e7b9d7af9754ef505f2e2ce3cf90
Use larger grouped FP8 tiles with 8-warp CTAs and 16-row columnwise store fragments. Treat uniform overallocated FP8 grouped outputs as same-shape wrappers during output reuse so the timed path avoids varying-shape metadata overlaunch. Add overallocated current-scaling coverage for all grouped FP8 direction modes.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_3f98ac9c5b82192ec289d8d2a9816c7f
Orchestra-Run: run_83f3b99cc950024cf06ee836337fbf72
Stage columnwise transpose fragments through shared-memory vectors with smaller columnwise row tiles to reduce register pressure and barrier overhead while preserving the larger rowwise-only store path.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_495cc57eef84749103aded403a508d99
Orchestra-Run: run_53e038e90f83186bc6c12cb722c986b5
Add fast grouped FP8 rowwise and full-tile columnwise paths for uniform active groups while preserving the general fallback for varying grouped metadata.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_4c33e88776c8a7148e9da5cc2bae84ea
Orchestra-Run: run_2caaff219394eb5d59b7be38ab2bf346
Add a same-shape bidirectional full-tile kernel with wider input vectors and rowwise stores while preserving the existing rowwise-only, columnwise-only, and fallback grouped paths.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_87cec01d94f053b53e3c79377ad379ab
Orchestra-Run: run_ed48db00a730a4bf56530d551ecd350e
Route same-shape rowwise+columnwise grouped FP8 tensor-scaling quantization through the compact full-tile transpose schedule instead of the wide dynamic-shared-memory variant, preserving the existing single-direction and fallback paths.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_fdddd228a620039c024b4ecf43f3ab42
Orchestra-Run: run_30a2753eea9c893cb0fadb8233da8ce6
Hint the rowwise stores in the full-tile rowwise+columnwise grouped FP8 path as streaming global stores to reduce cache/writeback pressure without changing single-direction launch geometry.

Orchestra-Work-Order: wo_aea2e337b06582111bba66a6d6158a9e
Orchestra-Task: task_bf82020032e68276f4e47c65f62d97ae
Orchestra-Run: run_754ea4c864f329c6f2003b413b723c43
Add graph-safe grouped FP8 tensor-scaling metadata, support varying last dimensions, preserve same-shape fast paths, adjust grouped FP8 columnwise allocation by architecture, and expand benchmark/test coverage for the reviewed shape cases.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_d104e74844fbc3d3b1a98a8d96d76037
Orchestra-Run: run_1314e997c61ffb92ff7120b0b26f0318
Map varying-last columnwise tiles per group to avoid tile-alignment device errors, expand nonaligned boundary coverage, and restore same-shape benchmark baseline criteria.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_14e0e7973300d26f69550bc0aee21acc
Orchestra-Run: run_2f42b8ba138ed8b2b4d9dc90b92caf85
Add grouped FP8 benchmark support for baseline-ref same-session reports and update the benchmark request to enforce same-shape baseline regression checks alongside the per-mode throughput thresholds.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_d0cada957a4aafdce9d52be86520e182
Orchestra-Run: run_4da74e9bdb4f4a4c72304a385692b6c9
Update the grouped FP8 benchmark driver so same-session baseline checks out and builds the baseline ref into an isolated PyTorch install, verifies the baseline subprocess loads those shared objects, and preserves the required same-shape baseline comparisons.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_4fd88b172872f547f2f2d0053dce73d1
Orchestra-Run: run_6a44ee0467ffff47d4b278de6127354d
Preserve grouped delayed-FP8 amax metadata and keep unsupported FP8 tensor-scaling quantizers out of the grouped GEMM path.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_2aa8e6bf11ae356f4b34d4540b508031
Orchestra-Run: run_302681098d7f4e05b0ad96450f2d9826
Set NVTE_GROUPED_LINEAR_SINGLE_PARAM inside the targeted state-dict tests so they exercise the gated single grouped parameter path without relying on external environment setup.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_261900f987bdc9397965019983a77c41
Orchestra-Run: run_c6624e34717cbe121b3e0edcf490e3d3
Add a segmented flat rowwise kernel for varying-first grouped FP8 tensor-scaling outputs while preserving the existing same-shape fast path.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_c1b7020b27290318848ef6ac9048dd5f
Orchestra-Run: run_5c257b8a5d2e7e4aa95e67aa16436166
Omit the last_dims keyword when absent so the same-session baseline can run against the base extension, and refresh the benchmark request to include direct varying-last current-scaling coverage.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_c20a3c94fdc798e741a469bd7bb9c4df
Orchestra-Run: run_457448e6cba80fc63ac72b3db71c5fd0
Dispatch varying-first tensor-scaling work per group to reduce inactive-tail CTAs and offset lookup overhead while preserving same-shape fast paths and graph-safe device metadata handling.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_d84e1fefef8641e558df064452f4689b
Orchestra-Run: run_a361ca2f93fcec53ddd60dd99f4639e5
Add a no-tail rowwise flat kernel for aligned varying-first grouped FP8 tensor-scaling quantization and keep same-shape and varying-last dispatch isolated. Tighten benchmark profiler timing so post-warmup measured ranges exclude profiler start overhead.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_2e478be1fb38195f36d25c51320dc01f
Orchestra-Run: run_9a133a75fa3d98dc3b1a63b0ff4d84af
Write grouped FP8 benchmark reports to a sidecar path by default and label script reports as benchmark_raw_report/v1 so regular 100-iteration measurements are fetched instead of the wrapper command report.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_27770b2e1d490b1a3053244d4b4ce248
Orchestra-Run: run_214052d0c1316e231443d645183a2675
Write the grouped FP8 benchmark JSON once and mirror the completed sidecar to ORCHESTRA_BENCHMARK_RAW_REPORT when running under Orchestra so the benchmark fetch path can parse the emitted measurements.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_b2e2747371204088c8e3f7cf10263164
Orchestra-Run: run_1d4ea38266807c8acb59143ee74ba241
Allow the grouped FP8 benchmark to use ORCHESTRA_BENCHMARK_RAW_REPORT as its primary output so the benchmark wrapper can fetch canonical measurements directly.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_10fdcfef6b70de4676b7843e4bbfac31
Orchestra-Run: run_4ce57df9e86d6d03a26f7aa95ac252cc
Write canonical grouped FP8 benchmark measurements to ORCHESTRA_BENCHMARK_RAW_REPORT in a small schema-shaped payload so the benchmark wrapper can materialize per-mode threshold evidence.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_3e862eebd585c74f2a58497fedea3511
Orchestra-Run: run_3770ab3dbbf51329d0839b3d10a91b5c
Write candidate_results and nonempty measurements into the Orchestra raw report path, and fail fast if the benchmark cannot produce threshold-ready evidence.

Orchestra-Work-Order: wo_9d18259ce6c2833da1178606e08d251a
Orchestra-Task: task_aa587a7b0d35aa9c2b715ec1b7c8bec3
Orchestra-Run: run_b42870e5d5e142a6cbf53bb5a3cafc2e
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review June 13, 2026 01:03
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds current-scaling group quantization (per-group amax → scale → FP8 cast) for both rowwise and columnwise layouts, and extends the existing group-quantize infrastructure in PyTorch and Common to support tensors with a varying last dimension (VARYING_LAST_DIM) and varying both dimensions (VARYING_BOTH_DIMS), in addition to the already-supported VARYING_FIRST_DIM case.

  • New CUDA kernels (group_amax_fp8.cuh, group_quantize_fp8.cuh): flat vectorized per-group amax, scale derivation, and FP8 cast with shape-representation dispatch; Blackwell fast-path uses hardware-accelerated 4-element scaled FP8 conversion.
  • PyTorch / C++ plumbing: last_dims parameter added throughout create_grouped_tensor overloads, build_grouped_tensor_offsets handles 1D/2D split-to-offset cases, and splits_to_offsets_2d_kernel computes cumulative element offsets for VARYING_BOTH_DIMS on device.
  • Shared scale_inv buffer for FP8 current scaling: on Blackwell (is_non_tn_fp8_gemm_supported), a single scale_inv tensor is aliased for both rowwise and (suppressed) columnwise paths, matching the C++ grouped tensor wrapper behavior.

Confidence Score: 4/5

The current-scaling FP8 path is well-guarded and the new kernels handle all four shape representations correctly; the one unguarded combination (MXFP8 + last_dims) would silently produce wrong quantized output if triggered.

All new code paths for FP8 current scaling, the new CUDA kernels, the 2D offset kernel, and the Python storage changes are internally consistent. NVFP4 and bgrad_group_quantize both explicitly reject a non-None last_dims, but the MXFP8 dispatch in group_quantize carries no equivalent guard. Passing last_dims with an MXFP8 quantizer causes create_grouped_tensor to size the scale buffer using the total (summed) logical last dim rather than per-tensor dims, then the MXFP8 cast kernel operates on the mismatched scale layout without error.

transformer_engine/pytorch/csrc/extensions/cast.cpp — the MXFP8_GROUPED_QUANTIZE switch case needs a guard matching the one already present for NVFP4.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds current-scaling group-quantize dispatch, noop_flag forwarding, and last_dims plumbing; NVFP4/bgrad cases guard against last_dims but MXFP8 case is missing the equivalent guard, allowing wrong scale buffer sizing to silently propagate.
transformer_engine/pytorch/csrc/quantizer.cpp Adds last_dims parameter to all create_grouped_tensor overloads, updates build_grouped_tensor_offsets for 1D/2D varying cases, and implements shared scale_inv for FP8 current scaling on Blackwell; logic is sound.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Extends make_grouped_tensor to handle VARYING_LAST_DIM and VARYING_BOTH_DIMS, unifies the graph-capture guard, and adds shared scale_inv for FP8 current scaling; the logical_shape conventions for varying dims are correct.
transformer_engine/common/cast/fp8/group_quantize_fp8.cuh New file implementing per-group FP8 cast kernels for all four shape representations; fast-path specializations for rowwise-flat and same-shape transpose; VARYING_LAST_DIM handled correctly via per-tensor col bounds check in the 2D grid.
transformer_engine/common/cast/fp8/group_amax_fp8.cuh New file with flat vectorized per-group amax kernel and scale/scale_inv derivation kernel; handles all shape representations via device-side offset arrays; shared-memory usage scales with num_tensors.
transformer_engine/common/util/splits_to_offsets.cu Adds splits_to_offsets_2d_kernel for VARYING_BOTH_DIMS element-offset computation; single-block chunked inclusive scan with correct chunk_prefix bookkeeping across arbitrary num_tensors.
transformer_engine/pytorch/ops/fused/grouped_mlp.py Converts rowwise_amax and columnwise_amax to keyword arguments in the nvfp4_group_quantize_with_amax call site to accommodate the new last_dims positional parameter.
transformer_engine/common/recipe/current_scaling.cu Adds nvte_group_compute_amax_with_config API; shape detection logic matches the four representation cases correctly.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adds last_dims (default None) and noop_flag (default None) to group_quantize binding, and last_dims to nvfp4_group_quantize_with_amax and bgrad_group_quantize bindings.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["group_quantize(tensor, quantizer, num_tensors,\nfirst_dims, last_dims, tensor_offsets, noop_flag)"] --> B{Quantizer type?}
    B -->|NVFP4| C["NVTE_CHECK: !last_dims ✓\nnvfp4 group quantize impl"]
    B -->|MXFP8| D["MXFP8_GROUPED_QUANTIZE\n❌ no last_dims guard\nnvte_group_quantize(MXFP8)"]
    B -->|Float8CurrentScaling| E["compute_grouped_fp8_current_scaling_\namax_and_scale()"]
    B -->|Other| F["NVTE_ERROR"]
    E --> E1["nvte_group_compute_amax_with_config\n(zero + flat vectorized amax kernel)"]
    E1 --> E2{"with_amax_reduction?"}
    E2 -->|Yes| E3["NCCL allreduce MAX on amax buffer"]
    E2 -->|No| E4
    E3 --> E4["nvte_group_compute_scale_from_amax\n(scale = fp8_max / amax)"]
    E4 --> E5["nvte_group_quantize\n(FP8 cast kernel)"]
    subgraph "Shape Representations"
        SR1["SAME_BOTH_DIMS\nflat rowwise fast path"]
        SR2["VARYING_FIRST_DIM\naligned flat rowwise / generic 2D colwise"]
        SR3["VARYING_LAST_DIM\nper-row flat rowwise / generic 2D colwise"]
        SR4["VARYING_BOTH_DIMS\n(blocked in current scaling)"]
    end
    E5 --> SR1 & SR2 & SR3 & SR4
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["group_quantize(tensor, quantizer, num_tensors,\nfirst_dims, last_dims, tensor_offsets, noop_flag)"] --> B{Quantizer type?}
    B -->|NVFP4| C["NVTE_CHECK: !last_dims ✓\nnvfp4 group quantize impl"]
    B -->|MXFP8| D["MXFP8_GROUPED_QUANTIZE\n❌ no last_dims guard\nnvte_group_quantize(MXFP8)"]
    B -->|Float8CurrentScaling| E["compute_grouped_fp8_current_scaling_\namax_and_scale()"]
    B -->|Other| F["NVTE_ERROR"]
    E --> E1["nvte_group_compute_amax_with_config\n(zero + flat vectorized amax kernel)"]
    E1 --> E2{"with_amax_reduction?"}
    E2 -->|Yes| E3["NCCL allreduce MAX on amax buffer"]
    E2 -->|No| E4
    E3 --> E4["nvte_group_compute_scale_from_amax\n(scale = fp8_max / amax)"]
    E4 --> E5["nvte_group_quantize\n(FP8 cast kernel)"]
    subgraph "Shape Representations"
        SR1["SAME_BOTH_DIMS\nflat rowwise fast path"]
        SR2["VARYING_FIRST_DIM\naligned flat rowwise / generic 2D colwise"]
        SR3["VARYING_LAST_DIM\nper-row flat rowwise / generic 2D colwise"]
        SR4["VARYING_BOTH_DIMS\n(blocked in current scaling)"]
    end
    E5 --> SR1 & SR2 & SR3 & SR4
Loading

Reviews (9): Last reviewed commit: "add CPP tests for current scaling group ..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py
Varun Thumbe and others added 2 commits June 14, 2026 12:21
Comment thread transformer_engine/common/cast/fp8/group_quantize_fp8.cuh
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci

@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

Pipeline: 54747206

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
vthumbe1503 and others added 2 commits June 14, 2026 21:56
Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…TransformerEngine into current_scaling_group_quant
@vthumbe1503 vthumbe1503 requested a review from timmoon10 as a code owner June 14, 2026 22:13
Removed duplicate brief comment about scaled prefix-sum offsets.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci

Signed-off-by: Varun Thumbe <vthumbe@vthumbe-mlt.client.nvidia.com>
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci

@Oleg-Goncharov Oleg-Goncharov left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vthumbe1503, could you please also add C++ unit tests covering various shapes, similar to grouped MXFP8?

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci

@vthumbe1503 vthumbe1503 changed the title Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize [Common/PyTorch] Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize Jun 22, 2026
@vthumbe1503 vthumbe1503 changed the title [Common/PyTorch] Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize [Common][PyTorch]Current Scaling Group Quantization + Enabling Varying Last/Both Dims in Group Quantize Jun 22, 2026
Comment thread transformer_engine/pytorch/csrc/quantizer.cpp
Comment on lines 363 to 373
@@ -276,7 +373,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 MXFP8 case missing last_dims guard — produces silently wrong output

The NVFP4_GROUPED_QUANTIZE case correctly rejects a non-None last_dims at line 341. The MXFP8_GROUPED_QUANTIZE case has no equivalent check. If a caller passes last_dims with an MXFP8 quantizer, MXFP8Quantizer::create_grouped_tensor will have already been called (before the switch) with is_varying_both = false, causing it to compute the total rowwise/columnwise scale buffer size via get_scale_shape({logical_first_dim, sum(last_dims)}, …) — treating all tensors as having the same last dimension equal to the sum, instead of computing per-tensor scale counts. The MXFP8 kernel then reads scale offsets that were sized for the wrong shape, producing wrong quantized output with no error raised. bgrad_group_quantize documents this exact constraint ("MXFP8 dbias kernel requires a constant last dimension"), so the same guard should be applied here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants