TMATMUL_MX¶
指令示意图¶
简介¶
带额外缩放 Tile 的矩阵乘法 (GEMM),用于支持目标上的混合精度/量化矩阵乘法。
数学语义¶
Let:
M = aMatrix.GetValidRow()K = aMatrix.GetValidCol()N = bMatrix.GetValidCol()
Conceptually, the result corresponds to a matrix multiply over the effective matmul domain (0 <= i < M, 0 <= j < N), with the scaling tiles aScaleMatrix / bScaleMatrix configuring implementation-defined mixed-precision behavior:
\[ \mathrm{C}_{i,j} = \sum_{k=0}^{K-1} \mathrm{A}_{i,k} \cdot \mathrm{B}_{k,j} \]
The exact role of aScaleMatrix / bScaleMatrix (and any dequant/quant semantics) is target-defined.
汇编语法¶
PTO-AS 形式:参见 PTO-AS Specification.
Synchronous forms (conceptual):
%c = tmatmul.mx %a, %a_scale, %b, %b_scale : (!pto.tile<...>, !pto.tile<...>, !pto.tile<...>, !pto.tile<...>) -> !pto.tile<...>
%c_out = tmatmul.mx.acc %c_in, %a, %a_scale, %b, %b_scale : (!pto.tile<...>, !pto.tile<...>, !pto.tile<...>, !pto.tile<...>, !pto.tile<...>) -> !pto.tile<...>
%c = tmatmul.mx.bias %a, %a_scale, %b, %b_scale, %bias : (!pto.tile<...>, !pto.tile<...>, !pto.tile<...>, !pto.tile<...>, !pto.tile<...>) -> !pto.tile<...>
AS Level 1 (SSA)¶
%c = pto.tmatmul.mx %a, %a_scale, %b, %b_scale : (!pto.tile<...>, !pto.tile<...>, !pto.tile<...>, !pto.tile<...>)
-> !pto.tile<...>
%c_out = pto.tmatmul.mx.acc %c_in, %a, %a_scale, %b, %b_scale : (!pto.tile<...>, !pto.tile<...>,
!pto.tile<...>, !pto.tile<...>, !pto.tile<...>) -> !pto.tile<...>
%c = pto.tmatmul.mx.bias %a, %a_scale, %b, %b_scale, %bias : (!pto.tile<...>, !pto.tile<...>,
!pto.tile<...>, !pto.tile<...>, !pto.tile<...>) -> !pto.tile<...>
AS Level 2 (DPS)¶
pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>)
outs(%c : !pto.tile_buf<...>)
pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>,
!pto.tile_buf<...>, !pto.tile_buf<...>) outs(%c_out : !pto.tile_buf<...>)
pto.tmatmul.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>,
!pto.tile_buf<...>, !pto.tile_buf<...>) outs(%c : !pto.tile_buf<...>)
AS Level 1(SSA)¶
%c = pto.tmatmul.mx %a, %a_scale, %b, %b_scale : (!pto.tile<...>, !pto.tile<...>, !pto.tile<...>, !pto.tile<...>)
-> !pto.tile<...>
%c_out = pto.tmatmul.mx.acc %c_in, %a, %a_scale, %b, %b_scale : (!pto.tile<...>, !pto.tile<...>,
!pto.tile<...>, !pto.tile<...>, !pto.tile<...>) -> !pto.tile<...>
%c = pto.tmatmul.mx.bias %a, %a_scale, %b, %b_scale, %bias : (!pto.tile<...>, !pto.tile<...>,
!pto.tile<...>, !pto.tile<...>, !pto.tile<...>) -> !pto.tile<...>
AS Level 2(DPS)¶
pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>)
outs(%c : !pto.tile_buf<...>)
pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>,
!pto.tile_buf<...>, !pto.tile_buf<...>) outs(%c_out : !pto.tile_buf<...>)
pto.tmatmul.mx.bias ins(%a, %a_scale, %b, %b_scale, %bias : !pto.tile_buf<...>, !pto.tile_buf<...>, !pto.tile_buf<...>,
!pto.tile_buf<...>, !pto.tile_buf<...>) outs(%c : !pto.tile_buf<...>)
C++ 内建接口¶
声明于 include/pto/common/pto_instr.hpp:
template <typename TileRes, typename TileLeft, typename TileLeftScale, typename TileRight, typename TileRightScale,
typename... WaitEvents>
PTO_INST RecordEvent TMATMUL_MX(TileRes &cMatrix, TileLeft &aMatrix, TileLeftScale &aScaleMatrix, TileRight &bMatrix, TileRightScale &bScaleMatrix, WaitEvents &... events);
template <AccPhase Phase, typename TileRes, typename TileLeft, typename TileLeftScale, typename TileRight,
typename TileRightScale, typename... WaitEvents>
PTO_INST RecordEvent TMATMUL_MX(TileRes &cMatrix, TileLeft &aMatrix, TileLeftScale &aScaleMatrix, TileRight &bMatrix, TileRightScale &bScaleMatrix, WaitEvents &... events);
template <typename TileRes, typename TileLeft, typename TileLeftScale, typename TileRight, typename TileRightScale,
typename... WaitEvents>
PTO_INST RecordEvent TMATMUL_MX(TileRes &cOutMatrix, TileRes &cInMatrix, TileLeft &aMatrix, TileLeftScale &aScaleMatrix, TileRight &bMatrix, TileRightScale &bScaleMatrix, WaitEvents &... events);
template <AccPhase Phase, typename TileRes, typename TileLeft, typename TileLeftScale, typename TileRight,
typename TileRightScale, typename... WaitEvents>
PTO_INST RecordEvent TMATMUL_MX(TileRes &cOutMatrix, TileRes &cInMatrix, TileLeft &aMatrix, TileLeftScale &aScaleMatrix, TileRight &bMatrix, TileRightScale &bScaleMatrix, WaitEvents &... events);
template <typename TileRes, typename TileLeft, typename TileLeftScale, typename TileRight, typename TileRightScale,
typename TileBias, typename... WaitEvents>
PTO_INST RecordEvent TMATMUL_MX(TileRes &cMatrix, TileLeft &aMatrix, TileLeftScale &aScaleMatrix, TileRight &bMatrix, TileRightScale &bScaleMatrix, TileBias &biasData, WaitEvents &... events);
template <AccPhase Phase, typename TileRes, typename TileLeft, typename TileLeftScale, typename TileRight,
typename TileRightScale, typename TileBias, typename... WaitEvents>
PTO_INST RecordEvent TMATMUL_MX(TileRes &cMatrix, TileLeft &aMatrix, TileLeftScale &aScaleMatrix, TileRight &bMatrix, TileRightScale &bScaleMatrix, TileBias &biasData, WaitEvents &... events);
约束¶
- 实现检查 (A5):
m/k/nare taken fromaMatrix.GetValidRow(),aMatrix.GetValidCol(),bMatrix.GetValidCol().- Static legality checks are enforced via
CheckMadMxValid<...>()(types, shapes, fractals, and scaling tile legality).
- Bias form:
TileBias::DTypemust befloatandTileBias::Loc == TileType::BiaswithTileBias::Rows == 1(A5 checks viastatic_assert).
示例¶
自动(Auto)¶
#include <pto/pto-inst.hpp>
using namespace pto;
void example_auto() {
using A = TileLeft<float8_e5m2_t, 16, 64>;
using B = TileRight<float8_e5m2_t, 64, 32>;
using ScaleA = TileLeftScale<float8_e8m0_t, 16, 2>;
using ScaleB = TileRightScale<float8_e8m0_t, 2, 32>;
using Bias = Tile<TileType::Bias, float, 1, 32>;
using C = TileAcc<float, 16, 32>;
A a;
B b;
ScaleA scaleA;
ScaleB scaleB;
Bias bias;
C c;
TMATMUL_MX(c, a, scaleA, b, scaleB, bias);
}
手动(Manual)¶
#include <pto/pto-inst.hpp>
using namespace pto;
void example_manual() {
using A = TileLeft<float8_e5m2_t, 16, 64>;
using B = TileRight<float8_e5m2_t, 64, 32>;
using ScaleA = TileLeftScale<float8_e8m0_t, 16, 2>;
using ScaleB = TileRightScale<float8_e8m0_t, 2, 32>;
using Bias = Tile<TileType::Bias, float, 1, 32>;
using C = TileAcc<float, 16, 32>;
A a;
B b;
ScaleA scaleA;
ScaleB scaleB;
Bias bias;
C c;
TASSIGN(a, 0x1000);
TASSIGN(b, 0x2000);
TASSIGN(scaleA, GetScaleAddr(a.data()));
TASSIGN(scaleB, GetScaleAddr(b.data()));
TASSIGN(bias, 0x3000);
TASSIGN(c, 0x4000);
TMATMUL_MX(c, a, scaleA, b, scaleB, bias);
}