0035 - Linear Algebra Matrix
| Status | Under Review |
|---|---|
| Authors |
- Planned Version: SM 6.10
Introduction
GPUs are exceptional parallel data processors, but increasingly it is becoming important to model operations with cross-thread data dependencies. In HLSL and Direct3D these operations have been called Wave operations, or in Vulkan Subgroup operations. Related terms like Quad or derivatives have similar meaning in different scoping contexts. Vulkan has also recently introduced the term “cooperative” when talking about operations that require participation from multiple threads, these can be viewed much like derivative operations but across the full SIMD unit instead of a subset of threads.
All of these terms refer to the way the underlying instructions execute, not necessarily what they do. One big part of this proposal is to take 5 steps back and talk about what they do: linear algebra.
Motivation
HLSL has a Vulkan extension for SIMD matrix types 0021 - vk::Cooperative Matrix, and DirectX had previewed a similar feature in SM 6.8 called Wave Matrix. This proposal is aimed at merging the two into a unified language feature that can be supported on all platforms (with some platform-specific limitations).
This proposal is similar but not directly aligned with 0031 - HLSL Vector Matrix Operations.
Proposed solution
Below is a proposed pseudo-HLSL API. The proposal uses C++20 concepts to represent template type constraints so as to avoid needing SFINAE complications.
Some portion of this API surface is portable between DirectX and Vulkan using the proposed
DXIL for DirectX and
SPV_KHR_cooperative_matrix
for Vulkan. Not all features proposed here are supported in Vulkan, so the API
as described is in the dx namespace.
A subsequent revision to HLSL’s 0021 - Vulkan Cooperative
Matrix support could be considered separately to align
on a base set of functionality for inclusion in the hlsl namespace.
namespace dx {
namespace linalg {
template <ComponentEnum ElementType, uint DimA> struct VectorRef {
ByteAddressBuffer Buf;
uint Offset;
};
template <typename T, int N, ComponentEnum DT> struct InterpretedVector {
vector<T, N> Data;
static const ComponentEnum Interpretation = DT;
static const SIZE_TYPE Size =
__detail::ComponentTypeTraits<DT>::ElementsPerScalar * N;
};
template <ComponentEnum DT, typename T, int N>
InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) {
InterpretedVector<T, N, DT> IV = {Vec};
return IV;
}
template <ComponentEnum DestTy, ComponentEnum OriginTy, typename T, int N>
InterpretedVector<typename __detail::ComponentTypeTraits<DestTy>::Type,
__detail::DstN<DestTy, OriginTy, N>::Value, DestTy>
Convert(vector<T, N> Vec) {
vector<typename __detail::ComponentTypeTraits<DestTy>::Type,
__detail::DstN<DestTy, OriginTy, N>::Value>
Result;
/* Do conversion somehow... */
return MakeInterpretedVector<DestTy>(Result);
}
template <ComponentEnum ComponentTy, SIZE_TYPE M, SIZE_TYPE N,
MatrixUseEnum Use, MatrixScopeEnum Scope>
class Matrix {
using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type;
// If this isn't a native scalar, we have an 8-bit type, so we have 4 elements
// packed in each scalar value.
static const uint ElementsPerScalar =
__detail::ComponentTypeTraits<ComponentTy>::ElementsPerScalar;
static const bool IsNativeScalar =
__detail::ComponentTypeTraits<ComponentTy>::IsNativeScalar;
template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
bool Transpose = false>
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
Cast();
template <typename T>
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type
Splat(T Val);
static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout, uint Align = 128);
static Matrix Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout, uint Align = 128);
template <typename T, SIZE_TYPE Size>
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
(M * N / ElementsPerScalar <= Size),
Matrix>::type
Load(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayoutEnum Layout);
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
uint>::type
Length();
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
uint2>::type
GetCoordinate(uint Index);
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
ElementType>::type
Get(uint Index);
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
void>::type
Set(uint Index, ElementType Value);
void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout, uint Align = 128);
template <typename T, SIZE_TYPE Size>
typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
(M * N / ElementsPerScalar <= Size),
void>::type
Store(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayoutEnum Layout);
// Accumulate methods
template <MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout,
uint Align = 128);
template <typename T, MatrixUseEnum UseLocal = Use,
MatrixScopeEnum ScopeLocal = Scope, SIZE_TYPE Size>
typename hlsl::enable_if<
hlsl::is_arithmetic<T>::value && Use == MatrixUse::Accumulator &&
UseLocal == Use && (M * N / ElementsPerScalar <= Size) &&
Scope == MatrixScope::Wave && ScopeLocal == Scope,
void>::type
InterlockedAccumulate(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayoutEnum Layout);
template <ComponentEnum CompTy, MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
Accumulate(const Matrix<CompTy, M, N, MatrixUse::A, Scope> MatrixA);
template <ComponentEnum CompTy, MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
Accumulate(const Matrix<CompTy, M, N, MatrixUse::B, Scope> MatrixB);
template <ComponentEnum LHSTy, ComponentEnum RHSTy, SIZE_TYPE K,
MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope> MatrixA,
const Matrix<RHSTy, K, N, MatrixUse::B, Scope> MatrixB);
};
// Thread-scope Matrices are read-only. Using a template partial specialization
// for this simplifies the SFINAE-foo above.
template <ComponentEnum ComponentTy, SIZE_TYPE M, SIZE_TYPE N,
MatrixUseEnum Use>
class Matrix<ComponentTy, M, N, Use, MatrixScope::Thread> {
using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type;
template <MatrixLayoutEnum Layout, MatrixUseEnum UseLocal = Use>
static typename hlsl::enable_if<Use == MatrixUse::A && UseLocal == Use,
Matrix>::type
Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
uint Align = 128);
template <MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset);
};
MatrixUseEnum AccumulatorLayout();
template <ComponentEnum OutTy, ComponentEnum ATy, ComponentEnum BTy,
SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::Wave> MatrixA,
const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::Wave> MatrixB);
template <ComponentEnum CompTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<CompTy, M, K, MatrixUse::A, MatrixScope::Wave> MatrixA,
const Matrix<CompTy, K, N, MatrixUse::B, MatrixScope::Wave> MatrixB);
template <ComponentEnum OutTy, ComponentEnum ATy, ComponentEnum BTy,
SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::ThreadGroup> MatrixA,
const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::ThreadGroup> MatrixB);
template <ComponentEnum CompTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
const Matrix<CompTy, M, K, MatrixUse::A, MatrixScope::ThreadGroup> MatrixA,
const Matrix<CompTy, K, N, MatrixUse::B, MatrixScope::ThreadGroup> MatrixB);
// Cooperative Vector Replacement API
// Cooperative Vector operates on per-thread vectors multiplying against A
// matrices with thread scope.
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
ComponentEnum MatrixDT>
vector<OutputElTy, M>
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec);
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
vector<OutputElTy, M>
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K>, vector<BiasElTy, M> Vec);
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum MatrixDT>
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
vector<OutputElTy, M> >::type
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
vector<BiasElTy, M> Bias);
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
vector<OutputElTy, M> >::type
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef);
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum MatrixDT>
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
vector<OutputElTy, M> >::type
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
VectorRef<BiasElTy, M> BiasRef);
// Outer product functions
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread>
OuterProduct(vector<InputElTy, M> VecA, vector<InputElTy, N> VecB);
} // namespace linalg
} // namespace dx
Example Usage: Wave Matrix
RWByteAddressBuffer B : register(u0);
void WaveMatrixExample() {
using namespace dx::linalg;
using MatrixATy =
Matrix<ComponentType::F16, 8, 32, MatrixUse::A, MatrixScope::Wave>;
using MatrixBTy =
Matrix<ComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>;
using MatrixAccumTy = Matrix<ComponentType::F16, 8, 16,
MatrixUse::Accumulator, MatrixScope::Wave>;
using MatrixAccum32Ty = Matrix<ComponentType::F32, 8, 16,
MatrixUse::Accumulator, MatrixScope::Wave>;
MatrixATy MatA = MatrixATy::Load(
B, 0, /* Row stride = number of columns * element size */ 32 * 4,
MatrixLayout::RowMajor);
MatrixBTy MatB = MatrixBTy::Load(
B, 0, /* Row stride = number of columns * element size */ 16 * 4,
MatrixLayout::RowMajor);
for (uint I = 0; I < MatB.Length(); ++I) {
uint2 Pos = MatB.GetCoordinate(I);
// Run `tanh` on all but the diagonal components for no reasonable reason.
if (Pos.x != Pos.y) {
float16_t Val = MatB.Get(I);
MatB.Set(I, tanh(Val));
}
}
MatrixAccumTy Accum = Multiply(MatA, MatB);
MatrixAccum32Ty Accum32 = Multiply<ComponentType::F32>(MatA, MatB);
}
Example Usage: Cooperative Vectors
ByteAddressBuffer B : register(t0);
void CoopVec() {
using namespace dx::linalg;
using MatrixATy =
Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>;
vector<float16_t, 16> Vec = (vector<float16_t, 16>)0;
MatrixATy MatA = MatrixATy::Load<MatrixLayout::RowMajor>(
MBuf, 0, /* Row stride = number of columns * element size */ 16 * 4);
vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec);
vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0;
vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1, NullBias);
VectorRef<ComponentType::F8_E4M3FN, 16> MemBias = {MBuf,
/*start offset*/ 4096};
vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2, MemBias);
// Clang doesn't yet support packed types.
#ifdef __hlsl_dx_compiler
vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0;
vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>(
MatA, MakeInterpretedVector<ComponentType::F8_E4M3FN>(SomeData), MemBias);
vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>(
MatA, MakeInterpretedVector<ComponentType::F8_E4M3FN>(SomeData),
NullBias);
vector<float16_t, 16> Layer6 = MultiplyAdd<float16_t>(
MatA, MakeInterpretedVector<ComponentType::F8_E4M3FN>(SomeData), MemBias);
// This example creates an interpreted vector where the data needs to be
// converted from a source type to a destination type.
vector<uint, 16> SomeData2 = (vector<uint, 16>)0;
vector<float16_t, 16> Layer7 = MultiplyAdd<float16_t>(
MatA, Convert<ComponentType::F8_E4M3FN, ComponentType::U32>(SomeData2),
MemBias);
#endif
}
Example Usage: OuterProduct and InterlockedAccumulate
RWByteAddressBuffer Buf : register(u1);
void OuterProdAccum() {
using namespace dx::linalg;
using MatrixAccumTy = Matrix<ComponentType::F16, 16, 8,
MatrixUse::Accumulator, MatrixScope::Thread>;
vector<float16_t, 16> VecA = (vector<float16_t, 16>)0;
vector<float16_t, 8> VecB = (vector<float16_t, 8>)0;
MatrixAccumTy MatAcc =
OuterProduct<ComponentType::F16>(VecA, VecB);
MatAcc.InterlockedAccumulate(Buf, 0);
}
Detailed design
HLSL API Concepts
The new HLSL API introduces a new linalg::Matrix type which represents an
opaque matrix object, and contains an intangible value object that refers to the
matrix.
The linalg::Matrix template type is parameterized based on the matrix
component data type, dimensions, use, and scope. These parameters restrict where
and how a matrix can be used.
Stage Availability
All operations on Thread scope matrices are available in all shader stages.
Operations on Wave and ThreadGroup scope matrices are available in compute
shaders.
Matrix Use
The Use parameter of an instance of a linalg::Matrix denotes which argument
it can be in matrix-matrix operations or matrix-vector operations.
There are three matrix usages: A, B, and Accumulator.
- The
Amatrix usage denotes a matrix that can be the first argument to binary or ternary algebraic operations. - The
Bmatrix usage denotes a matrix that can the second argument to binary or ternary algebraic operations. - The
Accumulatormatrix usage denotes a matrix that can either be a produced result from a binary arithmetic operation, or the third argument to a ternary algebraic operation.
The matrix use type parameter enables implementations to optimize the storage and layout of the matrix prior to tensor operations. It may be expensive on some hardware to translate between matrix uses, for that reason we capture the use in the type and require explicit conversion in the HLSL API.
Throughout this document a matrix may be described as a matrix of it’s use (e.g.
a matrix with Use == Accumulator is an accumulator matrix, while a matrix
with use A is an A matrix.)
Matrix Scope
The Scope parameter of an instance of a linalg::Matrix denotes the
uniformity scope of the matrix. The scope impacts which operations can be
performed on the matrix and may have performance implications depending on the
implementation.
There are three supported matrix scopes: Thread, Wave, and ThreadGroup.
- The
Threadmatrix scope denotes that a matrix’s values may vary by thread, which requires that an implementation handle divergent matrix values. - The
Wavematrix scope denotes that a matrix’s values are uniform across a wave, which allows an implementation to assume all instances of the matrix across a wave are identical. - The
ThreadGroupmatrix scope denotes that a matrix’s values are uniform across a thread group, which allows an implementation to assume all instances of the matrix across a thread group are identical.
Operations are categorized by their scope requirements. Some operations require
uniform scope matrices (Wave orThreadGroup), while others can operate on
non-uniform (Thread) scope matrices. Operations must be called from HLSL under
control flow that is at least as uniform as the matrix scope. Thread-scope
may be called in non-uniform control flow, Wave-scope operations must be
called in Wave-uniform control flow, and ThreadGroup-scope operations must
be called in ThreadGroup-uniform control flow. Operations implicitly
synchronize execution across all threads in the matrix’s scope. Calling an
operation from control flow that is not uniform across all participating threads
is undefined behavior.
When using ThreadGroup scope matrices, explicit barriers are required only when
there are actual cross-thread dependencies, such as when multiple threads
contribute to building or modifying the matrix before it is used by other
threads. The matrix scope semantics handle most synchronization automatically,
eliminating the need for barriers between every matrix operation.
The following table summarizes the operations supported for each matrix scope:
| Operation | Thread Scope | Wave Scope | ThreadGroup Scope |
|---|---|---|---|
Matrix::Cast() | ✗ | ✓ | ✓ |
Matrix::Length() | ✗ | ✓ | ✓ |
Matrix::GetCoordinate(uint) | ✗ | ✓ | ✓ |
Matrix::Get(uint) | ✗ | ✓ | ✓ |
Matrix::Set(uint, T) | ✗ | ✓ | ✓ |
Matrix::Splat() | ✗ | ✓ | ✓ |
Matrix::Load(ByteAddressBuffer) | ✓ | ✓ | ✓ |
Matrix::Load(RWByteAddressBuffer) | ✗ | ✓ | ✓ |
Matrix::Load(groupshared) | ✗ | ✓ | ✓ |
Matrix::Store(RWByteAddressBuffer) | ✗ | ✓ | ✓ |
Matrix::Store(groupshared) | ✗ | ✓ | ✓ |
Matrix::InterlockedAccumulate(RWByteAddressBuffer) | ✓ | ✓ | ✓ |
Matrix::InterlockedAccumulate(groupshared) | ✗ | ✓ | ✓ |
Matrix::Accumulate(Matrix) | ✗ | ✓ | ✓ |
Matrix::MultiplyAccumulate() | ✗ | ✓ | ✓ |
linalg::Multiply(Matrix, Matrix) | ✗ | ✓ | ✓ |
linalg::Multiply(Matrix, vector) | ✓ | ✗ | ✗ |
linalg::MultiplyAdd(Matrix, vector, vector) | ✓ | ✗ | ✗ |
linalg::OuterProduct(vector, vector) | ✓ | ✗ | ✗ |
Throughout this document a matrix may be described as having a scope as
specified by the Scope parameter (e.g. a matrix with Scope == Thread is a
matrix with thread scope).
Matrix storage is always opaque, the Scope does not directly restrict how the
matrix is stored, it merely denotes allowed scopes of allowed data divergence.
A matrix with thread scope must behave as if each thread has a unique copy of
the matrix. An implementation may coalesce identical matrices across threads.
Matrix Storage
In HLSL, matrix objects are intangible objects so they do not have defined size or memory layout. When in use, implementations are expected to distribute the storage of matrices across the thread-local storage for all threads in a SIMD unit. An implementation may also utilize caches or other memory regions as appropriate. At the DXIL level a matrix is represented as a value object. Because LLVM 3.7 doesn’t allow value objects of opaque types, the matrix object stores a pointer in the IR, but implementations will replace this with an implementation-defined object.
An A matrix is a collection of per-thread vectors representing matrix rows, while a B matrix is a collection of per-thread vectors representing matrix columns.
An Accumulator matrix may be either an A matrix, or a B matrix, and it varies by hardware implementation.
Restrictions on Dimensions
The HLSL API will enforce restrictions on the K dimension as found in the
formula: MxK * KxN = MxN
This restriction impacts the number of rows in an A matrix, and columns in a B matrix, but has no impact on an accumulator matrix.
The minimum and maximum K dimension for matrices is hardware dependent and
varies by scope. The table below describes the maximums enforced by HLSL and
DXIL validation.
| Matrix Scope | Scalar element dimensions |
|---|---|
| Thread | [4,128] |
| Wave | [4,128] |
| ThreadGroup | [1,1024] |
Sizes for matrices of packed data types are 4 times the valid size for a scalar element.
Not all hardware is required to support all possible dimensions for thread and wave scope matrices, or all possible element types. The shader compiler will encode the dimensions and input and output data types used by each shader in the Pipeline State Validation metadata.
HLSL API Documentation
HLSL Enumerations
// Put this in a dxil constants header.
namespace dxil {
// This enum is _exactly_ the DXIL constants.
enum class ComponentType : uint32_t {
Invalid = 0,
I1 = 1,
I16 = 2,
U16 = 3,
I32 = 4,
U32 = 5,
I64 = 6,
U64 = 7,
F16 = 8,
F32 = 9,
F64 = 10,
SNormF16 = 11,
UNormF16 = 12,
SNormF32 = 13,
UNormF32 = 14,
SNormF64 = 15,
UNormF64 = 16,
PackedS8x32 = 17,
PackedU8x32 = 18,
// BEGIN NEW FOR SM 6.10
I8 = 19,
U8 = 20,
F8_E4M3FN = 21,
F8_E5M2 = 22,
// END
LastEntry
};
}
namespace dx {
namespace linalg {
#define __COMPONENT_TYPE(type) type = (uint)dxil::ComponentType::type
// This enum only defines values that are valid for Matrix component types.
// Each enumeration's value matches the cooresponding DXIL constant.
struct ComponentType {
enum ComponentEnum {
// Signed integers.
__COMPONENT_TYPE(I8),
__COMPONENT_TYPE(I16),
__COMPONENT_TYPE(I32),
__COMPONENT_TYPE(I64),
// Unsigned integers.
__COMPONENT_TYPE(U8),
__COMPONENT_TYPE(U16),
__COMPONENT_TYPE(U32),
__COMPONENT_TYPE(U64),
// Floating point types.
__COMPONENT_TYPE(F8_E4M3FN),
__COMPONENT_TYPE(F8_E5M2),
__COMPONENT_TYPE(F16),
__COMPONENT_TYPE(F32),
__COMPONENT_TYPE(F64),
};
};
using ComponentEnum = ComponentType::ComponentEnum;
struct MatrixUse {
enum MatrixUseEnum {
A = 0,
B = 1,
Accumulator = 2,
};
};
using MatrixUseEnum = MatrixUse::MatrixUseEnum;
struct MatrixScope {
enum MatrixScopeEnum {
Thread = 0,
Wave = 1,
ThreadGroup = 2,
};
};
using MatrixScopeEnum = MatrixScope::MatrixScopeEnum;
struct MatrixLayout {
enum MatrixLayoutEnum {
RowMajor = 0,
ColMajor = 1,
MulOptimal = 2,
MulOptimalTranspose = 3,
OuterProductOptimal = 4,
OuterProductOptimalTranspose = 5,
};
};
using MatrixLayoutEnum = MatrixLayout::MatrixLayoutEnum;
New hlsl enable_if
namespace hlsl {
template <bool B, typename T> struct enable_if {};
template <typename T> struct enable_if<true, T> {
using type = T;
};
} // namespace hlsl
This proposal depends on adding a new SFINAE construct hlsl::enable_if which
works just like std::enable_if in C++.
New hlsl type traits
namespace hlsl {
template <typename T> struct is_arithmetic {
static const bool value = false;
};
#define __ARITHMETIC_TYPE(type) \
template <> struct is_arithmetic<type> { \
static const bool value = true; \
};
#if __HLSL_ENABLE_16_BIT
__ARITHMETIC_TYPE(uint16_t)
__ARITHMETIC_TYPE(int16_t)
#endif
__ARITHMETIC_TYPE(uint)
__ARITHMETIC_TYPE(int)
__ARITHMETIC_TYPE(uint64_t)
__ARITHMETIC_TYPE(int64_t)
__ARITHMETIC_TYPE(half)
__ARITHMETIC_TYPE(float)
__ARITHMETIC_TYPE(double)
} // namespace hlsl
This proposal depends on a new is_arithmetic type trait added to the hlsl
namespace.
dx::linalg::__detail type traits
namespace __detail {
template <ComponentEnum CompTy> struct ComponentTypeTraits {
using Type = uint;
static const bool IsNativeScalar = false;
static const uint ElementsPerScalar = 4;
};
#define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type) \
template <> struct ComponentTypeTraits<enum_val> { \
using Type = type; \
static const bool IsNativeScalar = true; \
static const uint ElementsPerScalar = 1; \
};
#if __HLSL_ENABLE_16_BIT
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I16, int16_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U16, uint16_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F16, float16_t)
#endif
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I32, int32_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U32, uint32_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F32, float)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I64, int64_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U64, uint64_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F64, double)
template <ComponentEnum DstTy, ComponentEnum SrcTy, int SrcN> struct DstN {
static const int Value =
(SrcN * ComponentTypeTraits<SrcTy>::ElementsPerScalar +
ComponentTypeTraits<DstTy>::ElementsPerScalar - 1) /
ComponentTypeTraits<DstTy>::ElementsPerScalar;
};
template <SIZE_TYPE MVal, SIZE_TYPE NVal, bool Transposed> struct DimMN {
static const SIZE_TYPE M = MVal;
static const SIZE_TYPE N = NVal;
};
template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
static const SIZE_TYPE M = NVal;
static const SIZE_TYPE N = MVal;
};
} // namespace __detail
The linalg::__detail::ComponentTypeTraits struct is provided as an
implementation detail to enable mapping ComponentType values to their
native HLSL element types and differentiating between types that have native
scalar support.
The linalg::__detail::DstN struct computes the destination vector size when
converting between two different component types. Value is calculated as the
number of source vector elements (SrcN) multiplied by the source type’s
ElementsPerScalar, divided by the destination type’s ElementsPerScalar,
rounded up to the nearest integer.
The linalg::__detail::DimMN struct conditionally swaps the M and N
dimension values based on the Transposed parameter. When false, M and N
are passed through unchanged. When Transposed is true, M and N are
swapped. This is used by Matrix::Cast to compute the dimensions of a
transposed matrix.
linalg::Convert
template <ComponentEnum DestTy, ComponentEnum OriginTy, typename T, int N>
InterpretedVector<typename __detail::ComponentTypeTraits<DestTy>::Type,
__detail::DstN<DestTy, OriginTy, N>::Value, DestTy>
linalg::Convert(vector<T, N> Vec);
Converts a vector of data interpreted as the OriginTy to a vector of data in
the DestTy. If the OriginTy is a native HLSL type, it must match the type of
the input vector.
The conversions are applied following the documented conversion rules. These rules are different from the standard HLSL type casting rules, and they apply to native and non-native types.
Matrix::Cast
template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
bool Transpose = false>
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
Matrix::Cast();
Requires Wave or ThreadGroup scope input and output matrices.
The Matrix::Cast() function supports casting component types and matrix Use.
Must be called from uniform control flow on scope-uniform matrices.
Matrix::Splat(T)
template <typename T>
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type
Matrix::Splat(T Val);
Requires Wave or ThreadGroup scope matrix output.
Constructs a matrix filled with the provided value casted to the element type. This operation shall behave equivalent to:
Matrix::Splat(WaveReadLaneFirst(Val));
Matrix::Load
static Matrix Matrix::Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout,
uint Align = 128);
// Not available on Thread scope matrices.
static Matrix Matrix::Load(RWByteAddressBuffer Res, uint StartOffset,
uint Stride, MatrixLayoutEnum Layout,
uint Align = 128);
// Not available on Thread scope matrices.
template <typename T, SIZE_TYPE Size>
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
(M * N / ElementsPerScalar <= Size),
Matrix>::type
Matrix::Load(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayoutEnum Layout);
The following table specifies the valid values for the Layout parameter
given the Load method type and matrix scope. All other combinations are
unsupported:
| Operation | Matrix Scope | Matrix Layout |
|---|---|---|
Matrix::Load(ByteAddressBuffer) | Thread | any |
Matrix::Load(*) | Wave, ThreadGroup | RowMajor, ColMajor |
The matrix Load methods create a new matrix of the specified dimensions and
fill the matrix by reading data from the supplied source. Thread scope matrices
can only be read from ByteAddressBuffer objects. Wave scope matrices can be
read from [RW]ByteAddressBuffer objects or groupshared arrays. When read
from [RW]ByteAddressBuffer objects the data is assumed to already be in the
expected target data format. When read from groupshared memory, the data may
be in any arithmetic or packed data type. If the type mismatches the target data
type of the matrix a data conversion is applied on load.
This operation may be called in divergent control flow when loading a thread scope matrix, and must be called in uniform control flow when loading a wave scope matrix.
This operation permits loads from RowMajor, ColumMajor and Optimal layouts
for Thread scope matrices, with an option to transpose the matrix by setting the
Layout parameter to MulOptimalTranspose or OuterProductOptimalTranspose.
Not all component types support transposing, it is implementation specific.
Applications need to query the driver to determine if a matrix transpose is
supported.
For the Load operations on [RW]ByteAddressBuffers, the Stride argument
represents the row or column stride in bytes. For the Load operations on
groupshared arrays, the Stride argument is the count of elements in the
groupshared array.
Reads from memory through Load functions are not atomic and may require
explicit synchronization.
Matrix::Length
uint Matrix::Length();
Requires Wave or ThreadGroup scope matrix.
Returns the number of matrix components accessible to the current thread. If the matrix’s element type does not have a native HLSL representation this function may not be used.
The mapping and distribution of threads to matrix elements is opaque and
implementation-specific. The value returned by Length may be different for
each thread. The sum of the values returned by Length across all threads must
be greater than or equal to the total number of matrix elements. Some
implementations may map multiple threads to the same matrix element. Therefore,
developers should take this into consideration when programming side-effects,
such as atomic operations and/or UAV writes, within user-defined matrix
operations.
May be called from non-uniform control flow. However, given the above rules,
calling Length from divergent threads may result in unpredictable behavior.
For example, the number of matrix elements accessible to each thread will be
inconsistent across different implementations.
Matrix::GetCoordinate
uint2 Matrix::GetCoordinate(uint Index);
Requires Wave or ThreadGroup scope matrix. If the matrix’s element type does
not have a native HLSL representation this function may not be used.
Converts a specified index into row and column coordinates. The valid range of
Index is [0, Length()-1]. If the value of Index is out of
range, then the result value is UINT32_MAX.xx. The mapping of indices to
matrix coordinates is implementation-specific.
Matrix::Get
ElementType Matrix::Get(uint Index);
Requires Wave or ThreadGroup scope matrix. If the matrix’s element type does
not have a native HLSL representation this function may not be used.
Retrieves the value of a matrix component at the specified index. The valid
range of Index is [0, Length()-1]. If the value of Index is out of range,
then the result value zero casted to the ElementType.
Matrix::Set
void Matrix::Set(uint Index, ElementType Value);
Requires Wave or ThreadGroup scope matrix. If the matrix’s element type does
not have a native HLSL representation this function may not be used.
Sets the value of a matrix component at the specified index. The valid
range of Index is [0, Length()-1]. If the value of Index is out of range,
then the operation is a no-op.
Matrix::Store
void Matrix::Store(
RWByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayout Layout,
uint Align = 128);
template <typename T, SIZE_TYPE Size>
typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
(M * N / ElementsPerScalar <= Size),
void>::type
Matrix::Store(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayout Layout);
The following table specifies the valid values for the Layout parameter
given the Store method type and matrix scope. All other combinations are
unsupported:
| Operation | Matrix Scope | Matrix Layout |
|---|---|---|
Matrix::Store(*) | Wave, ThreadGroup | RowMajor, ColMajor |
The matrix Store methods store the matrix data to a target
RWByteAddressBuffer or groupshared array. When storing to
RWByteAddressBuffer objects the data is stored in the component type of the
matrix object. When storing to groupshared memory, the matrix component data
is converted to the target arithmetic or packed data type if the data types do
not match.
For the Store operations on [RW]ByteAddressBuffers, the Stride argument
represents the row or column stride in bytes. For the Store operations on
groupshared arrays, the Stride argument is the count of elements in the
groupshared array.
Writes to memory through Store functions are not atomic and may require
explicit synchronization.
Matrix::InterlockedAccumulate
// When Scope != Thread, the following overloads are available:
template <MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
Matrix::InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset,
uint Stride, MatrixLayoutEnum Layout,
uint Align = 128);
template <typename T, MatrixUseEnum UseLocal = Use,
MatrixScopeEnum ScopeLocal = Scope, SIZE_TYPE Size>
typename hlsl::enable_if<
hlsl::is_arithmetic<T>::value && Use == MatrixUse::Accumulator &&
UseLocal == Use && (M * N / ElementsPerScalar <= Size) &&
Scope == MatrixScope::Wave && ScopeLocal == Scope,
void>::type
Matrix::InterlockedAccumulate(/*groupshared*/ T Arr[Size], uint StartIdx,
uint Stride, MatrixLayoutEnum Layout);
// When Scope == Thread, the following overload is available:
template <MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
Matrix::InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset);
Matrices with Wave and ThreadGroup scope support a Layout parameter which
must be RowMajor or ColMajor. Matrices of Thread scope must be
OuterProductOptimal layout, so no layout parameter is supported.
When used with Wave and ThreadGroup matrices this must be called from
uniform control flow on uniform matrices.
The matrix InterlockedAccumulate methods atomically add the matrix data to a
target RWByteAddressBuffer or groupshared array. These methods are only
available for matrices with MatrixUse::Accumulator use. The
RWByteAddressBuffer overload is available for all matrix scopes, while the
groupshared overload is only available for Wave scope matrices.
When accumulating to RWByteAddressBuffer objects, the accumulation is
performed on the component type of the matrix object. When accumulating to
groupshared memory, the matrix component data is converted to the target
arithmetic or packed data type before atomic arithmetic is performed. No
conversion is performed if the target aritmetic type matches the matrix
component type.
Matrix::MultiplyAccumulate(Matrix, Matrix)
template <ComponentType LHSTy, ComponentType RHSTy, uint K,
MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
Scope != MatrixScope::Thread && UseLocal == Use,
void>::type
Matrix::MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope> MatrixA,
const Matrix<RHSTy, K, N, MatrixUse::B, Scope> MatrixB);
Requires Wave or ThreadGroup scope matrix, and must be called from uniform
control flow on uniform matrices.
An accumulator matrix with wave or thread group scope has a method MultiplyAccumulate which
takes as parameters an M x K A matrix with the same scope and a K x N B matrix with
the same scope. The matrix arguments are multiplied against each other and added
back into the implicit object accumulator matrix.
Matrix::Accumulate(Matrix)
template <ComponentType LHSTy, ComponentType RHSTy, MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
Scope != MatrixScope::Thread && UseLocal == Use,
void>::type
Matrix::Accumulate(const Matrix<LHSTy, M, N, MatrixUse::A, Scope> MatrixA);
template <ComponentType LHSTy, ComponentType RHSTy, MatrixUse UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator &&
Scope != MatrixScope::Thread && UseLocal == Use,
void>::type
Matrix::Accumulate(const Matrix<RHSTy, M, N, MatrixUse::B, Scope> MatrixB);
Requires Wave or ThreadGroup scope matrix, and must be called from uniform
control flow on uniform matrices.
An accumulator matrix with wave or thread group scope has a method Accumulate
which takes as a parameter an M x N A or B matrix. The method adds the
provided matrix argument into the accumulator matrix.
Matrix::AccumulatorLayout()
MatrixUse linalg::AccumulatorLayout();
Returns the MatrixUse that identifies the hardware-dependent layout used by
Accumulator matrices. This can return MatrixUse::A or MatrixUse::B, and
should be evaluated by the driver compiler as a compile-time constant allowing
optimizing control flow and dead code elimination.
linalg::Multiply(Matrix, Matrix)
template <ComponentType OutTy, ComponentType ATy,
ComponentType BTy, uint M, uint N, uint K, MatrixScope Scope>
Matrix<OutTy, M, N, MatrixUse::Accumulator, Scope>
linalg::Multiply(const Matrix<ATy, M, K, MatrixUse::A, Scope> MatrixA,
const Matrix<BTy, K, N, MatrixUse::B, Scope> MatrixB);
template <ComponentType CompTy, uint M, uint N, uint K>
Matrix<CompTy, M, N, MatrixUse::Accumulator, Scope>
linalg::Multiply(const Matrix<CompTy, M, K, MatrixUse::A, Scope> MatrixA,
const Matrix<CompTy, K, N, MatrixUse::B, Scope> MatrixB);
Requires Wave or ThreadGroup scope matrix inputs and output, and must be
called from uniform control flow on uniform matrices.
The linalg::Multiply function has two overloads that take an MxK Wave-scope
A matrix, and a KxN Wave-scope B matrix and yields an MxN Wave-scope
Accumulator matrix initialized with the product of the two input matrices. One
of the overloads infers the type of the output accumulator to match the input matrices, the other overload takes a template parameter for the output matrix type and takes arguments with potentially mismatched element types.
linalg::Multiply(Matrix, vector)
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
ComponentEnum MatrixDT>
vector<OutputElTy, M>
linalg::Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec);
Requires Thread scope matrix input, may be called from divergent control flow.
The linalg::Multiply function has an overload that takes an MxK A matrix
with Thread scope, a K-element vector Vec. The operation multiplies the
matrix by the K-element vector Vec producing a result M-element vector.
linalg::OuterProduct(vector, vector)
template <ComponentType OutTy, typename InputElTy,
uint M, uint N>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread>
linalg::OuterProduct(vector<InputElTy, M> VecA, vector<InputElTy, N> VecB);
The linalg::OuterProduct function takes an M-element vector and an N-element
vector and yield an MxN Accumulator matrix with Thread scope initialized
with the outer product of the two input vectors. The function takes a template
parameter for the output matrix element type.
linalg::MultiplyAdd(Matrix, vector, vector)
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
vector<OutputElTy, M>
linalg::MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias);
Requires Thread scope matrix input, may be called from divergent control flow.
The linalg::MultiplyAdd function has an overload that takes an MxK A matrix
with Thread scope, a K-element vector Vec, and a M-element vector
Bias. The operation multiplies the matrix by the K-element vector Vec and
then adds the M-element vector Bias producing a result M-element vector.
Either vector may be a native vector or an InterpretedVector which combines a
packed element vector with an interpretation type. The M-element vector Bias
may also be a VectorRef which refers to a vector in memory. Using the
VectorRef overload makes it easier for the backend compiler to optimize the
bias vector loads with the ALU operations.
DXIL Types
This feature adds the following new DXIL enumerations, which used as immediate arguments to the new operations.
namespace DXIL {
enum class MatrixUse : unit32_t {
A = 0,
B = 1,
Accumulator = 2,
};
enum class UniformityScope : uint32_t {
Thread = 0, // should we reserve Quad even though we don't need it?
Wave = 1,
ThreadGroup = 2,
};
enum class ComponentType : uint32_t {
Invalid = 0,
I1 = 1,
I16 = 2,
U16 = 3,
I32 = 4,
U32 = 5,
I64 = 6,
U64 = 7,
F16 = 8,
F32 = 9,
F64 = 10,
SNormF16 = 11,
UNormF16 = 12,
SNormF32 = 13,
UNormF32 = 14,
SNormF64 = 15,
UNormF64 = 16,
PackedS8x32 = 17,
PackedU8x32 = 18,
// BEGIN NEW FOR SM 6.10
I8 = 19,
U8 = 20,
F8_E4M3FN = 21,
F8_E5M2 = 22,
// END
LastEntry
};
}
The compiler will generate a permutation of typed matrix handles with names of
the format %dx.types.LinAlgMatrix<mangling>. The mangling scheme for
each type name will capture the type parameterization with the tokens C,
M, N, U and S denoting each encoded property.
; Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Wave>
%dx.types.LinAlgMatrixC10M16N16U0S1 = type { i8 * }
; Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Wave>
%dx.types.LinAlgMatrixC10M16N16U1S1 = type { i8 * }
; Matrix<ComponentType::F32, 16, 16, MatrixUse::Accumulator, MatrixScope::Wave>
%dx.types.LinAlgMatrixC11M16N16U2S1 = type { i8 * }
DXIL validation will enforce that a LinAlgMatrix of any type may not
be bitcast to any other type.
LinAlg Component Types
DXIL validation will enforce that ComponentType for matrix types must be one
of the valid linalg component types listed below:
ComponentType::I8ComponentType::I16ComponentType::I32ComponentType::I64ComponentType::U8ComponentType::U16ComponentType::U32ComponentType::U64ComponentType::F8_E4M3FNComponentType::F8_E5M2ComponentType::F16ComponentType::F32ComponentType::F64
Type Metadata
A new named metadata dx.targetTypes will be added to contain mappings of
attributed matrix types to their type parameters avoiding needing to parse the
type mangling. For the given examples above metadata of the form below will be
generated:
!dx.targetTypes = !{!1, !2, !3}
; Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Wave>
!1 = !{%dx.types.LinAlgMatrixC10M16N16U0S1 undef, i32 10, i32 16, i32 16, i32 0, i32 1 }
; Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Wave>
!2 = !{%dx.types.LinAlgMatrixC10M16N16U1S1 undef, i32 10, i32 16, i32 16, i32 1, i32 1 }
; Matrix<ComponentType::F32, 16, 16, MatrixUse::Accumulator, MatrixScope::Wave>
!3 = !{%dx.types.LinAlgMatrixC11M16N16U2S1 undef, i32 11, i32 16, i32 16, i32 2, i32 1 }
Note: to ease compatability with modern LLVM we want the metadata to avoid encoding pointers since modern LLVM will convert pointers to opaque pointers losing the type information.
DXIL Operations
A new overload shape [MatTy] is introduced in the signatures below. This
shall be the <mangling> part of %dx.types.LinAlgMatrix<mangling> preceded
by the letter m.
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgFillMatrix.[MatTy].[TY](
immarg i32, ; opcode
[Ty] ; fill value
)
Fills a matrix with a scalar value. The scalar’s type does not need to match the matrix component’s type, a type conversion is applied following the rules documented in the Conversions section.
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgCopyConvertMatrix.[MatTy1].[MatTy2](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix source
immarg i1 ; transpose
)
Returns a new matrix which is a copy of the source matrix where the element and
use type of the returned matrix have been converted to MatTy1 from MatTy2.
The source matrix remains valid and unmodified after this operation is applied.
Validation shall enforce that:
- Both matrix types have the same scope
- If the transpose argument is
0both matrices must have the same dimensions. - If the transpose argument is
1the dimensions ofMatTy1andMatTy2are swapped (theMdimension ofMatTy2will match theNdimension ofMatTy1, and theNDimension ofMatTy2will match theMdimension ofMatTy1).
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixLoadFromDescriptor.[MatTy](
immarg i32, ; opcode
%dx.types.Handle, ; ByteAddressBuffer
i32, ; Offset
i32, ; Stride
i32, ; matrix layout
i32 ; alignment
)
Populates a matrix with data from a [RW]ByteAddressBuffer. This operation must observe bounds checking behavior described below.
Question: Do we need to specify a source format for the data or should we assume DXILComponentType?
Validation rules will enforce that:
LayoutisRowMajororColMajorfor matrix withMatrixScopeofWaveorThreadGroupStrideis0if theLayoutis notRowMajororColMajor
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixLoadFromMemory.[MatTy].[Ty](
immarg i32, ; opcode
[Ty] addrspace(3)*, ; groupshared Ty[M * N]
i32, ; Offset
i32, ; Stride
i32 ; matrix layout
)
Populates a matrix with data from a groupshared array. Data conversions
between opaque matrices and groupshared memory are defined in the
Conversions section below.
declare i32 @dx.op.linAlgMatrixLength.[MatTy](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling> ; matrix
)
Returns the number of elements stored in thread-local storage on the active thread for the provided matrix.
declare <2 x i32> @dx.op.linAlgMatrixGetCoordinate.[MatTy](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix
i32 ; thread-local index
)
Returns a two element vector containing the column and row of the matrix that the thread-local index corresponds to.
declare [Ty] @dx.op.linAlgMatrixGetElement.[Ty].[MatTy](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix
i32 ; thread-local index
)
Gets the element of the matrix corresponding to the thread local index provided. If the index is out of range for the values stored in this thread the result is 0.
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixSetElement.[MatTy].[MatTy].[Ty](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; input matrix
i32, ; thread-local index
[Ty] ; value
)
Sets the element of the matrix corresponding to the thread local index provided to the value provided. If the index is out of range for the values stored in this thread the result is a no-op.
declare void @dx.op.linAlgMatrixStoreToDescriptor.[MatTy](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix
%dx.types.Handle, ; ByteAddressBuffer
i32, ; Offset
i32, ; Stride
i32, ; matrix layout
i32 ; alignment
)
Store a matrix to a RWByteAddressBuffer at a specified offset. This operation must observe bounds checking behavior described below.
Validation rules will enforce that:
LayoutisRowMajororColMajor
declare void @dx.op.linAlgMatrixStoreToMemory.[MatTy].[Ty](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix
[Ty] addrspace(3)*, ; groupshared Ty[M * N]
i32, ; Offset
i32, ; Stride
i32 ; matrix layout
)
Store a matrix to groupshared memory. Data conversions between opaque matrices and groupshared memory are defined in the Conversions section below.
The validator will ensure that the group shared target memory is large enough for the write.
declare i32 @dx.op.linAlgMatrixQueryAccumulatorLayout(
immarg i32 ; opcode
)
This opcode must be evaluated at driver compile time and replaced with the
appropriate architecture specific value denoting the layout of accumulator
matrices. A return value of 0 will denote that accumulator matrices are A
layout while a return value of 1 will denote that accumulator matrices are B
layout.
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMultiply.[MatTyC].[MatTyA].[MatTyB](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix A
%dx.types.LinAlgMatrix<mangling> ; matrix B
)
This operation multiplies an A matrix and B matrix into new accumulator matrix
following the form C = A * B.
Validation rules will enforce that:
- argument A is an
Amatrix - argument B is a
Bmatrix - return value (C) is an
Accumulatormatrix - All three matrices have the same scope (Wave or ThreadGroup)
- Matrix A’s dimensions shall be M x K
- Matrix B’s dimensions shall be K x N
- Matrix C’s dimensions shall be M x N
- The element types are compatible
Must be called from wave-uniform control flow.
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixAccumulate.[MatTyC].[MatTyLHS].[MatTyRHS](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix LHS
%dx.types.LinAlgMatrix<mangling> ; matrix RHS
)
This operation accumulates an A or B matrix into an accumulator following
the form LHS = LHS + RHS.
Validation rules will enforce that:
- Argument RHS is an
AorBmatrix - Argument LHS is an
Accumulatormatrix - Type of LHS is the same as the return type
- Both matrices have the same scope (Wave or ThreadGroup)
- Both matrices have the same dimensions
- The element types are compatible
Must be called from wave-uniform control flow.
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMultiplyAccumulate.[MatTyR].[MatTyA].[MatTyB].[MatTyC](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix A
%dx.types.LinAlgMatrix<mangling>, ; matrix B
%dx.types.LinAlgMatrix<mangling> ; matrix C
)
This operation multiplies an A matrix and B matrix and accumlates it into an
accumulator matrix following the form R = C + (A * B).
Validation rules will enforce that:
- argument A is an
Amatrix - argument B is a
Bmatrix - argument C is an
Accumulatormatrix - return value (R) is an
Accumulatormatrix - All four matrices have the same scope (Wave or ThreadGroup)
- Matrix A’s dimensions shall be M x K
- Matrix B’s dimensions shall be K x N
- Matrix C’s dimensions shall be M x N
- Matrix R’s dimensions shall be M x N
- The element types are compatible
Must be called from wave-uniform control flow.
declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMul.v[NUMo][TYo].[MatTy].v[NUMi][TYi](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix A
immarg i1, ; is output signed
<[NUMi] x [TYi]>, ; input vector
immarg i32 ; input interpretation type (DXIL::ComponentType)
)
This operation implements a column-vector multiplication against an A matrix
of Thread scope.
Validation will enforce that:
- The input vector length matches the
Kmatrix dimension - The matrix A is an
Amatrix ofThreadscope - The input interpretation type must be one of the valid linalg component types specified in the list in the LinAlg Component Types section.
- The sign bit for output types should always be true if the output type is a vector of native floating point types.
declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMulAdd.v[NUMo][TYo].[MatTy].v[NUMi][TYi].v[NUMo][TYb](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix A
immarg i1, ; is output signed
<[NUMi] x [TYi]>, ; input vector
immarg i32, ; input interpretation type (DXIL::ComponentType)
<[NUMo] x [TYb]>, ; bias vector
immarg i32 ; bias interpretation type (DXIL::ComponentType)
)
This operation implements a column-vector multiplication against an A matrix
of Thread scope with a bias vector added to the result.
Validation will enforce that:
- The input vector length matches the
Kmatrix dimension - The bias vector length matches the
Mmatrix dimension - The matrix A is an
Amatrix ofThreadscope - The input and bias interpretation type must be one of the valid linalg component types specified in the list in the LinAlg Component Types section.
- The sign bit for output types should always be true if the output type is a vector of native floating point types.
declare void @dx.op.linAlgMatrixAccumulateToDescriptor.[MatTy](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix
%dx.types.Handle, ; RWByteAddressBuffer
i32, ; Offset
i32, ; Stride
i32, ; matrix layout
i32 ; alignment
)
Accumulates a matrix to a RWByteAddressBuffer at a specified offset. This
operation is only available for matrices with MatrixUse::Accumulator. The
matrix data is added to the existing data in the buffer. The matrix component
data is converted to the target arithmetic or packed data type if the data types
do not match, then added to the existing data in memory. This operation must
observe bounds checking behavior described below.
Validation rules will enforce that:
LayoutisOuterProductOptimalfor matrix withMatrixScopeofThreadLayoutisRowMajororColMajorfor matrix withMatrixScopeofWaveorThreadGroupStrideis0if theLayoutis notRowMajororColMajor
declare void @dx.op.linAlgMatrixAccumulateToMemory.[MatTy].[Ty](
immarg i32, ; opcode
%dx.types.LinAlgMatrix<mangling>, ; matrix
[Ty] addrspace(3)*, ; groupshared Ty[M * N]
i32, ; Offset
i32, ; Stride
i32 ; matrix layout
)
Accumulates a matrix to groupshared memory. This operation is only available for
matrices with MatrixUse::Accumulator and Wave or ThreadGroup scope. Data
conversions between opaque matrices and groupshared memory are defined in the
Conversions section below.
The validator will ensure that the group shared target memory is large enough for the write.
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixOuterProduct.[MatTy].v[M][TY].v[N][TY](
immarg i32, ; opcode
<[M] x [Ty]>, ; vector A
<[N] x [Ty]> ; vector B
)
Writes the outer product of the two input vectors into the provided matrix.
The matrix scope must be Thread.
Validation will ensure that:
- The
Mdimension of the matrix matches the length of vectorA, or 1/4th the length for packed types. - The
Ndimension of the matrix matches the length of vectorB, or 1/4th the length for packed types. - The element type of the matrix argument matches the element type of the input
vectors, or the input vectors are
i32if the matrix uses types not directly representable in DXIL. - The element type of vector A and vector B must be the same.
- The matrix output type must be
Threadscope.
declare <[NUMo] x [TYo]> @dx.op.linAlgConvert.v[NUMo][TYo].v[NUMi][TYi](
immarg i32, ; opcode
<[NUMi] x [TYi]>, ; input vector
immarg i32, ; input interpretation type (DXIL::ComponentType)
immarg i32 ; output interpretation type (DXIL::ComponentType)
)
Converts an input vector containing data of the input interpretation type to a vector containing data of the output interpretation type following the documented conversion rules.
Validation will ensure that:
- If the input interpretation type enum refers to a type that has a native DXIL
scalar representation the input vector type matches that scalar type,
otherwise the input vector type should be
i32as if storing 32-bit opaque values. - If the output interpretation type enum refers to a type that has a native DXIL
scalar representation the output vector type matches that scalar type,
otherwise the output vector type should be
i32as if storing 32-bit opaque values. - The output vector length must be equal to
NUMimultiplied by number of elements per scalar in the input interpretatation divided by the number of elements per scalar in the output interpretation (see the__detail::DstNtemplate).
Data Conversion Rules
All APIs introduced in this specification which may apply conversions shall obey these conversion rules.
If the source and destination types are integer types, and the destination type can exactly represent the source value, the value is preserved; otherwise the result is saturated.
Note: this is different from the normal conversion rules for HLSL native data types!
If the source and destination types are floating point types, and the destination type can exactly represent the source value, the result is the exact value; otherwise the conversion is a best-approximation of the source value and is implementaiton-defined.
If the source is an integer type and the destination is a floating point type the result is a round to nearest ties to even (RTNE) conversion.
If the source type is a floating point type and the destination is an integer type the conversion is a round to nearest ties to even (RTNE) saturating conversion.
FP8 Types
This specification introduces two FP8 data formats which may be used with linear
algebra objects. They are F8_E4M3FN and F8_E5M2, and they identify floating
point formats composed of 8 bits with 4 exponent and 3 mantissa bits and 5
exponent and 2 mantissa bits coorespondingly.
| E4M3FN (finite) | E5M2 | |
|---|---|---|
| Exponent Bias | 7 | 15 |
| Infinity | N/A | S.11111.00 |
| NaN | S.1111.111 | S.11111.{01,10,11} |
| Zero | S.0000.000 | S.00000.00 |
| Max | S.1111.110 (448) | S.11110.11 (57344) |
| Min | S.0000.001 (2^-9) | S.00000.01 (2^-16) |
Emulating FP
The DirectX API specification requires that all implementations support both FP8 formats for matrices, bias, and input vectors. If the target hardware does not support the used F8 type an implementation is allowed to convert to any other floating point format as long as the destination format can accurately represent all values of the format used in the shader’s DXIL.
If the driver sees a conversion to an F8 type that is not supported, and the result of that conversion is only used by linear algebra operations, the driver may eliminate the conversion or replace it with a conversion to any other floating point format as long as the new destination format can accurately represent all values of the format used in the shader’s DXIL.
Note: Under emulation if a source value would be converted to a saturated infinity (see: conversion rules) when converting to an F8 type, but the source value can be represented accurately in the emulated FP type, this may cause expected behavior differences.
Bounds Checking Behavior
The @dx.op.linAlgMatrixLoadFromDescriptor operation loads data from a
descriptor. For load operations a default element value of zero casted to the
element type is substituted for out of bounds reads. An implementation may
either perform bounds checking on the full bounds of the load initializing the
full matrix to the default element value if any element is out of bounds, or it
may perform per-element bounds checking initializing only the out of bounds
elements to the default value.
The @dx.op.linAlgMatrixStoreToDescriptor and
@dx.op.linAlgMatrixAccumulateToDescriptor operations write data to a
descriptor. Writes to out of bounds memory are a no-op. An implementation may
either perform bounds checking on the full bounds of the store converting the
whole store to a no-op if any elelemt is out of bounds, or it may perform
per-element bounds checking only converting the out of bounds stores to no-ops.
Note: bounds checking is not required for reads and writes to root descriptors as D3D does not attach dimensions to root descriptors.
Pipeline State Validation Metadata
Shader Model 6.10 will introduce a version 4 of the Pipeline
State Validation RuntimeInfo structure. A new 32-bit unsigned integer
LinalgMatrixUses will count the number of MatrixUse objects appended after
the signature output vectors (presently the last data at the end of PSV0 for
version 3).
The MatrixUse object is defined:
struct MatrixUse {
uint32_t Dimensions[3]; // M, N, K
uint8_t Scope;
uint8_t OperandType;
uint8_t ResultType;
uint8_t RESERVED; // unused but reserved for padding/alignment.
uint32_t Flags; // do we need this?
};
This object will encode each matrix shape and element type as used by the DXIL
operations in the linAlgMatrixMultiply, linAlgMatrixMultiplyAccumulate and
linAlgMatVecMulAdd opcode classes.
The Scope field will encode one of the values defined in the DXILMatrixScope
enumeration.
The OperandType and ResultType fields will encode one of the values defined
in the DXIL::ComponentType enumeration.
Open questions:
- Do we need the M and N dimensions or just the K dimension?
- Do we need both operand types, or should we expect the operands to be the same type?
- What flags do we need?
Appendix 1: HLSL Header
Note: this mostly works with Clang, but has some issues to work out still.
namespace hlsl {
#ifdef __hlsl_dx_compiler
#define SIZE_TYPE int
#else
#define SIZE_TYPE uint
#endif
template <typename T> struct is_arithmetic {
static const bool value = false;
};
#define __ARITHMETIC_TYPE(type) \
template <> struct is_arithmetic<type> { \
static const bool value = true; \
};
#if __HLSL_ENABLE_16_BIT
__ARITHMETIC_TYPE(uint16_t)
__ARITHMETIC_TYPE(int16_t)
#endif
__ARITHMETIC_TYPE(uint)
__ARITHMETIC_TYPE(int)
__ARITHMETIC_TYPE(uint64_t)
__ARITHMETIC_TYPE(int64_t)
__ARITHMETIC_TYPE(half)
__ARITHMETIC_TYPE(float)
__ARITHMETIC_TYPE(double)
template <bool B, typename T> struct enable_if {};
template <typename T> struct enable_if<true, T> {
using type = T;
};
} // namespace hlsl
// Put this in a dxil constants header.
namespace dxil {
// This enum is _exactly_ the DXIL constants.
enum class ComponentType : uint32_t {
Invalid = 0,
I1 = 1,
I16 = 2,
U16 = 3,
I32 = 4,
U32 = 5,
I64 = 6,
U64 = 7,
F16 = 8,
F32 = 9,
F64 = 10,
SNormF16 = 11,
UNormF16 = 12,
SNormF32 = 13,
UNormF32 = 14,
SNormF64 = 15,
UNormF64 = 16,
PackedS8x32 = 17,
PackedU8x32 = 18,
// BEGIN NEW FOR SM 6.10
I8 = 19,
U8 = 20,
F8_E4M3FN = 21,
F8_E5M2 = 22,
// END
LastEntry
};
} // namespace dxil
namespace dx {
namespace linalg {
#define __COMPONENT_TYPE(type) type = (uint)dxil::ComponentType::type
// This enum only defines values that are valid for Matrix component types.
// Each enumeration's value matches the cooresponding DXIL constant.
struct ComponentType {
enum ComponentEnum {
// Signed integers.
__COMPONENT_TYPE(I8),
__COMPONENT_TYPE(I16),
__COMPONENT_TYPE(I32),
__COMPONENT_TYPE(I64),
// Unsigned integers.
__COMPONENT_TYPE(U8),
__COMPONENT_TYPE(U16),
__COMPONENT_TYPE(U32),
__COMPONENT_TYPE(U64),
// Floating point types.
__COMPONENT_TYPE(F8_E4M3FN),
__COMPONENT_TYPE(F8_E5M2),
__COMPONENT_TYPE(F16),
__COMPONENT_TYPE(F32),
__COMPONENT_TYPE(F64),
};
};
using ComponentEnum = ComponentType::ComponentEnum;
struct MatrixUse {
enum MatrixUseEnum {
A = 0,
B = 1,
Accumulator = 2,
};
};
using MatrixUseEnum = MatrixUse::MatrixUseEnum;
struct MatrixScope {
enum MatrixScopeEnum {
Thread = 0,
Wave = 1,
ThreadGroup = 2,
};
};
using MatrixScopeEnum = MatrixScope::MatrixScopeEnum;
struct MatrixLayout {
enum MatrixLayoutEnum {
RowMajor = 0,
ColMajor = 1,
MulOptimal = 2,
MulOptimalTranspose = 3,
OuterProductOptimal = 4,
OuterProductOptimalTranspose = 5,
};
};
using MatrixLayoutEnum = MatrixLayout::MatrixLayoutEnum;
namespace __detail {
template <ComponentEnum CompTy> struct ComponentTypeTraits {
using Type = uint;
static const bool IsNativeScalar = false;
static const uint ElementsPerScalar = 4;
};
#define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type) \
template <> struct ComponentTypeTraits<enum_val> { \
using Type = type; \
static const bool IsNativeScalar = true; \
static const uint ElementsPerScalar = 1; \
};
#if __HLSL_ENABLE_16_BIT
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I16, int16_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U16, uint16_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F16, float16_t)
#endif
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I32, int32_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U32, uint32_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F32, float)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I64, int64_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U64, uint64_t)
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F64, double)
template <ComponentEnum DstTy, ComponentEnum SrcTy, int SrcN> struct DstN {
static const int Value =
(SrcN * ComponentTypeTraits<SrcTy>::ElementsPerScalar +
ComponentTypeTraits<DstTy>::ElementsPerScalar - 1) /
ComponentTypeTraits<DstTy>::ElementsPerScalar;
};
template <SIZE_TYPE MVal, SIZE_TYPE NVal, bool Transposed> struct DimMN {
static const SIZE_TYPE M = MVal;
static const SIZE_TYPE N = NVal;
};
template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
static const SIZE_TYPE M = NVal;
static const SIZE_TYPE N = MVal;
};
} // namespace __detail
template <ComponentEnum ElementType, uint DimA> struct VectorRef {
ByteAddressBuffer Buf;
uint Offset;
};
template <typename T, int N, ComponentEnum DT> struct InterpretedVector {
vector<T, N> Data;
static const ComponentEnum Interpretation = DT;
static const SIZE_TYPE Size =
__detail::ComponentTypeTraits<DT>::ElementsPerScalar * N;
};
template <ComponentEnum DT, typename T, int N>
InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) {
InterpretedVector<T, N, DT> IV = {Vec};
return IV;
}
template <ComponentEnum DestTy, ComponentEnum OriginTy, typename T, int N>
InterpretedVector<typename __detail::ComponentTypeTraits<DestTy>::Type,
__detail::DstN<DestTy, OriginTy, N>::Value, DestTy>
Convert(vector<T, N> Vec) {
vector<typename __detail::ComponentTypeTraits<DestTy>::Type,
__detail::DstN<DestTy, OriginTy, N>::Value>
Result;
/* Do conversion somehow... */
return MakeInterpretedVector<DestTy>(Result);
}
template <ComponentEnum ComponentTy, SIZE_TYPE M, SIZE_TYPE N,
MatrixUseEnum Use, MatrixScopeEnum Scope>
class Matrix {
using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type;
// If this isn't a native scalar, we have an 8-bit type, so we have 4 elements
// packed in each scalar value.
static const uint ElementsPerScalar =
__detail::ComponentTypeTraits<ComponentTy>::ElementsPerScalar;
static const bool IsNativeScalar =
__detail::ComponentTypeTraits<ComponentTy>::IsNativeScalar;
template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
bool Transpose = false>
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
Cast();
template <typename T>
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type
Splat(T Val);
static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout, uint Align = 128);
static Matrix Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout, uint Align = 128);
template <typename T, SIZE_TYPE Size>
static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
(M * N / ElementsPerScalar <= Size),
Matrix>::type
Load(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayoutEnum Layout);
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
uint>::type
Length();
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
uint2>::type
GetCoordinate(uint Index);
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
ElementType>::type
Get(uint Index);
template <ComponentEnum LocalComp = ComponentTy>
typename hlsl::enable_if<LocalComp == ComponentTy && IsNativeScalar,
void>::type
Set(uint Index, ElementType Value);
void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout, uint Align = 128);
template <typename T, SIZE_TYPE Size>
typename hlsl::enable_if<hlsl::is_arithmetic<T>::value &&
(M * N / ElementsPerScalar <= Size),
void>::type
Store(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayoutEnum Layout);
// Accumulate methods
template <MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout,
uint Align = 128);
template <typename T, MatrixUseEnum UseLocal = Use,
MatrixScopeEnum ScopeLocal = Scope, SIZE_TYPE Size>
typename hlsl::enable_if<
hlsl::is_arithmetic<T>::value && Use == MatrixUse::Accumulator &&
UseLocal == Use && (M * N / ElementsPerScalar <= Size) &&
Scope == MatrixScope::Wave && ScopeLocal == Scope,
void>::type
InterlockedAccumulate(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride,
MatrixLayoutEnum Layout);
template <ComponentEnum CompTy, MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
Accumulate(const Matrix<CompTy, M, N, MatrixUse::A, Scope> MatrixA);
template <ComponentEnum CompTy, MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
Accumulate(const Matrix<CompTy, M, N, MatrixUse::B, Scope> MatrixB);
template <ComponentEnum LHSTy, ComponentEnum RHSTy, SIZE_TYPE K,
MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope> MatrixA,
const Matrix<RHSTy, K, N, MatrixUse::B, Scope> MatrixB);
};
// Thread-scope Matrices are read-only. Using a template partial specialization
// for this simplifies the SFINAE-foo above.
template <ComponentEnum ComponentTy, SIZE_TYPE M, SIZE_TYPE N,
MatrixUseEnum Use>
class Matrix<ComponentTy, M, N, Use, MatrixScope::Thread> {
using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type;
template <MatrixLayoutEnum Layout, MatrixUseEnum UseLocal = Use>
static typename hlsl::enable_if<Use == MatrixUse::A && UseLocal == Use,
Matrix>::type
Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
uint Align = 128);
template <MatrixUseEnum UseLocal = Use>
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset);
};
MatrixUseEnum AccumulatorLayout();
template <ComponentEnum OutTy, ComponentEnum ATy, ComponentEnum BTy,
SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::Wave> MatrixA,
const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::Wave> MatrixB);
template <ComponentEnum CompTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave>
Multiply(const Matrix<CompTy, M, K, MatrixUse::A, MatrixScope::Wave> MatrixA,
const Matrix<CompTy, K, N, MatrixUse::B, MatrixScope::Wave> MatrixB);
template <ComponentEnum OutTy, ComponentEnum ATy, ComponentEnum BTy,
SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::ThreadGroup> MatrixA,
const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::ThreadGroup> MatrixB);
template <ComponentEnum CompTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K>
Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
const Matrix<CompTy, M, K, MatrixUse::A, MatrixScope::ThreadGroup> MatrixA,
const Matrix<CompTy, K, N, MatrixUse::B, MatrixScope::ThreadGroup> MatrixB);
// Cooperative Vector Replacement API
// Cooperative Vector operates on per-thread vectors multiplying against A
// matrices with thread scope.
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
ComponentEnum MatrixDT>
vector<OutputElTy, M>
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec);
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
vector<OutputElTy, M>
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec, vector<BiasElTy, M> Vec);
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum MatrixDT>
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
vector<OutputElTy, M> >::type
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
vector<BiasElTy, M> Bias);
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
vector<OutputElTy, M> >::type
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef);
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum MatrixDT>
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
vector<OutputElTy, M> >::type
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
VectorRef<BiasElTy, M> BiasRef);
// Outer product functions
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Thread>
OuterProduct(vector<InputElTy, M> VecA, vector<InputElTy, N> VecB);
} // namespace linalg
} // namespace dx
RWByteAddressBuffer B : register(u0);
void WaveMatrixExample() {
using namespace dx::linalg;
using MatrixATy =
Matrix<ComponentType::F16, 8, 32, MatrixUse::A, MatrixScope::Wave>;
using MatrixBTy =
Matrix<ComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>;
using MatrixAccumTy = Matrix<ComponentType::F16, 8, 16,
MatrixUse::Accumulator, MatrixScope::Wave>;
using MatrixAccum32Ty = Matrix<ComponentType::F32, 8, 16,
MatrixUse::Accumulator, MatrixScope::Wave>;
MatrixATy MatA = MatrixATy::Load(
B, 0, /* Row stride = number of columns * element size */ 32 * 4,
MatrixLayout::RowMajor);
MatrixBTy MatB = MatrixBTy::Load(
B, 0, /* Row stride = number of columns * element size */ 16 * 4,
MatrixLayout::RowMajor);
for (uint I = 0; I < MatB.Length(); ++I) {
uint2 Pos = MatB.GetCoordinate(I);
// Run `tanh` on all but the diagonal components for no reasonable reason.
if (Pos.x != Pos.y) {
float16_t Val = MatB.Get(I);
MatB.Set(I, tanh(Val));
}
}
MatrixAccumTy Accum = Multiply(MatA, MatB);
MatrixAccum32Ty Accum32 = Multiply<ComponentType::F32>(MatA, MatB);
}
ByteAddressBuffer MBuf : register(t0);
void CoopVec() {
using namespace dx::linalg;
using MatrixATy =
Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>;
vector<float16_t, 16> Vec = (vector<float16_t, 16>)0;
MatrixATy MatA = MatrixATy::Load<MatrixLayout::RowMajor>(
MBuf, 0, /* Row stride = number of columns * element size */ 16 * 4);
vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec);
vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0;
vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1, NullBias);
VectorRef<ComponentType::F8_E4M3FN, 16> MemBias = {MBuf,
/*start offset*/ 4096};
vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2, MemBias);
// Clang doesn't yet support packed types.
#ifdef __hlsl_dx_compiler
vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0;
vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>(
MatA, MakeInterpretedVector<ComponentType::F8_E4M3FN>(SomeData), MemBias);
vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>(
MatA, MakeInterpretedVector<ComponentType::F8_E4M3FN>(SomeData),
NullBias);
vector<float16_t, 16> Layer6 = MultiplyAdd<float16_t>(
MatA, MakeInterpretedVector<ComponentType::F8_E4M3FN>(SomeData), MemBias);
// This example creates an interpreted vector where the data needs to be
// converted from a source type to a destination type.
vector<uint, 16> SomeData2 = (vector<uint, 16>)0;
vector<float16_t, 16> Layer7 = MultiplyAdd<float16_t>(
MatA, Convert<ComponentType::F8_E4M3FN, ComponentType::U32>(SomeData2),
MemBias);
#endif
}
RWByteAddressBuffer Buf : register(u1);
void OuterProdAccum() {
using namespace dx::linalg;
using MatrixAccumTy = Matrix<ComponentType::F16, 16, 8,
MatrixUse::Accumulator, MatrixScope::Thread>;
vector<float16_t, 16> VecA = (vector<float16_t, 16>)0;
vector<float16_t, 8> VecB = (vector<float16_t, 8>)0;
MatrixAccumTy MatAcc = OuterProduct<ComponentType::F16>(VecA, VecB);
MatAcc.InterlockedAccumulate(Buf, 0);
}