Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
bb42995
Add grouped FP8 tensor-scaling quantization
May 15, 2026
f39c527
Optimize grouped FP8 tensor-scaling quantization
May 15, 2026
f18e886
Fix grouped FP8 bias backward routing
May 15, 2026
ff3a766
Optimize grouped FP8 tensor-scaling cast kernel
May 15, 2026
84b5b6c
Fix grouped FP8 vector conversion deduction
May 15, 2026
acdf26b
Optimize grouped FP8 overallocated launch and transpose staging
May 15, 2026
15aa021
Size grouped FP8 launches from active output shape
May 15, 2026
51b7582
Improve grouped FP8 columnwise quantization bandwidth
May 15, 2026
82ef5c6
Optimize grouped FP8 columnwise staging
May 15, 2026
c1cd184
Optimize grouped FP8 same-shape quantization
May 15, 2026
8226a09
Optimize grouped FP8 bidirectional full-tile quantization
May 15, 2026
4a77e8b
Use compact full-tile grouped FP8 bidirectional quantize
May 15, 2026
94c4031
Use streaming stores for bidirectional grouped FP8 rowwise output
May 15, 2026
47403e1
Handle varying-dim grouped FP8 quantization feedback
May 16, 2026
a9b9d23
Handle nonaligned varying-last grouped FP8 quantization
May 16, 2026
3d740c2
Emit same-session grouped FP8 baseline measurements
May 16, 2026
1577806
Build grouped FP8 benchmark baseline extension
May 16, 2026
a867d53
Fix grouped FP8 targeted correctness failures
May 16, 2026
cd2d545
Make GroupedLinear single-param state-dict tests hermetic
May 16, 2026
15cd492
Specialize varying-first grouped FP8 quantization
May 16, 2026
d5c9bee
Fix grouped FP8 benchmark baseline compatibility
May 16, 2026
663cfa2
Optimize varying-first grouped FP8 quantization
May 16, 2026
3d0dce6
Optimize aligned varying-first grouped FP8 quantize
May 16, 2026
261bb04
Protect grouped FP8 benchmark raw report output
May 16, 2026
c4e79ea
Mirror grouped FP8 benchmark report for Orchestra fetch
May 16, 2026
dce8d3e
Write grouped FP8 benchmark report to raw path
May 16, 2026
3ea1ec6
Emit compact grouped FP8 raw measurements
May 16, 2026
673b30f
Preserve grouped FP8 raw benchmark results
May 16, 2026
79efcca
changes to improve amax kernel
vthumbe1503 May 27, 2026
0c3f0c7
cleanups
vthumbe1503 May 27, 2026
8c30915
cleanup tests
vthumbe1503 May 27, 2026
2fc1f31
further cleanup
vthumbe1503 May 27, 2026
65b10e6
further cleanup
vthumbe1503 May 27, 2026
f062f05
profile code push for now
vthumbe1503 May 28, 2026
410b728
cleanups
vthumbe1503 May 28, 2026
b692384
grouped amax in a seperate file
vthumbe1503 May 28, 2026
7cd35ef
changes so far
vthumbe1503 May 28, 2026
8c5219f
Merge remote-tracking branch 'nvidia_origin/main' into current_scalin…
vthumbe1503 May 28, 2026
e673e94
clean
vthumbe1503 May 28, 2026
01c3b22
no need for all same first
vthumbe1503 May 28, 2026
b50146a
dead code
vthumbe1503 May 28, 2026
c331674
revert
vthumbe1503 May 28, 2026
e5dcc28
clean
vthumbe1503 May 29, 2026
f28df69
all changes:
vthumbe1503 May 29, 2026
0e21f7f
cleanup
vthumbe1503 May 29, 2026
5a08af0
cleanpus
vthumbe1503 May 29, 2026
99aaf8c
Resolve merge conflicts with main and generalize splits to offset ker…
vthumbe1503 Jun 8, 2026
18c922f
Simplify splits_to_offsets and avoid modifying the multi-offset API
vthumbe1503 Jun 8, 2026
e6adb2f
resolve merge conflicts
vthumbe1503 Jun 8, 2026
937aa56
add the single grouped weight skip check
vthumbe1503 Jun 8, 2026
a39ad23
some optimizations.. unclean
vthumbe1503 Jun 9, 2026
4549c1b
cleanups
vthumbe1503 Jun 9, 2026
24bb0cd
optimize amax kernel
vthumbe1503 Jun 9, 2026
3c7643f
optimize cast kernel
vthumbe1503 Jun 10, 2026
bf227f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2026
15eecff
improve colwise cast perf
vthumbe1503 Jun 10, 2026
f8ee99d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2026
0aaec71
nvte tensors in the API
vthumbe1503 Jun 10, 2026
5269232
Merge branch 'current_scaling_group_quant' of github.com:vthumbe1503/…
vthumbe1503 Jun 10, 2026
587ba3d
address review comments
vthumbe1503 Jun 12, 2026
6986fc8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2026
5c7d1d6
Merge branch 'main' into current_scaling_group_quant
vthumbe1503 Jun 12, 2026
d6856a4
address review comments, fix lint errors
vthumbe1503 Jun 13, 2026
624f94c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2026
59e06f2
no need for type_trait
vthumbe1503 Jun 13, 2026
34b6201
Merge branch 'current_scaling_group_quant' of github.com:vthumbe1503/…
vthumbe1503 Jun 13, 2026
b42af46
varying last dims can also be overallocated
vthumbe1503 Jun 13, 2026
d7d6628
split commmon into layout and tma files for arch specific vs non arch…
Jun 14, 2026
032dc15
Merge branch 'main' into current_scaling_group_quant
vthumbe1503 Jun 14, 2026
7946edd
address reviem comment
vthumbe1503 Jun 14, 2026
cce8d9f
no need to depend on multi tensor impl.. nvte API to compute grouepd …
vthumbe1503 Jun 14, 2026
53470ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2026
c65faa6
fix test
vthumbe1503 Jun 14, 2026
1afb576
Merge branch 'current_scaling_group_quant' of github.com:vthumbe1503/…
vthumbe1503 Jun 14, 2026
c73a0cc
Remove duplicate comment in transformer_engine.h
vthumbe1503 Jun 15, 2026
7c868df
fix comment and add nvte check
Jun 15, 2026
7161288
add CPP tests for current scaling group quantize + splits to offsets 2d
vthumbe1503 Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
732 changes: 732 additions & 0 deletions benchmarks/benchmark_group_quantize_current_scaling.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
add_executable(test_operator
test_cast.cu
test_cast_current_scaling.cu
test_cast_current_scaling_grouped.cu
test_cast_dbias.cu
test_cast_dbias_dgelu.cu
test_cast_gated_swiglu.cu
Expand Down
468 changes: 468 additions & 0 deletions tests/cpp/operator/test_cast_current_scaling_grouped.cu

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions tests/cpp/operator/test_splits_to_offsets.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,102 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<1>(info.param));
return name;
});

namespace {

// Allocate a device buffer holding `host` stored as `dtype` (int32 or int64).
void *copy_to_device(const std::vector<int64_t> &host, transformer_engine::DType dtype) {
using namespace transformer_engine;
NVTE_CHECK(dtype == DType::kInt32 || dtype == DType::kInt64,
"splits_to_offsets test only supports int32/int64.");
void *dptr = nullptr;
if (dtype == DType::kInt32) {
std::vector<int32_t> tmp(host.begin(), host.end());
NVTE_CHECK_CUDA(cudaMalloc(&dptr, sizeof(int32_t) * tmp.size()));
NVTE_CHECK_CUDA(
cudaMemcpy(dptr, tmp.data(), sizeof(int32_t) * tmp.size(), cudaMemcpyHostToDevice));
} else {
NVTE_CHECK_CUDA(cudaMalloc(&dptr, sizeof(int64_t) * host.size()));
NVTE_CHECK_CUDA(
cudaMemcpy(dptr, host.data(), sizeof(int64_t) * host.size(), cudaMemcpyHostToDevice));
}
return dptr;
}

// Copy a device buffer of `n` `dtype` (int32 or int64) elements back to host as int64.
std::vector<int64_t> copy_to_host(const void *dptr, size_t n, transformer_engine::DType dtype) {
using namespace transformer_engine;
NVTE_CHECK(dtype == DType::kInt32 || dtype == DType::kInt64,
"splits_to_offsets test only supports int32/int64.");
std::vector<int64_t> out(n);
if (dtype == DType::kInt32) {
std::vector<int32_t> tmp(n);
NVTE_CHECK_CUDA(cudaMemcpy(tmp.data(), dptr, sizeof(int32_t) * n, cudaMemcpyDeviceToHost));
out.assign(tmp.begin(), tmp.end());
} else {
NVTE_CHECK_CUDA(cudaMemcpy(out.data(), dptr, sizeof(int64_t) * n, cudaMemcpyDeviceToHost));
}
return out;
}

} // namespace

class SplitsToOffsets2DTestSuite
: public ::testing::TestWithParam<std::tuple<size_t, transformer_engine::DType>> {};

TEST_P(SplitsToOffsets2DTestSuite, TestSplitsToOffsets2D) {
using namespace transformer_engine;

const size_t num_tensors = std::get<0>(GetParam());
const DType dtype = std::get<1>(GetParam());

// Generate per-tensor first/last dims. Vary both dimensions so the test
// exercises the 2D prefix sum (offset[i+1] = sum_{j<=i} first_dims[j] * last_dims[j]).
std::vector<int64_t> h_first_dims(num_tensors);
std::vector<int64_t> h_last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
h_first_dims[i] = static_cast<int64_t>((i % 17) + 1);
h_last_dims[i] = static_cast<int64_t>((i % 5) + 1) * 16;
}

std::vector<int64_t> h_expected(num_tensors + 1, 0);
for (size_t i = 0; i < num_tensors; ++i) {
h_expected[i + 1] = h_expected[i] + h_first_dims[i] * h_last_dims[i];
}

void *d_first_dims = copy_to_device(h_first_dims, dtype);
void *d_last_dims = copy_to_device(h_last_dims, dtype);

std::vector<int64_t> h_output_init(num_tensors + 1, -1);
void *d_output = copy_to_device(h_output_init, dtype);

TensorWrapper first_dims_w(d_first_dims, std::vector<size_t>{num_tensors}, dtype);
TensorWrapper last_dims_w(d_last_dims, std::vector<size_t>{num_tensors}, dtype);
TensorWrapper output_w(d_output, std::vector<size_t>{num_tensors + 1}, dtype);

nvte_splits_to_offsets_2d(first_dims_w.data(), last_dims_w.data(), output_w.data(),
0 /* stream */);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

std::vector<int64_t> h_output = copy_to_host(d_output, num_tensors + 1, dtype);

NVTE_CHECK_CUDA(cudaFree(d_first_dims));
NVTE_CHECK_CUDA(cudaFree(d_last_dims));
NVTE_CHECK_CUDA(cudaFree(d_output));

for (size_t i = 0; i < h_output.size(); ++i) {
EXPECT_EQ(h_output[i], h_expected[i])
<< "Mismatch at index " << i << ": expected " << h_expected[i] << ", got " << h_output[i];
}
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest, SplitsToOffsets2DTestSuite,
::testing::Combine(::testing::ValuesIn(splits_to_offsets_num_tensors),
::testing::Values(transformer_engine::DType::kInt32,
transformer_engine::DType::kInt64)),
[](const testing::TestParamInfo<SplitsToOffsets2DTestSuite::ParamType> &info) {
std::string name =
std::to_string(std::get<0>(info.param)) + "X" + test::typeName(std::get<1>(info.param));
return name;
});
Loading
Loading