以下是使用 DeepGEMM 进行基本 FP8 GEMM 运算的步骤:
- 导入库和函数::
import torch
from deep_gemm import gemm_fp8_fp8_bf16_nt - Vorbereiten der Dateneingabe(矩阵 A 和 B,必须是 FP8 格式):
A = torch.randn(1024, 512, dtype=torch.float8_e4m3fn).cuda()
B = torch.randn(512, 1024, dtype=torch.float8_e4m3fn).cuda() - 调用函数进行矩阵乘法::
C = gemm_fp8_fp8_bf16_nt(A, B)
print(C)
Vorbehalte:
- 输入矩阵需位于 GPU 上,且格式需为 FP8(E4M3 或 E5M2)
- 输出结果为 BF16 格式,适合后续计算或存储
Diese Antwort stammt aus dem ArtikelDeepGEMM: Eine Open-Source-Bibliothek mit effizienter Unterstützung für FP8-Matrixoperationen (DeepSeek Open-Source-Woche Tag 3)Die