Skip to content

Batched PW EXX optimization#7469

Draft
zhubonan wants to merge 17 commits into
deepmodeling:developfrom
bonan-group:batch-exx-pr
Draft

Batched PW EXX optimization#7469
zhubonan wants to merge 17 commits into
deepmodeling:developfrom
bonan-group:batch-exx-pr

Conversation

@zhubonan

@zhubonan zhubonan commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Reminder

  • Have you linked an issue with this pull request?
  • Have you added adequate unit tests and/or case tests for your pull request?
  • Have you noticed possible changes of behavior below or in the linked issue?
  • Have you explained the changes of codes in core modules of ESolver, HSolver, ElecState, Hamilt, Operator or Psi? (ignore if not applicable)

Linked Issue

Fix #...

Unit Tests and/or Case Tests for my changes

All ABACUS runtime validations below used OMP_NUM_THREADS=1 unless otherwise noted.

Build/test checks:

cmake --build /tmp/abacus-batch-exx-pr-cpu-mpi --target abacus_pw_para -j2
cmake --build /tmp/abacus-batch-exx-pr-p0p1-gpu-build --target abacus_pw_gpu -j2
ctest --test-dir build-codex-batch-exx-pr-test -R 'MODULE_PW_basis_pw_k_serial$'

CUDA CI-style smoke after the latest FFT metadata fix:

OMP_NUM_THREADS=12 bash ../integrate/Autotest.sh \
  -n 2 \
  -a /tmp/abacus-ci-full-cuda-noelpa-run-build/abacus_basic_gpu \
  -f CASES_GPU.txt \
  -r '^scf_bpcg$'

Result: passed.

Additional local GPU suites run with mpirun -np 2 equivalent Autotest settings:

Suite Result
tests/11_PW_GPU 56 passed
tests/12_NAO_Gamma_GPU 109 passed
tests/13_NAO_multik_GPU 21 passed
tests/15_rtTDDFT_GPU 65 passed
tests/16_SDFT_GPU 28 passed

Validation Inputs

Si8 CPU stress reduced-EXX-grid test:

basis_type              pw
calculation             scf
device                  cpu
precision               double
dft_functional          hse
ks_solver               dav
ecutwfc                 50
ecutexx                 <40 or 100>
nbands                  32
scf_thr                 1e-8
scf_nmax                1
gamma_only              0
symmetry                -1
exxace                  0
exx_separate_loop       0
exx_batch_fft_size      1
cal_stress              1
kpar                    1

Q-tile variant:

exx_use_q_tile          1
exx_band_tile_size      8
exx_q_tile_size         2

K-point grid:

K_POINTS
0
Gamma
2 2 2 0 0 0

Si8 structure: conventional 8-atom Si cell, LATTICE_CONSTANT 1.889766, lattice vectors 5.43090 0 0, 0 5.43090 0, 0 0 5.43090, with the 8 diamond-cubic fractional positions.

GPU timing tests used RTX 5090, precision=single, K_POINTS Gamma 2 2 2, and PW HSE. The patched develop baseline is official/develop plus only the private EXX FFT precision fix needed to avoid the known CUDA memory_op.cu failure in the legacy GPU EXX path.

Si8 GPU inputs:

basis_type              pw
device                  gpu
precision               single
dft_functional          hse
ecutwfc                 50
nbands                  32
symmetry                -1
exxace                  <0 or 1>
exx_separate_loop       <0 for no-ACE, 1 for ACE>
exx_batch_fft_size      8

MgO64 GPU input summary:

basis_type              pw
device                  gpu
precision               single
dft_functional          hse
ecutwfc                 50
nbands                  320
symmetry                1
exxace                  1
exx_separate_loop       1
exx_batch_fft_size      8
exx_full_q_cache        1

Q-tile GPU variants additionally used:

exx_use_q_tile          1
exx_band_tile_size      8
exx_q_tile_size         <1, 2, or 4>

Key Timing and Correctness Results

CPU stress, Si8, mpirun -np 2:

Case EXX FFT / npw Path Final ETOT (eV) E_exx (eV) Pressure (kbar) Total time Stress time
ecutexx=40 24 24 24, npw=2312 non-q-tile -1718.7220983469073872 -40.3665181109 -240.636564 30.23 s 3.85 s
ecutexx=40 24 24 24, npw=2312 q-tile -1718.7220983469057956 -40.3665181110 -240.636564 22.36 s 1.52 s
ecutexx=100 36 36 36, npw=9139 non-q-tile -1720.9038698769220446 -40.3594664072 -239.158678 99.54 s 11.93 s
ecutexx=100 36 36 36, npw=9139 q-tile -1720.9038698769215898 -40.3594664072 -239.158678 83.04 s 5.21 s

GPU timing summary:

Case Main settings Total wall time Key EXX timer Result check
MgO64 HSE ACE, patched develop 2x2x2, ecutwfc=50, nbands=320, symmetry=1, exxace=1 2298 s construct_ace 2029.78 s ETOT -61302.76778492656 eV, E_exx=-1963.6735019729 eV
MgO64 HSE ACE, this PR same, exx_batch_fft_size=8, exx_full_q_cache=1 343.94 s construct_ace 307.12 s ETOT -61302.76734772023 eV, E_exx=-1963.6731245711 eV
Si8 HSE ACE, patched develop 2x2x2, ecutwfc=50, nbands=32, symmetry=-1, exxace=1 18.92 s construct_ace 15.10 s ETOT -857.2818154074363 eV, E_exx=-45.1454174940 eV
Si8 HSE ACE, this PR same, exx_batch_fft_size=8 8.13 s construct_ace 4.26 s ETOT -857.2818340449073 eV, E_exx=-45.1454128819 eV
Si8 HSE no-ACE, patched develop 2x2x2, ecutwfc=50, nbands=32, symmetry=-1, exxace=0 729.15 s act_op 592.99 s ETOT -857.2821719114121 eV, E_exx=-45.1453093445 eV
Si8 HSE no-ACE, this PR same, exx_batch_fft_size=8 35.84 s act_op_batch 23.61 s; cal_exx_energy_batch 7.98 s ETOT -857.2819593448013 eV, E_exx=-45.1455110661 eV

GPU q-tile checks on the current branch:

Case q-tile setting Total wall time Key EXX timer Result check
Si8 HSE ACE batch off 9.60 s construct_ace 4.59 s; act_op_batch 4.41 s ETOT -857.2818340449073 eV, E_exx=-45.1454128819 eV, gap 1.2563273446 eV
Si8 HSE ACE q-tile q=4, band=8 7.89 s construct_ace 3.54 s; act_op_qtile 3.45 s ETOT delta vs no-q +2.34e-5 eV, E_exx delta -6.59e-7 eV
Si8 HSE no-ACE batch off 36.05 s act_op_batch 23.42 s; cal_exx_energy_batch 7.86 s ETOT -857.2819593448013 eV, E_exx=-45.1455110661 eV, gap 1.2562612512 eV
Si8 HSE no-ACE q-tile q=4, band=8 27.66 s act_op_qtile 17.32 s; cal_exx_energy_qtile 6.58 s ETOT delta vs no-q -2.87e-6 eV, E_exx delta +4.07e-7 eV
MgO64 HSE ACE batch, full-q cache on off 345.63 s construct_ace 313.24 s; build_full_q_cache 4.36 s ETOT -61302.76701807489 eV, E_exx=-1963.6730981641 eV, gap 6.3900535284 eV
MgO64 HSE ACE q-tile, full-q cache on q=2, band=8 263.47 s construct_ace 231.30 s; q_tile_pair 21.33 s; build_full_q_cache 4.16 s ETOT delta vs no-q -3.65e-4 eV, E_exx delta +7.11e-5 eV
MgO64 HSE ACE q-tile, full-q cache off q=2, band=8 262.15 s construct_ace 229.93 s; q_tile_pair 21.40 s ETOT delta vs cache-on q-tile -4.78e-4 eV, E_exx delta -3.74e-4 eV

Key observations:

  • MgO64 HSE ACE is 6.68x faster in total wall time and 6.61x faster in ACE construction than patched develop.
  • Si8 HSE ACE is 2.33x faster in total wall time and 3.54x faster in ACE construction than patched develop.
  • Si8 HSE no-ACE is 20.35x faster in total wall time and 25.12x faster in Hamiltonian EXX apply than patched develop.
  • MgO64 q-tile reduced total wall time by about 24% versus the corresponding no-q-tile branch path.
  • The full-q cache diagnostic estimated 3499.14 MB for MgO64 cache-on and 0 MB for cache-off.

What's changed?

This PR ports and completes the batched/q-tile PW EXX implementation for the develop branch.

Main behavior changes:

  • Add batched PW EXX Hamiltonian apply and EXX energy paths for CPU/GPU.
  • Enable GPU batched EXX FFT by default with exx_batch_fft_size=8.
  • Add q-tile EXX layout for no-ACE and ACE/separate-loop paths.
  • Add ACE KPAR q-state fetch/broadcast support for the q-tile path.
  • Restrict PW EXX KPAR to ACE plus exx_separate_loop=1; no-ACE KPAR is blocked for now.
  • Fully honor ecutexx > 0 with separate EXX reciprocal and real-space FFT grids. ecutexx=0 keeps the previous effective ecutrho behavior.
  • Add CPU poolnproc > 1 redistribution between wavefunction and EXX grids.
  • Block GPU poolnproc > 1, because GPU PW FFT does not support intra-pool MPI distribution and would be very slow/unsupported.
  • Refactor q-tile EXX stress to reuse target/q tiles and (k,q) potentials.
  • Add exx_full_q_cache, default 1, for symmetry-reduced PW EXX. This materializes explicit full-q reciprocal wavefunctions to avoid repeated symmetry remaps. Users can set exx_full_q_cache 0 for the lower-memory reduced-q remap-on-demand path.

Q-tile layout

Compared with develop, which applies EXX one target (k,n) and one source (q,m) state at a time, this PR groups real-space EXX work into tiles:

  • target real states: [n_local][ir], controlled by exx_band_tile_size;
  • source real states: [q_local][m_local][ir], controlled by exx_q_tile_size * exx_band_tile_size;
  • source weights: [q_local][m_local];
  • KPAR q-owner pools materialize q states and broadcast source tiles with Parallel_Common::bcast_dev.

For symmetry-reduced runs, exx_full_q_cache 1 keeps the reduced k-point SCF problem but stores explicit full-q EXX wavefunctions. This trades memory for less repeated symmetry rotation/remap work. exx_full_q_cache 0 keeps the memory-saving reduced-q path.

Batched FFT implementation

This PR adds batch-aware FFT setup to PW_Basis and PW_Basis_K through setuptransform(batch_fft_size), while keeping the default setuptransform() behavior equivalent to the existing single-transform path.

The new transform helpers accept contiguous batches of reciprocal or real-space states:

  • PW_Basis::recip_to_real_batch / PW_Basis::real_to_recip_batch for charge and density-like grids;
  • PW_Basis_K::recip_to_real_batch / PW_Basis_K::real_to_recip_batch for k-dependent wavefunction grids.

PW EXX uses these helpers to transform source wavefunction tiles, density products, and EXX energy batches in one batched FFT call instead of launching one FFT per band pair. Batch FFT setup is currently scoped to EXX-owned bases: rhopw_dev, wfcpw_exx, and wfcpw_exx_fullq. Standard PW DFT bases still use the default setup path in this PR. The infrastructure may be reused later for standard PW DFT calculations, but that should be a separate benchmarked follow-up.

Any changes of core modules? (ignore if not applicable)

Core PW modules are changed.

  • OperatorEXXPW and related PW EXX kernels now support batched source/state transforms, q-tile apply/energy, full ecutexx grid honor, full-q cache loads, and ACE KPAR q-tile communication.
  • PW EXX stress is updated to use the tiled q-layout and EXX-grid-sized buffers consistently.
  • PW_Basis and PW_Basis_K gain batch transform setup and batch reciprocal/real transform helpers. The default non-batch call path remains available.
  • Input handling adds/updates PW EXX controls such as exx_full_q_cache, exx_batch_fft_size, exx_use_q_tile, exx_band_tile_size, exx_q_tile_size, and ecutexx documentation/echo.
  • GPU batch FFT planning is intentionally limited to EXX-owned bases in this PR; ordinary PW DFT bases are not switched to batched planning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant