Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Optimization: Optimized TileShape Configuration for f8 #3617

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

MatrixAssembler
Copy link

Performance Issue with Current F8 TileShape Configuration

The current FBGEMM f8 kernel uses a TileShape configuration of 128x128x128,
while the optimal shape for dense f8 tensor core on H100 is m64n256k32.
The current configuration leads to suboptimal performance for
tensor cores and bandwidth usage.

Optimized TileShape (128x256x128) Implementation

Modification of the TileShape configuration from 128x128x128 to 128x256x128 for large GEMM
operations using a cooperative kernel, enabling optimal bandwidth and tensor cores utilization.
This configuration is notably used in Flash Attention V3 for f8.

Benchmark Results on H100 GPU

Benchmark configuration:

PyTorch 2.6
CUDA 12.4
CPU: AMD EPYC
GPU: NVIDIA H100
Benchmarks are configured with 30 kernel launch iterations
and averaged over 25 Benchmark calculations.
We used the same gemm sizes as in the Colfax benchmarks

Benchmark

f8f8bf16_grouped (G = 4, M = 2,048, N = 8,192, K = 8,192)

TileShape TFlops
128-128-128 1244
128-256-128 1374

f8f8bf16_rowwise (M = N = K = 8,192)

TileShape TFlops
128-128-128 1300
128-256-128 1480

f8f8bf16_tensorwise (M=N=K = 8,192)

TileShape TFlops
128-128-128 1271
128-256-128 1463

Technical Implementation

Modified TileShape from 128-128-128 to 128-256-128 for:

  • f8f8bf16_grouped
  • f8f8bf16_rowwise
  • f8f8bf16_tensorwise

Added cooperative kernel by default for:

  • f8f8bf16_rowwise
  • f8f8bf16_tensorwise

f8f8f16.cu was not modified because it was deprecated compared to f8f8bf16_tensorwise

The modifications only affect large where M > 128 and N > 128 and M or N > 2,048.
The matrices are divided into tiles twice as large, but with kernels using 3
SMs instead of 2. The smaller heuristics of large kernels may experience a
slight reduced efficiency compared to the previous configuration.
An empirical study between F8 kernel configurations and GEMM sizes could benefit FBGEMM.

These changes were made by modifying the minimum necessary code while respecting
existing coding practices in FBGEMM.

Test Coverage

Unit Tests Results

The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize
have been verified for the modified kernels.

@jiawenliu64 @jwfromm Thank you!

- Change TileShape from 128x128x128 to 128x256x128

- Add cooperative kernel by default for f8 kernels
Copy link

netlify bot commented Jan 27, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 760415c
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/67984b40af02880008f9386a
😎 Deploy Preview https://deploy-preview-3617--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@jiawenliu64
Copy link
Member

Thanks for the contribution, @MatrixAssembler !

Wonder if you observe optimization opportunities with M <= 128 || N <= 128?

@jiawenliu64 jiawenliu64 self-requested a review January 27, 2025 18:00
@facebook-github-bot
Copy link
Contributor

@jiawenliu64 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@MatrixAssembler
Copy link
Author

MatrixAssembler commented Jan 28, 2025

Thanks for the contribution, @MatrixAssembler !

Wonder if you observe optimization opportunities with M <= 128 || N <= 128?

Indeed @jiawenliu64, the bandwidth of H100 and B100 is so large
that TileShape sizes significantly affect performance,

f8f8bf16_rowwise (M = N = K = 8,192)

TileShape TFlops
64-128-128 910
128-128-128 1300
128-256-128 1480

We can deduce that for a configuration where 128 >= M > 64
and N > 4224, we get a -30% underperformance
between the 64-128-128 and 128-128-128 kernels.
(for reference, this hasn't been tested empirically yet)

M being the batch size, this can be a common GEMM
configuration in the case of LLMs. Bench shape llama

This is the next contribution I'm preparing,
I have limited access to an H100 GPU and
this empirical research is quite combinatorial.

I would have liked to do the same for B100, to prepare for
Blackwell CUTLASS kernel development.
I don't have access to a Blackwell GPU.

@facebook-github-bot
Copy link
Contributor

@jiawenliu64 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants