DeepGEMM 专门为混合专家模型(MoE)提供了分组 GEMM 支持,特别针对专家共享相同形状的场景进行优化。具体使用方法如下:
- 导入分组 GEMM 函数::
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
- 准备连续布局的输入数据::
A = torch.randn(4096, 512, dtype=torch.float8_e4m3fn).cuda() # 多个专家的输入拼接
B = torch.randn(512, 1024, dtype=torch.float8_e4m3fn).cuda()
group_sizes = [1024, 1024, 1024, 1024] # 每个专家的 token 数 - 执行分组 GEMM::
C = m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(A, B, group_sizes)
print(C)
Advertências:
- 输入矩阵 A 的 M 轴需按专家分组拼接,且每个分组的大小需对齐 GEMM M 块大小
- B 矩阵的 N 和 K 轴需保持固定
Essa resposta foi extraída do artigoDeepGEMM: uma biblioteca de código aberto com suporte eficiente para operações de matriz FP8 (DeepSeek Open Source Week Day 3)O