pto.mad_mx

pto.mad_mx is part of the Cube MAD Ops.

Summary

Zero-init MX (microscaled) cube matrix multiply: dst[m, n] = mx_product[m, n].

MX Matmul Model

pto.mad_mx* additionally applies microscaling. The scale payloads are loaded with pto.mte_l1_l0a_mx / pto.mte_l1_l0b_mx and are associated with the selected %lhs / %rhs tiles; they are not direct operands of pto.mad_mx*.

The K dimension is partitioned into 32-element groups:

k_group = floor(k / 32)

mx_product[m, n] =
  sum k in 0 .. K-1:
    (lhs[m, k] * lhs_scale[m, k_group]) *
    (rhs[k, n] * rhs_scale[k_group, n])

Current target-profile MX data tiles use f8E4M3FN. %k must be compatible with MX grouping. On the current target profile, MX matmul consumes K in 64-element multiples, which contain two 32-element scale groups.

Mechanism

Functionally equivalent to pto.mad but with the MX scaling applied during the multiply-accumulate. Like pto.mad, the result overwrites L0C.

Syntax

pto.mad_mx %lhs, %rhs, %dst, %m, %n, %k
  unit_flag(check_only | check_and_set)?
  disable_gemv?
  (sat | nosat)?
  n_dir?
  : !pto.ptr<A, l0a>, !pto.ptr<B, l0b>, !pto.ptr<C, l0c>, i64, i64, i64

Inputs

Same parameter shape as pto.mad. %lhs and %rhs must additionally have matching MX scale payloads loaded into L0A / L0B before this op is issued.

See MAD Common Clauses for the optional clauses (note: tf32_mode(...) is not a clause of MX MAD).

Expected Outputs

Result Type Description
None Writes the produced M x N MX-scaled tile to L0C.

Side Effects

Engages the CUBE pipe; reads scale payloads associated with %lhs / %rhs; writes to L0C.

Constraints

Constraints

  • Operands must use a target-supported MX dtype combination (currently f8E4M3FN on the supported profile).
  • Matching left and right MX scale payloads must be loaded before this op via pto.mte_l1_l0a_mx / pto.mte_l1_l0b_mx.
  • %k must satisfy the MX grouping rule described in MX Matmul Model.
  • tf32_mode(...) is not a clause of MX MAD.
  • Other constraints match pto.mad.

Examples

pto.mad_mx %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c64_i64
  : !pto.ptr<f8E4M3FN, l0a>, !pto.ptr<f8E4M3FN, l0b>, !pto.ptr<f32, l0c>, i64, i64, i64