HLSL Specifications

0035 - Linear Algebra Matrix

StatusUnder Review
Authors
  • Planned Version: SM 6.10

Introduction

GPUs are exceptional parallel data processors, but increasingly it is becoming important to model operations with cross-thread data dependencies. In HLSL and Direct3D these operations have been called Wave operations, or in Vulkan Subgroup operations. Related terms like Quad or derivatives have similar meaning in different scoping contexts. Vulkan has also recently introduced the term “cooperative” when talking about operations that require participation from multiple threads, these can be viewed much like derivative operations but across the full SIMD unit instead of a subset of threads.

All of these terms refer to the way the underlying instructions execute, not necessarily what they do. One big part of this proposal is to take 5 steps back and talk about what they do: linear algebra.

Motivation

HLSL has a Vulkan extension for SIMD matrix types 0021 - vk::Cooperative Matrix, and DirectX had previewed a similar feature in SM 6.8 called Wave Matrix. This proposal is aimed at merging the two into a unified language feature that can be supported on all platforms (with some platform-specific limitations).

This proposal is similar but not directly aligned with 0031 - HLSL Vector Matrix Operations.

Proposed solution

Below is a proposed pseudo-HLSL API. The proposal uses C++20 concepts to represent template type constraints so as to avoid needing SFINAE complications.

Some portion of this API surface is portable between DirectX and Vulkan using the proposed DXIL for DirectX and SPV_KHR_cooperative_matrix for Vulkan. Not all features proposed here are supported in Vulkan, so the API as described is in the dx namespace.

A subsequent revision to HLSL’s 0021 - Vulkan Cooperative Matrix support could be considered separately to align on a base set of functionality for inclusion in the hlsl namespace.

namespace dx {
namespace linalg {

template <MatrixComponentType ComponentTy, uint M, uint N, MatrixUse Use,
          MatrixScope Scope>
class Matrix {
  using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type;
  // If this isn't a native scalar, we have an 8-bit type, so we have 4 elements
  // packed in each scalar value.
  static const uint ElementsPerScalar =
      __detail::ComponentTypeTraits<ComponentTy>::IsNativeScalar ? 1 : 4;
  // Computes the number of scalars actually stored in the matrix M dimension
  // accounting for packing.
  static const uint MScalars =
      (M + (ElementsPerScalar - 1)) / ElementsPerScalar;
  // Computes the number of scalars actually stored in the matrix N dimension
  // accounting for packing.
  static const uint NScalars =
      (N + (ElementsPerScalar - 1)) / ElementsPerScalar;

  template <MatrixComponentType NewCompTy, MatrixUse NewUse = Use>
  typename hlsl::enable_if<Scope != MatrixScope::Thread,
                           Matrix<NewCompTy, M, N, NewUse, Scope>>::type
  cast();

  // Element-wise operations
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator+=(T);
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator-=(T);
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator*=(T);
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator/=(T);

  // Apply a unary operation to each element.
  template <UnaryOperation Op, MatrixScope ScopeLocal = Scope>
  typename hlsl::enable_if<Scope != MatrixScope::Thread && ScopeLocal == Scope,
                           Matrix>::type
  ApplyUnaryOperation();

  template <typename T>
  static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                                      Scope != MatrixScope::Thread,
                                  Matrix>::type
  Splat(T Val);

  static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
                     MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <MatrixScope ScopeLocal = Scope>
  static typename hlsl::enable_if<
      Scope != MatrixScope::Thread && ScopeLocal == Scope, Matrix>::type
  Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
       MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <typename T>
  static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                                      Scope != MatrixScope::Thread,
                                  Matrix>::type
  Load(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
       MatrixLayout Layout);

  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::A &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           Matrix>::type
      FromThreadVectors(vector<ElementType, MScalars>);

  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::B &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           Matrix>::type
      FromThreadVectors(vector<ElementType, NScalars>);

  template <MatrixScope ScopeLocal = Scope>
  typename hlsl::enable_if<Scope != MatrixScope::Thread && ScopeLocal == Scope,
                           void>::type
  Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
        MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <typename T>
  typename hlsl::enable_if<
      hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, void>::type
  Store(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
        MatrixLayout Layout);

  // Accumulate methods
  template <MatrixScope ScopeLocal = Scope>
  typename hlsl::enable_if<Use == MatrixUse::Accumulator && ScopeLocal == Scope,
                           void>::type
  Accumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
             MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <typename T, MatrixUse UseLocal = Use>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Use == MatrixUse::Accumulator &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           void>::type
  Accumulate(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
             MatrixLayout Layout);

  // Extract the thread-specific vector.
  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::A &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           vector<ElementType, MScalars>>::type
  GetThreadVector(uint Index = 0);

  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::B &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           vector<ElementType, NScalars>>::type
  GetThreadVector(uint Index = 0);

  template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K,
            MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           void>::type
  MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>,
                     const Matrix<RHSTy, K, N, MatrixUse::B, Scope>);

  template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K,
            MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           void>::type
  SumAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>,
                const Matrix<RHSTy, K, N, MatrixUse::B, Scope>);
};

MatrixUse AccumulatorLayout();

template <MatrixComponentType OutTy, MatrixComponentType ATy,
          MatrixComponentType BTy, uint M, uint N, uint K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::Wave>,
         const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::Wave>);

template <MatrixComponentType T, uint M, uint N, uint K>
Matrix<T, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::Wave>,
         const Matrix<T, K, N, MatrixUse::B, MatrixScope::Wave>);

template <MatrixComponentType OutTy, MatrixComponentType ATy,
          MatrixComponentType BTy, uint M, uint N, uint K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup>
Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
         const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::ThreadGroup>);

template <MatrixComponentType T, uint M, uint N, uint K>
Matrix<T, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup>
Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
         const Matrix<T, K, N, MatrixUse::B, MatrixScope::ThreadGroup>);

// Cooperative Vector Replacement API
// Cooperative Vector operates on per-thread vectors multiplying against B
// matrices.

template <typename OutputElTy, typename InputElTy, uint M, uint K,
          MatrixComponentType MatrixDT, MatrixScope Scope>
vector<OutputElTy, K> Multiply(vector<InputElTy, M>,
                               Matrix<MatrixDT, M, K, MatrixUse::B, Scope>);

template <typename OutputElTy, typename InputElTy, typename BiasElTy, uint M,
          uint K, MatrixComponentType MatrixDT, MatrixScope Scope>
vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>,
                                  Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
                                  vector<BiasElTy, K>);

// Outer product functions
template <MatrixComponentType OutTy, MatrixScope Scope, typename InputElTy,
          uint M, uint N>
Matrix<OutTy, M, N, MatrixUse::Accumulator, Scope>
    OuterProduct(vector<InputElTy, M>, vector<InputElTy, N>);

} // namespace linalg
} // namespace dx

Example Usage: Wave Matrix

RWByteAddressBuffer B : register(u0);

void WaveMatrixExample() {
  using namespace dx::linalg;
  using MatrixATy =
      Matrix<MatrixComponentType::F16, 8, 32, MatrixUse::A, MatrixScope::Wave>;
  using MatrixBTy =
      Matrix<MatrixComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>;
  using MatrixAccumTy = Matrix<MatrixComponentType::F16, 8, 16,
                               MatrixUse::Accumulator, MatrixScope::Wave>;
  using MatrixAccum32Ty = Matrix<MatrixComponentType::F32, 8, 16,
                                 MatrixUse::Accumulator, MatrixScope::Wave>;

  MatrixATy MatA = MatrixATy::Load(B, 0, 8 * 4, MatrixLayout::RowMajor);
  MatrixBTy MatB = MatrixBTy::Load(B, 0, 32 * 4, MatrixLayout::RowMajor);

  MatrixAccumTy Accum = Multiply(MatA, MatB);
  MatrixAccum32Ty Accum32 = Multiply<MatrixComponentType::F32>(MatA, MatB);
}

Example Usage: Cooperative Vectors

ByteAddressBuffer B : register(t0);

void CoopVec() {
  using namespace dx::linalg;
  using MatrixBTy =
      Matrix<MatrixComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>;

  vector<float16_t, 32> Vec = (vector<float16_t, 32>)0;
  MatrixBTy MatB = MatrixBTy::Load(B, 0, 32 * 4, MatrixLayout::RowMajor);
  vector<float16_t, 16> Accum = Multiply<float16_t>(Vec, MatB);
}

Example Usage: OuterProduct and Accumulate

RWByteAddressBuffer Buf : register(u1);

void OuterProdAccum() {
  using namespace dx::linalg;
  using MatrixAccumTy = Matrix<MatrixComponentType::F16, 16, 8,
                               MatrixUse::Accumulator, MatrixScope::Thread>;

  vector<float16_t, 16> VecA = (vector<float16_t, 16>)0;
  vector<float16_t, 8> VecB = (vector<float16_t, 8>)0;
  MatrixAccumTy MatAcc =
      OuterProduct<MatrixComponentType::F16, MatrixScope::Thread>(VecA, VecB);

  MatAcc.Accumulate(Buf, 0, 0, MatrixLayout::OuterProductOptimal);
}

Detailed design

HLSL API Concepts

The new HLSL API introduces a new linalg::Matrix type which represents an opaque matrix object, and contains an intangible handle that refers to the allocated matrix.

The linalg::Matrix template type is parameterized based on the matrix component data type, dimensions, use, and scope. These parameters restrict where and how a matrix can be used.

Matrix Use

The Use parameter of an instance of a linalg::Matrix denotes which argument it can be in matrix-matrix operations or matrix-vector operations.

There are three matrix usages: A, B, and Accumulator.

  • The A matrix usage denotes a matrix that can be the first argument to binary or ternary algebraic operations.
  • The B matrix usage denotes a matrix that can the second argument to binary or ternary algebraic operations.
  • The Accumulator matrix usage denotes a matrix that can either be a produced result from a binary arithmetic operation, or the third argument to a ternary algebraic operation.

The matrix use type parameter enables implementations to optimize the storage and layout of the matrix prior to tensor operations. It may be expensive on some hardware to translate between matrix uses, for that reason we capture the use in the type and require explicit conversion in the HLSL API.

Throughout this document a matrix may be described as a matrix of it’s use (e.g. a matrix with Use == Accumulator is an accumulator matrix, while a matrix with use A is an A matrix.)

Matrix Scope

The Scope parameter of an instance of a linalg::Matrix denotes the uniformity scope of the matrix. The scope impacts which operations can be performed on the matrix and may have performance implications depending on the implementation.

There are three supported matrix scopes: Thread, Wave, and ThreadGroup.

  • The Thread matrix scope denotes that a matrix’s values may vary by thread, which requires that an implementation handle divergent matrix values.
  • The Wave matrix scope denotes that a matrix’s values are uniform across a wave, which allows an implementation to assume all instances of the matrix across a wave are identical.
  • The ThreadGroup matrix scope denotes that a matrix’s values are uniform across a thread group, which allows an implementation to assume all instances of the matrix across a thread group are identical.

Operations are categorized by their scope requirements. Some operations require uniform scope matrices (Wave orThreadGroup), while others can operate on non-uniform (Thread) scope matrices. Operations that support non-uniform scope also support uniform scopes. There may be significant performance benefits when using uniform scope matrices.

When using ThreadGroup scope matrices, explicit barriers are required only when there are actual cross-thread dependencies, such as when multiple threads contribute to building or modifying the matrix before it is used by other threads. The matrix scope semantics handle most synchronization automatically, eliminating the need for barriers between every matrix operation.

The following table summarizes the operations supported for each matrix scope:

OperationThread ScopeWave ScopeThreadGroup Scope
Matrix::cast()
Matrix::operator+=()
Matrix::operator-=()
Matrix::operator*=()
Matrix::operator/=()
Matrix::ApplyUnaryOperation()
Matrix::Splat()
Matrix::Load(ByteAddressBuffer)
Matrix::Load(RWByteAddressBuffer)
Matrix::Load(groupshared)
Matrix::Store(RWByteAddressBuffer)
Matrix::Store(groupshared)
Matrix::Accumulate(RWByteAddressBuffer)
Matrix::Accumulate(groupshared)
Matrix::FromThreadVectors()
Matrix::GetThreadVector()
Matrix::MultiplyAccumulate()
Matrix::SumAccumulate()
linalg::Multiply(Matrix, Matrix)
linalg::Multiply(vector, Matrix)
linalg::MultiplyAdd(vector, Matrix, vector)
linalg::OuterProduct(vector, vector)

Throughout this document a matrix may be described as having a scope as specified by the Scope parameter (e.g. a matrix with Scope == Thread is a matrix with thread scope).

Matrix storage is always opaque, the Scope does not directly restrict how the matrix is stored, it merely denotes allowed scopes of allowed data divergence. A matrix with thread scope must behave as if each thread has a unique copy of the matrix. An implementation may coalesce identical matrices across threads.

Matrix Storage

In HLSL, matrix objects are intangible objects so they do not have defined size or memory layout. When in use, implementations are expected to distribute the storage of matrices across the thread-local storage for all threads in a SIMD unit. An implementation may also utilize caches or other memory regions as appropriate. At the DXIL level a matrix is represented as a handle object.

An A matrix is a collection of per-thread vectors representing matrix rows, while a B matrix is a collection of per-thread vectors representing matrix columns.

An Accumulator matrix may be either an A matrix, or a B matrix, and it varies by hardware implementation.

Restrictions on Dimensions

The HLSL API will enforce restrictions on the K dimension as found in the formula: MxK * KxN = MxN

This restriction impacts the number of rows in an A matrix, and columns in a B matrix, but has no impact on an accumulator matrix.

The minimum and maximum K dimension for Wave and Thread scope matrices is tied to the the minimum and maximum wave size, while the minimum and maximum K dimension for ThreadGroup matrices is tied to the thread group size.

Matrix ScopeScalar element dimensions
ThreadPowers of two between [4,128]
WavePowers of two between [4,128]
ThreadGroup[1,1024]

Sizes for matrices of packed data types are 4 times the valid size for a scalar element.

Not all hardware is required to support all possible dimensions for thread and wave scope matrices, or all possible element types. The shader compiler will encode the dimensions and input and output data types used by each shader in the Pipeline State Validation metadata.

HLSL API Documentation

HLSL Enumerations

enum class MatrixComponentType {
  Invalid = 0,
  I1 = 1,
  I16 = 2,
  U16 = 3,
  I32 = 4,
  U32 = 5,
  I64 = 6,
  U64 = 7,
  F16 = 8,
  F32 = 9,
  F64 = 10,
  SNormF16 = 11,
  UNormF16 = 12,
  SNormF32 = 13,
  UNormF32 = 14,
  SNormF64 = 15,
  UNormF64 = 16,
  PackedS8x32 = 17,
  PackedU8x32 = 18,
};

enum class MatrixUse {
  A = 0,
  B = 1,
  Accumulator = 2,
};

enum class MatrixScope {
  Thread = 0,
  Wave = 1,
  ThreadGroup = 2,
};

enum class UnaryOperation {
  NOp = 0,
  Negate = 1,
  Abs = 2,
  Sin = 3,
  Cos = 4,
  Tan = 5,
};

enum class MatrixLayout {
  RowMajor = 0,
  ColMajor = 1,
  MulOptimal = 2,
  OuterProductOptimal = 3,
};

New hlsl enable_if

namespace hlsl {
template <bool B, typename T> struct enable_if {};

template <typename T> struct enable_if<true, T> {
  using type = T;
};

} // namespace hlsl

This proposal depends on adding a new SFINAE construct hlsl::enable_if which works just like std::enable_if in C++.

New hlsl type traits

namespace hlsl {

template <typename T> struct is_arithmetic {
  static const bool value = false;
};

#define __ARITHMETIC_TYPE(type)                                                \
  template <> struct is_arithmetic<type> {                                     \
    static const bool value = true;                                            \
  };

#if __HLSL_ENABLE_16_BIT
__ARITHMETIC_TYPE(uint16_t)
__ARITHMETIC_TYPE(int16_t)
#endif
__ARITHMETIC_TYPE(uint)
__ARITHMETIC_TYPE(int)
__ARITHMETIC_TYPE(uint64_t)
__ARITHMETIC_TYPE(int64_t)
__ARITHMETIC_TYPE(half)
__ARITHMETIC_TYPE(float)
__ARITHMETIC_TYPE(double)

} // namespace hlsl

This proposal depends on a new is_arithmetic type trait added to the hlsl namespace.

dx::linalg::__detail type traits

namespace __detail {
template<MatrixComponentType T>
struct ComponentTypeTraits {
    using Type = uint;
    static const bool IsNativeScalar = false;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::I16> {
    using Type = int16_t;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::U16> {
    using Type = uint16_t;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::I32> {
    using Type = int32_t;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::U32> {
    using Type = uint32_t;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::I64> {
    using Type = int64_t;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::U64> {
    using Type = uint64_t;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::F16> {
    using Type = float16_t;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::F32> {
    using Type = float;
    static const bool IsNativeScalar = true;
};

template<>
struct ComponentTypeTraits<MatrixComponentType::F64> {
    using Type = double;
    static const bool IsNativeScalar = true;
};
} // namespace __detail

The linalg::__detail::ComponentTypeTraits struct is provided as an implementation detail to enable mapping MatrixComponentType values to their native HLSL element types and differentiating between types that have native scalar support.

Matrix::cast

template <MatrixComponentType NewCompTy, MatrixUse NewUse = Use>
typename hlsl::enable_if<Scope != MatrixScope::Thread,
                         Matrix<NewCompTy, M, N, NewUse, Scope>>::type
Matrix::cast();

The Matrix::cast() function supports casting component types and matrix Use.

Element-wise Operators

template <typename T>
typename hlsl::enable_if<
    hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, Matrix>::type
    Matrix::operator+=(T);
template <typename T>
typename hlsl::enable_if<
    hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, Matrix>::type
    Matrix::operator-=(T);
template <typename T>
typename hlsl::enable_if<
    hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, Matrix>::type
    Matrix::operator*=(T);
template <typename T>
typename hlsl::enable_if<
    hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, Matrix>::type
    Matrix::operator/=(T);

For any arithmetic scalar type the +, -, * and / binary operators perform element-wise arithmetic on the matrix. The returned by-value Matrix contains the same handle and refers to the same (now modified) Matrix.

Matrix::ApplyUnaryOperation<>()

template <linalg::UnaryOperation Op, linalg::MatrixScope ScopeLocal = Scope>
typename hlsl::enable_if<Scope != MatrixScope::Thread && ScopeLocal == Scope,
                         Matrix>::type
Matrix::ApplyUnaryOperation();

Taking the linalg::UnaryOperation enumeration value as a template parameter, this function applies a unary operation to each element in the matrix. Each unary operation will behave with regard to special values in the same way as if the standalone HLSL intrinsic had been applied.

Matrix::Splat(T)

template <typename T>
static typename hlsl::enable_if<
    hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, Matrix>::type
Matrix::Splat(T Val);

Constructs a matrix filled with the provided value casted to the element type. If the matrix is a Wave or ThreadGroup scope matrix, this operation shall behave equivalent to:

Matrix::Splat(WaveReadLaneFirst(Val));

This operation may be called in divergent control flow when creating a thread scope matrix, and must be called in uniform control flow when creating a wave scope or thread group scope matrix.

Matrix::Load

static Matrix Matrix::Load(
    ByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayout Layout,
    uint Align = sizeof(__detail::ComponentTypeTraits<ComponentTy>::Type));

template <MatrixScope ScopeLocal = Scope>
static typename hlsl::enable_if<
    Scope != MatrixScope::Thread && ScopeLocal == Scope, Matrix>::type
Matrix::Load(
    RWByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayout Layout,
    uint Align = sizeof(__detail::ComponentTypeTraits<ComponentTy>::Type));

template <typename T>
static typename hlsl::enable_if<
    hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, Matrix>::type
Matrix::Load(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
             MatrixLayout Layout);

The matrix Load methods create a new matrix of the specified dimensions and fill the matrix by reading data from the supplied source. Thread scope matrices can only be read from ByteAddressBuffer objects. Wave scope matrices can be read from [RW]ByteAddressBuffer objects or groupshared arrays. When read from [RW]ByteAddressBuffer objects the data is assumed to already be in the expected target data format. When read from groupshared memory, the data may be in any arithmetic or packed data type. If the type mismatches the target data type of the matrix a data conversion is applied on load.

The following table specifies the valid values for the Layout parameter given the Load method type and matrix scope. All other combinations are unsupported:

OperationMatrix ScopeMatrix Layout
Matrix::Load(ByteAddressBuffer)Threadany
Matrix::Load(*)Wave, ThreadGroupRowMajor, ColMajor

This operation may be called in divergent control flow when loading a thread scope matrix, and must be called in uniform control flow when loading a wave scope matrix.

Matrix::FromThreadVectors

template <MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::A && Scope != MatrixScope::Thread &&
                             UseLocal == Use,
                         Matrix>::type
    Matrix::FromThreadVectors(vector<ElementType, MScalars>);

template <MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::B && Scope != MatrixScope::Thread &&
                             UseLocal == Use,
                         Matrix>::type
    Matrix::FromThreadVectors(vector<ElementType, NScalars>);

Produces a matrix from per-thread vectors. An A matrix is produced from per-thread column vectors, while a B matrix is produced from per-thread row vectors. The FromThreadVectors construction method is not available for accumulator matrices which vary by hardware implementation.

When creating an A wave scope matrix, the N dimension must be less than or equal to the wave size. When creating an A thread group scope matrix, the N dimension must be less than or equal to the thread group size.

When creating a B wave scope matrix, the M dimension must be less than or equal to the wave size. When creating a B thread group scope matrix, the M dimension must be less than or equal to the thread group size.

Threads outside the matrix size are discarded.

Must be called from wave-uniform control flow.

Matrix::Store

template <MatrixScope ScopeLocal = Scope>
typename hlsl::enable_if<Scope != MatrixScope::Thread && ScopeLocal == Scope,
                         void>::type
Matrix::Store(
    RWByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayout Layout,
    uint Align = sizeof(__detail::ComponentTypeTraits<ComponentTy>::Type));

template <typename T>
typename hlsl::enable_if<
    hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, void>::type
Matrix::Store(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
              MatrixLayout Layout);

The matrix Store methods store the matrix data to a target RWByteAddressBuffer or groupshared array. When storing to RWByteAddressBuffer objects the data is stored in the component type of the matrix object. When storing to groupshared memory, the matrix component data is converted to the target arithmetic or packed data type if the data types do not match.

The following table specifies the valid values for the Layout parameter given the Store method type and matrix scope. All other combinations are unsupported:

OperationMatrix ScopeMatrix Layout
Matrix::Store(*)Wave, ThreadGroupRowMajor, ColMajor

Matrix::Accumulate

template <MatrixScope ScopeLocal = Scope>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && ScopeLocal == Scope,
                         void>::type
Matrix::Accumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
                   MatrixLayout Layout, uint Align = sizeof(ElementType));

template <typename T, MatrixUse UseLocal = Use>
typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                             Use == MatrixUse::Accumulator &&
                             Scope != MatrixScope::Thread && UseLocal == Use,
                         void>::type
Matrix::Accumulate(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
                   MatrixLayout Layout);

The matrix Accumulate methods add the matrix data to a target RWByteAddressBuffer or groupshared array. These methods are only available for matrices with MatrixUse::Accumulator. The RWByteAddressBuffer overload works with all matrix scopes, while the groupshared overload only works with Wave scope matrices. When accumulating to RWByteAddressBuffer objects the data is added in the component type of the matrix object. When accumulating to groupshared memory, the matrix component data is converted to the target arithmetic or packed data type if the data types do not match.

The following table specifies the valid values for the Layout parameter given the Accumulate method type and matrix scope. All other combinations are unsupported:

OperationMatrix ScopeMatrix Layout
Matrix::Accumulate(RWByteAddressBuffer)ThreadOuterProductOptimal
Matrix::Accumulate(*)Wave, ThreadGroupRowMajor, ColMajor

Matrix::GetThreadVector(uint)

template <MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::A && Scope != MatrixScope::Thread &&
                             UseLocal == Use,
                         vector<ElementType, MScalars>>::type
Matrix::GetThreadVector(uint Index = 0);

template <MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::B && Scope != MatrixScope::Thread &&
                             UseLocal == Use,
                         vector<ElementType, NScalars>>::type
Matrix::GetThreadVector(uint Index = 0);

Returns the underlying vector for the associated thread in the matrix. The optional index is used when the matrix K dimension is larger than the wave size to compute the starting offset (i.e. (Index * WaveSize) + ThreadID).

An A matrix produces a vector containing a column of a matrix, while a B matrix produces a vector containing a row of the matrix. This method may not be used on an Accumulator matrix because the matrix layout varies by hardware implementation.

Threads which correspond to threads outside the matrix size will return a vector with all elements zero initialized.

Must be called from wave-uniform control flow for a wave scope matrix and thread group-uniform control flow for a thread group scope matrix..

Matrix::MultiplyAccumulate(Matrix, Matrix)

template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K,
          MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
                             Scope != MatrixScope::Thread && UseLocal == Use,
                         void>::type
Matrix::MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>,
                           const Matrix<RHSTy, K, N, MatrixUse::B, Scope>);

An accumulator matrix with wave or thread group scope has a method MultiplyAccumulate which takes as parameters an M x K A matrix with the same scope and a K x N B matrix with the same scope. The matrix arguments are multiplied against each other and added back into the implicit object accumulator matrix.

Must be called from wave-uniform control flow.

Matrix::SumAccumulate(Matrix, Matrix)

template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K,
          MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
                             Scope != MatrixScope::Thread && UseLocal == Use,
                         void>::type
Matrix::SumAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>,
                      const Matrix<RHSTy, K, N, MatrixUse::B, Scope>);

An accumulator matrix with wave or thread group scope has a method SumAccumulate which takes as parameters an M x K A matrix with the same scope and a K x N B matrix with the same scope. The matrix arguments are added together then added back into the implicit object accumulator matrix.

Must be called from wave-uniform control flow.

Matrix::AccumulatorLayout()

MatrixUse linalg::AccumulatorLayout();

Returns the MatrixUse that identifies the hardware-dependent layout used by Accumulator matrices. This can return MatrixUse::A or MatrixUse::B, and should be evaluated by the driver compiler as a compile-time constant allowing optimizing control flow and dead code elimination.

linalg::Multiply(Matrix, Matrix)

template <MatrixComponentType OutTy, MatrixComponentType ATy,
          MatrixComponentType BTy, uint M, uint N, uint K, MatrixScope Scope>
Matrix<OutTy, M, N, MatrixUse::Accumulator, Scope>
linalg::Multiply(const Matrix<T, M, K, MatrixUse::A, Scope>,
                 const Matrix<T, K, N, MatrixUse::B, Scope>);

template <MatrixComponentType T, uint M, uint N, uint K>
Matrix<T, M, N, MatrixUse::Accumulator, Scope>
linalg::Multiply(const Matrix<T, M, K, MatrixUse::A, Scope>,
                 const Matrix<T, K, N, MatrixUse::B, Scope>);

The linalg::Multiply function has two overloads that take an MxK Wave-scope A matrix, and a KxN Wave-scope B matrix and yields an MxN Wave-scope Accumulator matrix initialized with the product of the two input matrices. One

of the overloads infers the type of the output accumulator to match the input matrices, the other overload takes a template parameter for the output matrix type and takes arguments with potentially mismatched element types.

Must be called from wave-uniform control flow.

linalg::Multiply(vector, Matrix)

template <typename OutputElTy, typename InputElTy, uint M, uint K,
          MatrixComponentType MatrixDT, MatrixScope Scope>
vector<OutputElTy, K>
    linalg::Multiply(vector<InputElTy, M>,
                     Matrix<MatrixDT, M, K, MatrixUse::B, Scope>);

The linalg::Multiply function has an overload that takes an M-element vector and an MxK B matrix of any scope. The function returns a K-element vector.

linalg::OuterProduct(vector, vector)

template <MatrixComponentType OutTy, MatrixScope Scope, typename InputElTy,
          uint M, uint N>
Matrix<OutTy, M, N, MatrixUse::Accumulator, Scope>
    linalg::OuterProduct(vector<InputElTy, M>, vector<InputElTy, N>);

The linalg::OuterProduct function has two overloads that take an M-element vector and an N-element vector and yield an MxN Accumulator matrix with the specified scope initialized with the outer product of the two input vectors. One overload infers the type of the output accumulator to match the input vector element type, the other overload takes a template parameter for the output matrix element type. All matrix scopes are allowed for the output matrix.

linalg::MultiplyAdd(vector, Matrix, vector)

template <typename OutputElTy, typename InputElTy, typename BiasElTy, uint M,
          uint K, MatrixComponentType MatrixDT, MatrixScope Scope>
vector<OutputElTy, K>
    linalg::MultiplyAdd(vector<InputElTy, M>,
                        Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
                        vector<BiasElTy, K>);

The linalg::MultiplyAdd function has an overload that takes an M-element, an MxK B matrix of any scope, and a K-element vector. The operation multiplies the M-element vector by the matrix then adds the K-element vector producing a result K-element vector.

DXIL Enumerations

This feature adds the following new DXIL enumerations, which used as immediate arguments to the new operations.

enum class DXILMatrixUse {
  A = 0,
  B = 1,
  Accumulator = 2,
};

enum class DXILMatrixScope {
  Thread = 0,
  Wave = 1,
  ThreadGroup = 2,
};

enum class DXILMatrixUnaryOperation {
  nop = 0,
  negate = 1,
  abs = 2,
  sin = 3,
  cos = 4,
  tan = 5,
};

enum class DXILMatrixElementwiseOperation {
  invalid = 0;
  add = 1;
  sub = 2;
  mul = 3;
  div = 4;
};

enum class DXILMatrixComponentType {
  Invalid = 0,
  I1 = 1,
  I16 = 2,
  U16 = 3,
  I32 = 4,
  U32 = 5,
  I64 = 6,
  U64 = 7,
  F16 = 8,
  F32 = 9,
  F64 = 10,
  SNormF16 = 11,
  UNormF16 = 12,
  SNormF32 = 13,
  UNormF32 = 14,
  SNormF64 = 15,
  UNormF64 = 16,
  PackedS8x32 = 17,
  PackedU8x32 = 18,
}

DXIL Operations

declare %dx.types.MatrixRef *@dx.op.createMatrix(
  immarg i32, ; opcode
  immarg i32, ; component type (DXILMatrixComponentType)
  immarg i32, ; M dimension
  immarg i32, ; N dimension
  immarg i32, ; matrix Use (DXILMatrixUse)
  immarg i32  ; matrix Scope (DXILMatrixScope)
  )

Creates a new uninitialized matrix with the component, dimensions, use and scope as specified.

declare @dx.op.fillMatrix.[TY](
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  [Ty]                   ; fill value
  )

Fills a matrix with a scalar value. The scalar’s type does not need to match the matrix component’s type, a type conversion is applied following the rules documented in the Conversions section.

declare void @dx.op.castMatrix(
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix destination
  %dx.types.MatrixRef *  ; matrix source
  )

Converts the element and use type of the source matrix to the destination matrix. The source matrix remains valid and unmodified after this operation is applied. Validation shall enforce that both matrices have the same scope.

declare void @dx.op.matrixElementwiseUnaryOp(
  immarg i32,            ; opcode
  immarg i32,            ; unary operation (DXILMatrixUnaryOperation)
  %dx.types.MatrixRef *, ; matrix
  )

Applies a unary math function to each element of the provided matrix.

declare void @dx.op.matrixElementwiseBinaryOp.[TY](
  immarg i32,            ; opcode
  immarg i32,            ; binary operation (DXILMatrixElementwiseOperation)
  %dx.types.MatrixRef *, ; matrix
  [TY]                   ; Value to binary operation
  )

Applies a binary math operation with a wave-uniform value to the elements of the provided matrix.

declare void @dx.op.matrixLoadFromDescriptor(
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  %dx.types.Handle *,    ; ByteAddressBuffer
  i32,                   ; Offset
  i32,                   ; Stride
  i32,                   ; matrix layout
  )

Populates a matrix with data from a [RW]ByteAddressBuffer. If any member of the matrix is OOB the matrix is returned zero-initialized.

Question: Do we need to specify a source format for the data or should we assume DXILMatrixComponentType?

Validation rules will enforce that:

  • Layout is RowMajor or ColMajor for matrix with MatrixScope of Wave or ThreadGroup
  • Stride is 0 if the Layout is not RowMajor or ColMajor
declare void @dx.op.matrixLoadFromMemory.p[Ty](
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  [Ty] * addrspace(4),   ; groupshared T[M * N]
  i32,                   ; Offset
  i32,                   ; Stride
  i32,                   ; matrix layout
  )

Populates a matrix with data from a groupshared array. Data conversions between opaque matrices and groupshared memory are defined in the Conversions section below.

declare void @dx.op.matrixLoadFromThreads.v[NUM][TY](
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  < NUM x [Ty]>,         ; Vector
  )

Populates a matrix from per-thread vectors. For an A matrix the NUM corresponds to the M dimension while for a B matrix it corresponds to the N dimension. The NUM must match the matrix corresponding dimension, unless the element is a packed data type in which case it must be the number of 32-bit unsigned integers used to store M elements. This operation may not be used on Accumulator matrices.

For an A matrix the N dimension must be less than or equal to the WaveSize. For a B matrix the M dimension must be less than or equal to the WaveSize. Values from additional threads are discarded.

The result of this operation is undefined if called from non-uniform control flow.

declare void @dx.op.matrixStoreToDescriptor(
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  %dx.types.Handle *,    ; ByteAddressBuffer
  i32,                   ; Offset
  i32,                   ; Stride
  i32,                   ; matrix layout
  )

Store a matrix to a RWByteAddressBuffer at a specified offset. If any destination address is out of bounds the entire store is a no-op.

Validation rules will enforce that:

  • Layout is RowMajor or ColMajor
declare void @dx.op.matrixStoreToMemory.p[Ty](
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  [Ty] *,                ; groupshared T[M * N]
  i32,                   ; Offset
  i32,                   ; Stride
  i32,                   ; matrix layout
  )

Store a matrix to groupshared memory. Data conversions between opaque matrices and groupshared memory are defined in the Conversions section below.

The validator will ensure that the group shared target memory is large enough for the write.

declare < NUM x [Ty]> @dx.op.matrixExtractToThreads.v[NUM][TY](
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  i32,                   ; Index
  )

Extracts per-thread vectors from a matrix. For an A matrix the NUM corresponds to the M dimension while for a B matrix it corresponds to the N dimension. The NUM must match the matrix corresponding dimension, unless the element is a packed data type in which case it must be the number of 32-bit unsigned integers used to store M elements. This operation may not be used on Accumulator matrices.

The Index argument specifies the starting row or column as a multiple of the wave size. The resulting vector corresponds to the row or column numbered (Index * WaveSize) + ThreadID.

Must be called from wave-uniform control flow.

declare i32 @dx.op.matrixQueryAccumulatorLayout.v[NUM][TY](
  immarg i32,            ; opcode
  )

This opcode must be evaluated at driver compile time and replaced with the appropriate architecture specific value denoting the layout of accumulator matrices. A return value of 0 will denote that accumulator matrices are A layout while a return value of 1 will denote that accumulator matrices are B layout.

declare void @dx.op.matrixOp(
  immarg i32             ; opcode
  %dx.types.MatrixRef *, ; matrix A
  %dx.types.MatrixRef *, ; matrix B
  %dx.types.MatrixRef *  ; matrix C
  )

Two opcodes are available for this operation class, one for multiplying matrices and storing the result as C = A * B. The second for multiply accumulation C += A * B.

Validation rules will enforce that:

  • argument A is an A matrix
  • argument B is a B matrix
  • argument C is an Accumulator matrix
  • All three matrices have the same scope (Wave or ThreadGroup)
  • Matrix A’s dimensions shall be M x K
  • Matrix B’s dimensions shall be K x N
  • Matrix C’s dimensions shall be M x N
  • The element types are compatible

Must be called from wave-uniform control flow.

declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].v[NUMi][TYi](
  immarg i32            ; opcode
  <[NUMi] x [TYi]>,     ; input vector
  %dx.types.MatrixRef * ; matrix A
)

This operation implements a row-vector multiplication against a B matrix.

Note for this operation the matrix can be of any scope.

Validation will enforce that:

  • The input vector is an N element vector
  • The matrix A is a B matrix
declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].v[NUMi][TYi](
  immarg i32             ; opcode
  <[NUMi] x [TYi]>,      ; input vector
  %dx.types.MatrixRef *, ; matrix A
  <[NUMo] x [TYo]>       ; bias vector
)

This operation implements a row-vector multiplication against a B matrix with a bias vector added to the result.

Note for this operation the matrix can be of any scope.

declare void @dx.op.matrixAccumulateToDescriptor(
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  %dx.types.Handle *,    ; RWByteAddressBuffer
  i32,                   ; Offset
  i32,                   ; Stride
  i32                    ; matrix layout
  )

Accumulates a matrix to a RWByteAddressBuffer at a specified offset. This operation is only available for matrices with MatrixUse::Accumulator. The matrix data is added to the existing data in the buffer. The matrix component data is converted to the target arithmetic or packed data type if the data types do not match, then added to the existing data in memory. If any destination address is out of bounds the entire accumulate operation is a no-op.

Validation rules will enforce that:

  • Layout is OuterProductOptimal for matrix with MatrixScope of Thread
  • Layout is RowMajor or ColMajor for matrix with MatrixScope of Wave or ThreadGroup
  • Stride is 0 if the Layout is not RowMajor or ColMajor
declare void @dx.op.matrixAccumulateToMemory.p[Ty](
  immarg i32,            ; opcode
  %dx.types.MatrixRef *, ; matrix
  [Ty] *,                ; groupshared T[M * N]
  i32,                   ; Offset
  i32,                   ; Stride
  i32                    ; matrix layout
  )

Accumulates a matrix to groupshared memory. This operation is only available for matrices with MatrixUse::Accumulator and Wave or ThreadGroup scope. Data conversions between opaque matrices and groupshared memory are defined in the Conversions section below.

The validator will ensure that the group shared target memory is large enough for the write.

declare %dx.types.MatrixRef *@dx.op.matrixOuterProduct(
  immarg i32,            ; opcode
  immarg i32,            ; component type (DXILMatrixComponentType)
  immarg i32,            ; M dimension
  immarg i32,            ; N dimension
  immarg i32,            ; matrix Scope (DXILMatrixScope)
  <[M] x [Ty]>,          ; vector A
  <[N] x [Ty]>           ; vector B
  )

Creates a new MxN accumulator matrix initialized with the outer product of the two input vectors. The matrix scope can be Thread, Wave, or ThreadGroup. The element type of the output matrix matches the element type of the input vectors.

Pipeline State Validation Metadata

Shader Model 6.10 will introduce a version 4 of the Pipeline State Validation RuntimeInfo structure. A new 32-bit unsigned integer LinalgMatrixUses will count the number of MatrixUse objects appended after the signature output vectors (presently the last data at the end of PSV0 for version 3).

The MatrixUse object is defined:

struct MatrixUse {
  uint32_t Dimensions[3]; // M, N, K
  uint8_t Scope;
  uint8_t OperandType;
  uint8_t ResultType;
  uint8_t RESERVED; // unused but reserved for padding/alignment.
  uint32_t Flags; // do we need this?
};

This object will encode each matrix shape and element type as used by the DXIL operations in the matrixOp and matvecmuladd opcode classes.

The Scope field will encode one of the values defined in the DXILMatrixScope enumeration.

The OperandType and ResultType fields will encode one of the values defined in the DXILMatrixComponentType enumeration.

Open questions:

  1. Do we need the M and N dimensions or just the K dimension?
  2. Do we need both operand types, or should we expect the operands to be the same type?
  3. What flags do we need?

Conversions

Appendix 1: Outstanding Questions

  • What is the exhaustive list of data types we need to support?
  • What data type conversions do we need to support?
  • Support for other number formats that aren’t natively supported by HLSL?
  • Do we need to specify a source/destination format for the data in the load and store operations that operate on descriptors or should we assume DXILMatrixComponentType?

Appendix 2: HLSL Header

Compiler Explorer

Note: this mostly works with Clang, but has some issues to work out still.

namespace hlsl {

template <typename T> struct is_arithmetic {
  static const bool value = false;
};

#define __ARITHMETIC_TYPE(type)                                                \
  template <> struct is_arithmetic<type> {                                     \
    static const bool value = true;                                            \
  };

#if __HLSL_ENABLE_16_BIT
__ARITHMETIC_TYPE(uint16_t)
__ARITHMETIC_TYPE(int16_t)
#endif
__ARITHMETIC_TYPE(uint)
__ARITHMETIC_TYPE(int)
__ARITHMETIC_TYPE(uint64_t)
__ARITHMETIC_TYPE(int64_t)
__ARITHMETIC_TYPE(half)
__ARITHMETIC_TYPE(float)
__ARITHMETIC_TYPE(double)

template <bool B, typename T> struct enable_if {};

template <typename T> struct enable_if<true, T> {
  using type = T;
};

} // namespace hlsl

namespace dx {

namespace linalg {

enum class MatrixComponentType {
  Invalid = 0,
  I1 = 1,
  I16 = 2,
  U16 = 3,
  I32 = 4,
  U32 = 5,
  I64 = 6,
  U64 = 7,
  F16 = 8,
  F32 = 9,
  F64 = 10,
  SNormF16 = 11,
  UNormF16 = 12,
  SNormF32 = 13,
  UNormF32 = 14,
  SNormF64 = 15,
  UNormF64 = 16,
  PackedS8x32 = 17,
  PackedU8x32 = 18,
};

namespace __detail {
template <MatrixComponentType T> struct ComponentTypeTraits {
  using Type = uint;
  static const bool IsNativeScalar = false;
};

#define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type)                      \
  template <> struct ComponentTypeTraits<enum_val> {                           \
    using Type = type;                                                         \
    static const bool IsNativeScalar = true;                                   \
  };

#if __HLSL_ENABLE_16_BIT
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::I16, int16_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::U16, uint16_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::F16, float16_t)
#endif

__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::I32, int32_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::U32, uint32_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::F32, float)
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::I64, int64_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::U64, uint64_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::F64, double)

} // namespace __detail

enum class MatrixUse {
  A = 0,
  B = 1,
  Accumulator = 2,
};

enum class MatrixScope {
  Thread = 0,
  Wave = 1,
  ThreadGroup = 2,
};

enum class UnaryOperation {
  NOp = 0,
  Negate = 1,
  Abs = 2,
  Sin = 3,
  Cos = 4,
  Tan = 5,
};

enum class MatrixLayout {
  RowMajor = 0,
  ColMajor = 1,
  MulOptimal = 2,
  OuterProductOptimal = 3,
};

template <MatrixComponentType ComponentTy, uint M, uint N, MatrixUse Use,
          MatrixScope Scope>
class Matrix {
  using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type;
  // If this isn't a native scalar, we have an 8-bit type, so we have 4 elements
  // packed in each scalar value.
  static const uint ElementsPerScalar =
      __detail::ComponentTypeTraits<ComponentTy>::IsNativeScalar ? 1 : 4;
  // Computes the number of scalars actually stored in the matrix M dimension
  // accounting for packing.
  static const uint MScalars =
      (M + (ElementsPerScalar - 1)) / ElementsPerScalar;
  // Computes the number of scalars actually stored in the matrix N dimension
  // accounting for packing.
  static const uint NScalars =
      (N + (ElementsPerScalar - 1)) / ElementsPerScalar;

  template <MatrixComponentType NewCompTy, MatrixUse NewUse = Use>
  typename hlsl::enable_if<Scope != MatrixScope::Thread,
                           Matrix<NewCompTy, M, N, NewUse, Scope>>::type
  cast();

  // Element-wise operations
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator+=(T);
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator-=(T);
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator*=(T);
  template <typename T>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Scope != MatrixScope::Thread,
                           Matrix>::type
  operator/=(T);

  // Apply a unary operation to each element.
  template <UnaryOperation Op, MatrixScope ScopeLocal = Scope>
  typename hlsl::enable_if<Scope != MatrixScope::Thread && ScopeLocal == Scope,
                           Matrix>::type
  ApplyUnaryOperation();

  template <typename T>
  static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                                      Scope != MatrixScope::Thread,
                                  Matrix>::type
  Splat(T Val);

  static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
                     MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <MatrixScope ScopeLocal = Scope>
  static typename hlsl::enable_if<
      Scope != MatrixScope::Thread && ScopeLocal == Scope, Matrix>::type
  Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
       MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <typename T>
  static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                                      Scope != MatrixScope::Thread,
                                  Matrix>::type
  Load(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
       MatrixLayout Layout);

  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::A &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           Matrix>::type
      FromThreadVectors(vector<ElementType, MScalars>);

  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::B &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           Matrix>::type
      FromThreadVectors(vector<ElementType, NScalars>);

  template <MatrixScope ScopeLocal = Scope>
  typename hlsl::enable_if<Scope != MatrixScope::Thread && ScopeLocal == Scope,
                           void>::type
  Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
        MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <typename T>
  typename hlsl::enable_if<
      hlsl::is_arithmetic<T>::value && Scope != MatrixScope::Thread, void>::type
  Store(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
        MatrixLayout Layout);

  // Accumulate methods
  template <MatrixScope ScopeLocal = Scope>
  typename hlsl::enable_if<Use == MatrixUse::Accumulator && ScopeLocal == Scope,
                           void>::type
  Accumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
             MatrixLayout Layout, uint Align = sizeof(ElementType));

  template <typename T, MatrixUse UseLocal = Use>
  typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
                               Use == MatrixUse::Accumulator &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           void>::type
  Accumulate(/*groupshared*/ T Arr[], uint StartIdx, uint Stride,
             MatrixLayout Layout);

  // Extract the thread-specific vector.
  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::A &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           vector<ElementType, MScalars>>::type
  GetThreadVector(uint Index = 0);

  template <MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::B &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           vector<ElementType, NScalars>>::type
  GetThreadVector(uint Index = 0);

  template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K,
            MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           void>::type
  MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>,
                     const Matrix<RHSTy, K, N, MatrixUse::B, Scope>);

  template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K,
            MatrixUse UseLocal = Use>
  typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
                               Scope != MatrixScope::Thread && UseLocal == Use,
                           void>::type
  SumAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>,
                const Matrix<RHSTy, K, N, MatrixUse::B, Scope>);
};

MatrixUse AccumulatorLayout();

template <MatrixComponentType OutTy, MatrixComponentType ATy,
          MatrixComponentType BTy, uint M, uint N, uint K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::Wave>,
         const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::Wave>);

template <MatrixComponentType T, uint M, uint N, uint K>
Matrix<T, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::Wave>,
         const Matrix<T, K, N, MatrixUse::B, MatrixScope::Wave>);

template <MatrixComponentType OutTy, MatrixComponentType ATy,
          MatrixComponentType BTy, uint M, uint N, uint K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup>
Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
         const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::ThreadGroup>);

template <MatrixComponentType T, uint M, uint N, uint K>
Matrix<T, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup>
Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
         const Matrix<T, K, N, MatrixUse::B, MatrixScope::ThreadGroup>);

// Cooperative Vector Replacement API
// Cooperative Vector operates on per-thread vectors multiplying against B
// matrices.

template <typename OutputElTy, typename InputElTy, uint M, uint K,
          MatrixComponentType MatrixDT, MatrixScope Scope>
vector<OutputElTy, K> Multiply(vector<InputElTy, M>,
                               Matrix<MatrixDT, M, K, MatrixUse::B, Scope>);

template <typename OutputElTy, typename InputElTy, typename BiasElTy, uint M,
          uint K, MatrixComponentType MatrixDT, MatrixScope Scope>
vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>,
                                  Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
                                  vector<BiasElTy, K>);

// Outer product functions
template <MatrixComponentType OutTy, MatrixScope Scope, typename InputElTy,
          uint M, uint N>
Matrix<OutTy, M, N, MatrixUse::Accumulator, Scope>
    OuterProduct(vector<InputElTy, M>, vector<InputElTy, N>);

} // namespace linalg
} // namespace dx

RWByteAddressBuffer B : register(u0);

void WaveMatrixExample() {
  using namespace dx::linalg;
  using MatrixATy =
      Matrix<MatrixComponentType::F16, 8, 32, MatrixUse::A, MatrixScope::Wave>;
  using MatrixBTy =
      Matrix<MatrixComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>;
  using MatrixAccumTy = Matrix<MatrixComponentType::F16, 8, 16,
                               MatrixUse::Accumulator, MatrixScope::Wave>;
  using MatrixAccum32Ty = Matrix<MatrixComponentType::F32, 8, 16,
                                 MatrixUse::Accumulator, MatrixScope::Wave>;

  MatrixATy MatA = MatrixATy::Load(B, 0, 8 * 4, MatrixLayout::RowMajor);
  MatrixBTy MatB = MatrixBTy::Load(B, 0, 32 * 4, MatrixLayout::RowMajor);

  MatrixAccumTy Accum = Multiply(MatA, MatB);
  MatrixAccum32Ty Accum32 = Multiply<MatrixComponentType::F32>(MatA, MatB);
}

void CoopVec() {
  using namespace dx::linalg;
  using MatrixBTy =
      Matrix<MatrixComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>;

  vector<float16_t, 32> Vec = (vector<float16_t, 32>)0;
  MatrixBTy MatB = MatrixBTy::Load(B, 0, 32 * 4, MatrixLayout::RowMajor);
  vector<float16_t, 16> Accum = Multiply<float16_t>(Vec, MatB);
}

RWByteAddressBuffer Buf : register(u1);

void OuterProdAccum() {
  using namespace dx::linalg;
  using MatrixAccumTy = Matrix<MatrixComponentType::F16, 16, 8,
                               MatrixUse::Accumulator, MatrixScope::Thread>;

  vector<float16_t, 16> VecA = (vector<float16_t, 16>)0;
  vector<float16_t, 8> VecB = (vector<float16_t, 8>)0;
  MatrixAccumTy MatAcc =
      OuterProduct<MatrixComponentType::F16, MatrixScope::Thread>(VecA, VecB);

  MatAcc.Accumulate(Buf, 0, 0, MatrixLayout::OuterProductOptimal);
}