Coverage Report

Created: 2025-09-26 23:27

/home/runner/work/DirectXShaderCompiler/DirectXShaderCompiler/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Line
Count
Source (jump to first uncovered line)
1
//===--- DeclResultIdMapper.cpp - DeclResultIdMapper impl --------*- C++ -*-==//
2
//
3
//                     The LLVM Compiler Infrastructure
4
//
5
// This file is distributed under the University of Illinois Open Source
6
// License. See LICENSE.TXT for details.
7
//
8
//===----------------------------------------------------------------------===//
9
10
#include "DeclResultIdMapper.h"
11
12
#include <algorithm>
13
#include <optional>
14
#include <sstream>
15
16
#include "dxc/DXIL/DxilConstants.h"
17
#include "dxc/DXIL/DxilTypeSystem.h"
18
#include "dxc/Support/SPIRVOptions.h"
19
#include "clang/AST/Expr.h"
20
#include "clang/AST/HlslTypes.h"
21
#include "clang/SPIRV/AstTypeProbe.h"
22
#include "llvm/ADT/SmallBitVector.h"
23
#include "llvm/ADT/StringMap.h"
24
#include "llvm/ADT/StringSet.h"
25
#include "llvm/Support/Casting.h"
26
27
#include "AlignmentSizeCalculator.h"
28
#include "SignaturePackingUtil.h"
29
#include "SpirvEmitter.h"
30
31
namespace clang {
32
namespace spirv {
33
34
namespace {
35
36
// Returns true if the image format is compatible with the sampled type. This is
37
// determined according to the same at
38
// https://docs.vulkan.org/spec/latest/appendices/spirvenv.html#spirvenv-format-type-matching.
39
3.25k
bool areFormatAndTypeCompatible(spv::ImageFormat format, QualType sampledType) {
40
3.25k
  if (format == spv::ImageFormat::Unknown) {
41
3.17k
    return true;
42
3.17k
  }
43
44
84
  if (hlsl::IsHLSLVecType(sampledType)) {
45
    // For vectors, we need to check if the element type is compatible. We do
46
    // not check the number of elements because it is possible that the number
47
    // of elements in the sampled type is different. I could not find in the
48
    // spec what should happen in that case.
49
72
    sampledType = hlsl::GetHLSLVecElementType(sampledType);
50
72
  }
51
52
84
  const Type *desugaredType = sampledType->getUnqualifiedDesugaredType();
53
84
  const BuiltinType *builtinType = dyn_cast<BuiltinType>(desugaredType);
54
84
  if (!builtinType) {
55
0
    return false;
56
0
  }
57
58
84
  switch (format) {
59
10
  case spv::ImageFormat::Rgba32f:
60
10
  case spv::ImageFormat::Rg32f:
61
14
  case spv::ImageFormat::R32f:
62
22
  case spv::ImageFormat::Rgba16f:
63
26
  case spv::ImageFormat::Rg16f:
64
26
  case spv::ImageFormat::R16f:
65
26
  case spv::ImageFormat::Rgba16:
66
26
  case spv::ImageFormat::Rg16:
67
26
  case spv::ImageFormat::R16:
68
26
  case spv::ImageFormat::Rgba16Snorm:
69
30
  case spv::ImageFormat::Rg16Snorm:
70
30
  case spv::ImageFormat::R16Snorm:
71
34
  case spv::ImageFormat::Rgb10A2:
72
38
  case spv::ImageFormat::R11fG11fB10f:
73
38
  case spv::ImageFormat::Rgba8:
74
42
  case spv::ImageFormat::Rg8:
75
46
  case spv::ImageFormat::R8:
76
50
  case spv::ImageFormat::Rgba8Snorm:
77
50
  case spv::ImageFormat::Rg8Snorm:
78
50
  case spv::ImageFormat::R8Snorm:
79
    // 32-bit float
80
50
    return builtinType->getKind() == BuiltinType::Float;
81
8
  case spv::ImageFormat::Rgba32i:
82
8
  case spv::ImageFormat::Rg32i:
83
8
  case spv::ImageFormat::R32i:
84
8
  case spv::ImageFormat::Rgba16i:
85
8
  case spv::ImageFormat::Rg16i:
86
8
  case spv::ImageFormat::R16i:
87
8
  case spv::ImageFormat::Rgba8i:
88
12
  case spv::ImageFormat::Rg8i:
89
12
  case spv::ImageFormat::R8i:
90
    // signed 32-bit int
91
12
    return builtinType->getKind() == BuiltinType::Int;
92
2
  case spv::ImageFormat::Rgba32ui:
93
2
  case spv::ImageFormat::Rg32ui:
94
2
  case spv::ImageFormat::R32ui:
95
6
  case spv::ImageFormat::Rgba16ui:
96
6
  case spv::ImageFormat::Rg16ui:
97
6
  case spv::ImageFormat::R16ui:
98
10
  case spv::ImageFormat::Rgb10a2ui:
99
10
  case spv::ImageFormat::Rgba8ui:
100
10
  case spv::ImageFormat::Rg8ui:
101
10
  case spv::ImageFormat::R8ui:
102
    // unsigned 32-bit int
103
10
    return builtinType->getKind() == BuiltinType::UInt;
104
6
  case spv::ImageFormat::R64i:
105
    // signed 64-bit int
106
6
    return builtinType->getKind() == BuiltinType::LongLong;
107
6
  case spv::ImageFormat::R64ui:
108
    // unsigned 64-bit int
109
6
    return builtinType->getKind() == BuiltinType::ULongLong;
110
84
  }
111
112
0
  return true;
113
84
}
114
115
156
uint32_t getVkBindingAttrSet(const VKBindingAttr *attr, uint32_t defaultSet) {
116
  // If the [[vk::binding(x)]] attribute is provided without the descriptor set,
117
  // we should use the default descriptor set.
118
156
  if (attr->getSet() == INT_MIN) {
119
44
    return defaultSet;
120
44
  }
121
112
  return attr->getSet();
122
156
}
123
124
/// Returns the :packoffset() annotation on the given decl. Returns nullptr if
125
/// the decl does not have one.
126
904
hlsl::ConstantPacking *getPackOffset(const clang::NamedDecl *decl) {
127
904
  for (auto *annotation : decl->getUnusualAnnotations())
128
92
    if (auto *packing = llvm::dyn_cast<hlsl::ConstantPacking>(annotation))
129
26
      return packing;
130
878
  return nullptr;
131
904
}
132
133
/// Returns the number of binding numbers that are used up by the given type.
134
/// An array of size N consumes N*M binding numbers where M is the number of
135
/// binding numbers used by each array element.
136
/// The number of binding numbers used by a structure is the sum of binding
137
/// numbers used by its members.
138
220
uint32_t getNumBindingsUsedByResourceType(QualType type) {
139
  // For custom-generated types that have SpirvType but no QualType.
140
220
  if (type.isNull())
141
10
    return 1;
142
143
  // For every array dimension, the number of bindings needed should be
144
  // multiplied by the array size. For example: an array of two Textures should
145
  // use 2 binding slots.
146
210
  uint32_t arrayFactor = 1;
147
286
  while (auto constArrayType = dyn_cast<ConstantArrayType>(type)) {
148
76
    arrayFactor *=
149
76
        static_cast<uint32_t>(constArrayType->getSize().getZExtValue());
150
76
    type = constArrayType->getElementType();
151
76
  }
152
153
  // Once we remove the arrayness, we expect the given type to be either a
154
  // resource OR a structure that only contains resources.
155
210
  assert(isResourceType(type) || isResourceOnlyStructure(type));
156
157
  // In the case of a resource, each resource takes 1 binding slot, so in total
158
  // it consumes: 1 * arrayFactor.
159
210
  if (isResourceType(type))
160
188
    return arrayFactor;
161
162
  // In the case of a struct of resources, we need to sum up the number of
163
  // bindings for the struct members. So in total it consumes:
164
  //  sum(bindings of struct members) * arrayFactor.
165
22
  if (isResourceOnlyStructure(type)) {
166
22
    uint32_t sumOfMemberBindings = 0;
167
22
    const auto *structDecl = type->getAs<RecordType>()->getDecl();
168
22
    assert(structDecl);
169
22
    for (const auto *field : structDecl->fields())
170
40
      sumOfMemberBindings += getNumBindingsUsedByResourceType(field->getType());
171
172
22
    return sumOfMemberBindings * arrayFactor;
173
22
  }
174
175
22
  
llvm_unreachable0
(
176
22
      "getNumBindingsUsedByResourceType was called with unknown resource type");
177
22
}
178
179
QualType getUintTypeWithSourceComponents(const ASTContext &astContext,
180
56
                                         QualType sourceType) {
181
56
  if (isScalarType(sourceType)) {
182
30
    return astContext.UnsignedIntTy;
183
30
  }
184
26
  uint32_t elemCount = 0;
185
26
  if (isVectorType(sourceType, nullptr, &elemCount)) {
186
26
    return astContext.getExtVectorType(astContext.UnsignedIntTy, elemCount);
187
26
  }
188
26
  
llvm_unreachable0
("only scalar and vector types are supported in "
189
26
                   "getUintTypeWithSourceComponents");
190
26
}
191
192
LocationAndComponent getLocationAndComponentCount(const ASTContext &astContext,
193
7.99k
                                                  QualType type) {
194
  // See Vulkan spec 14.1.4. Location Assignment for the complete set of rules.
195
196
7.99k
  const auto canonicalType = type.getCanonicalType();
197
7.99k
  if (canonicalType != type)
198
3.31k
    return getLocationAndComponentCount(astContext, canonicalType);
199
200
  // Inputs and outputs of the following types consume a single interface
201
  // location:
202
  // * 16-bit scalar and vector types, and
203
  // * 32-bit scalar and vector types, and
204
  // * 64-bit scalar and 2-component vector types.
205
206
  // 64-bit three- and four- component vectors consume two consecutive
207
  // locations.
208
209
  // Primitive types
210
4.68k
  if (isScalarType(type)) {
211
1.30k
    const auto *builtinType = type->getAs<BuiltinType>();
212
1.30k
    if (builtinType != nullptr) {
213
1.28k
      switch (builtinType->getKind()) {
214
6
      case BuiltinType::Double:
215
6
      case BuiltinType::LongLong:
216
6
      case BuiltinType::ULongLong:
217
6
        return {1, 2, true};
218
1.27k
      default:
219
1.27k
        return {1, 1, false};
220
1.28k
      }
221
1.28k
    }
222
26
    return {1, 1, false};
223
1.30k
  }
224
225
  // Vector types
226
3.37k
  {
227
3.37k
    QualType elemType = {};
228
3.37k
    uint32_t elemCount = {};
229
3.37k
    if (isVectorType(type, &elemType, &elemCount)) {
230
2.77k
      const auto *builtinType = elemType->getAs<BuiltinType>();
231
2.77k
      switch (builtinType->getKind()) {
232
18
      case BuiltinType::Double:
233
18
      case BuiltinType::LongLong:
234
18
      case BuiltinType::ULongLong: {
235
18
        if (elemCount >= 3)
236
14
          return {2, 4, true};
237
4
        return {1, 2 * elemCount, true};
238
18
      }
239
2.75k
      default:
240
        // Filter switch only interested in types occupying 2 locations.
241
2.75k
        break;
242
2.77k
      }
243
2.75k
      return {1, elemCount, false};
244
2.77k
    }
245
3.37k
  }
246
247
  // If the declared input or output is an n * m 16- , 32- or 64- bit matrix,
248
  // it will be assigned multiple locations starting with the location
249
  // specified. The number of locations assigned for each matrix will be the
250
  // same as for an n-element array of m-component vectors.
251
252
  // Matrix types
253
604
  {
254
604
    QualType elemType = {};
255
604
    uint32_t rowCount = 0, colCount = 0;
256
604
    if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
257
58
      auto locComponentCount = getLocationAndComponentCount(
258
58
          astContext, astContext.getExtVectorType(elemType, colCount));
259
58
      return {locComponentCount.location * rowCount,
260
58
              locComponentCount.component,
261
58
              locComponentCount.componentAlignment};
262
58
    }
263
604
  }
264
265
  // Typedefs
266
546
  if (const auto *typedefType = type->getAs<TypedefType>())
267
0
    return getLocationAndComponentCount(astContext, typedefType->desugar());
268
269
  // Reference types
270
546
  if (const auto *refType = type->getAs<ReferenceType>())
271
92
    return getLocationAndComponentCount(astContext, refType->getPointeeType());
272
273
  // Pointer types
274
454
  if (const auto *ptrType = type->getAs<PointerType>())
275
0
    return getLocationAndComponentCount(astContext, ptrType->getPointeeType());
276
277
  // If a declared input or output is an array of size n and each element takes
278
  // m locations, it will be assigned m * n consecutive locations starting with
279
  // the location specified.
280
281
  // Array types
282
454
  if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
283
454
    auto locComponentCount =
284
454
        getLocationAndComponentCount(astContext, arrayType->getElementType());
285
454
    uint32_t arrayLength =
286
454
        static_cast<uint32_t>(arrayType->getSize().getZExtValue());
287
454
    return {locComponentCount.location * arrayLength,
288
454
            locComponentCount.component, locComponentCount.componentAlignment};
289
454
  }
290
291
  // Struct type
292
0
  if (type->getAs<RecordType>()) {
293
0
    assert(false && "all structs should already be flattened");
294
0
    return {0, 0, false};
295
0
  }
296
297
0
  llvm_unreachable(
298
0
      "calculating number of occupied locations for type unimplemented");
299
0
  return {0, 0, false};
300
0
}
301
302
52.4k
bool shouldSkipInStructLayout(const Decl *decl) {
303
  // Ignore implicit generated struct declarations/constructors/destructors
304
52.4k
  if (decl->isImplicit())
305
43.5k
    return true;
306
  // Ignore embedded type decls
307
8.95k
  if (isa<TypeDecl>(decl))
308
2.57k
    return true;
309
  // Ignore embeded function decls
310
6.37k
  if (isa<FunctionDecl>(decl))
311
636
    return true;
312
  // Ignore empty decls
313
5.73k
  if (isa<EmptyDecl>(decl))
314
700
    return true;
315
316
  // For the $Globals cbuffer, we only care about externally-visible
317
  // non-resource-type variables. The rest should be filtered out.
318
319
5.03k
  const auto *declContext = decl->getDeclContext();
320
321
  // $Globals' "struct" is the TranslationUnit, so we should ignore resources
322
  // in the TranslationUnit "struct" and its child namespaces.
323
5.03k
  if (declContext->isTranslationUnit() || 
declContext->isNamespace()3.38k
) {
324
325
3.67k
    if (decl->hasAttr<VKConstantIdAttr>()) {
326
4
      return true;
327
4
    }
328
329
3.67k
    if (decl->hasAttr<VKPushConstantAttr>()) {
330
4
      return true;
331
4
    }
332
333
    // External visibility
334
3.66k
    if (const auto *declDecl = dyn_cast<DeclaratorDecl>(decl))
335
896
      if (!declDecl->hasExternalFormalLinkage())
336
40
        return true;
337
338
    // cbuffer/tbuffer
339
3.62k
    if (isa<HLSLBufferDecl>(decl))
340
24
      return true;
341
342
    // 'groupshared' variables should not be placed in $Globals cbuffer.
343
3.60k
    if (decl->hasAttr<HLSLGroupSharedAttr>())
344
8
      return true;
345
346
    // Other resource types
347
3.59k
    if (const auto *valueDecl = dyn_cast<ValueDecl>(decl)) {
348
848
      const auto declType = valueDecl->getType();
349
848
      if (isResourceType(declType) || 
isResourceOnlyStructure(declType)500
)
350
352
        return true;
351
848
    }
352
3.59k
  }
353
354
4.60k
  return false;
355
5.03k
}
356
357
void collectDeclsInField(const Decl *field,
358
51.8k
                         llvm::SmallVector<const Decl *, 4> *decls) {
359
360
  // Case of nested namespaces.
361
51.8k
  if (const auto *nsDecl = dyn_cast<NamespaceDecl>(field)) {
362
10.4k
    for (const auto *decl : nsDecl->decls()) {
363
10.4k
      collectDeclsInField(decl, decls);
364
10.4k
    }
365
880
  }
366
367
51.8k
  if (shouldSkipInStructLayout(field))
368
47.8k
    return;
369
370
4.00k
  if (!isa<DeclaratorDecl>(field)) {
371
2.84k
    return;
372
2.84k
  }
373
374
1.15k
  decls->push_back(field);
375
1.15k
}
376
377
llvm::SmallVector<const Decl *, 4>
378
478
collectDeclsInDeclContext(const DeclContext *declContext) {
379
478
  llvm::SmallVector<const Decl *, 4> decls;
380
41.3k
  for (const auto *field : declContext->decls()) {
381
41.3k
    collectDeclsInField(field, &decls);
382
41.3k
  }
383
478
  return decls;
384
478
}
385
386
/// \brief Returns true if the given decl is a boolean stage I/O variable.
387
/// Returns false if the type is not boolean, or the decl is a built-in stage
388
/// variable.
389
bool isBooleanStageIOVar(const NamedDecl *decl, QualType type,
390
                         const hlsl::DXIL::SemanticKind semanticKind,
391
8.00k
                         const hlsl::SigPoint::Kind sigPointKind) {
392
  // [[vk::builtin(...)]] makes the decl a built-in stage variable.
393
  // IsFrontFace (if used as PSIn) is the only known boolean built-in stage
394
  // variable.
395
8.00k
  bool isBooleanBuiltin = false;
396
397
8.00k
  if ((decl->getAttr<VKBuiltInAttr>() != nullptr))
398
98
    isBooleanBuiltin = true;
399
7.90k
  else if (semanticKind == hlsl::Semantic::Kind::IsFrontFace &&
400
7.90k
           
sigPointKind == hlsl::SigPoint::Kind::PSIn12
) {
401
8
    isBooleanBuiltin = true;
402
7.89k
  } else if (semanticKind == hlsl::Semantic::Kind::CullPrimitive) {
403
4
    isBooleanBuiltin = true;
404
4
  }
405
406
  // TODO: support boolean matrix stage I/O variable if needed.
407
8.00k
  QualType elemType = {};
408
8.00k
  const bool isBooleanType =
409
8.00k
      ((isScalarType(type, &elemType) || 
isVectorType(type, &elemType)5.97k
) &&
410
8.00k
       
elemType->isBooleanType()6.79k
);
411
412
8.00k
  return isBooleanType && 
!isBooleanBuiltin126
;
413
8.00k
}
414
415
/// \brief Returns the stage variable's register assignment for the given Decl.
416
4.29k
const hlsl::RegisterAssignment *getResourceBinding(const NamedDecl *decl) {
417
4.29k
  for (auto *annotation : decl->getUnusualAnnotations()) {
418
1.00k
    if (auto *reg = dyn_cast<hlsl::RegisterAssignment>(annotation)) {
419
1.00k
      return reg;
420
1.00k
    }
421
1.00k
  }
422
3.29k
  return nullptr;
423
4.29k
}
424
425
/// \brief Returns the stage variable's 'register(c#) assignment for the given
426
/// Decl. Return nullptr if the given variable does not have such assignment.
427
248
const hlsl::RegisterAssignment *getRegisterCAssignment(const NamedDecl *decl) {
428
248
  const auto *regAssignment = getResourceBinding(decl);
429
248
  if (regAssignment)
430
66
    return regAssignment->RegisterType == 'c' ? 
regAssignment64
:
nullptr2
;
431
182
  return nullptr;
432
248
}
433
434
/// \brief Returns true if the given declaration has a primitive type qualifier.
435
/// Returns false otherwise.
436
6.47k
inline bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
437
6.47k
  return decl->hasAttr<HLSLTriangleAttr>() ||
438
6.47k
         
decl->hasAttr<HLSLTriangleAdjAttr>()6.45k
||
439
6.47k
         
decl->hasAttr<HLSLPointAttr>()6.45k
||
decl->hasAttr<HLSLLineAttr>()6.43k
||
440
6.47k
         
decl->hasAttr<HLSLLineAdjAttr>()6.37k
;
441
6.47k
}
442
443
/// \brief Deduces the parameter qualifier for the given decl.
444
hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl,
445
4.72k
                                         bool asInput) {
446
4.72k
  const auto type = decl->getType();
447
448
4.72k
  if (hlsl::IsHLSLInputPatchType(type))
449
78
    return hlsl::DxilParamInputQual::InputPatch;
450
4.64k
  if (hlsl::IsHLSLOutputPatchType(type))
451
30
    return hlsl::DxilParamInputQual::OutputPatch;
452
  // TODO: Add support for multiple output streams.
453
4.61k
  if (hlsl::IsHLSLStreamOutputType(type))
454
46
    return hlsl::DxilParamInputQual::OutStream0;
455
456
  // The inputs to the geometry shader that have a primitive type qualifier
457
  // must use 'InputPrimitive'.
458
4.56k
  if (hasGSPrimitiveTypeQualifier(decl))
459
50
    return hlsl::DxilParamInputQual::InputPrimitive;
460
461
4.51k
  if (decl->hasAttr<HLSLIndicesAttr>())
462
0
    return hlsl::DxilParamInputQual::OutIndices;
463
4.51k
  if (decl->hasAttr<HLSLVerticesAttr>())
464
34
    return hlsl::DxilParamInputQual::OutVertices;
465
4.48k
  if (decl->hasAttr<HLSLPrimitivesAttr>())
466
8
    return hlsl::DxilParamInputQual::OutPrimitives;
467
4.47k
  if (decl->hasAttr<HLSLPayloadAttr>())
468
6
    return hlsl::DxilParamInputQual::InPayload;
469
470
4.47k
  if (hlsl::IsHLSLNodeType(type)) {
471
0
    return hlsl::DxilParamInputQual::NodeIO;
472
0
  }
473
474
4.47k
  return asInput ? 
hlsl::DxilParamInputQual::In1.59k
:
hlsl::DxilParamInputQual::Out2.88k
;
475
4.47k
}
476
477
/// \brief Deduces the HLSL SigPoint for the given decl appearing in the given
478
/// shader model.
479
const hlsl::SigPoint *deduceSigPoint(const DeclaratorDecl *decl, bool asInput,
480
                                     const hlsl::ShaderModel::Kind kind,
481
4.93k
                                     bool forPCF) {
482
4.93k
  if (kind == hlsl::ShaderModel::Kind::Node) {
483
212
    return hlsl::SigPoint::GetSigPoint(hlsl::SigPoint::Kind::CSIn);
484
212
  }
485
4.72k
  return hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
486
4.72k
      deduceParamQual(decl, asInput), kind, forPCF));
487
4.93k
}
488
489
/// Returns the type of the given decl. If the given decl is a FunctionDecl,
490
/// returns its result type.
491
45.7k
inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
492
45.7k
  if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
493
13.8k
    return funcDecl->getReturnType();
494
13.8k
  }
495
31.8k
  return decl->getType();
496
45.7k
}
497
498
/// Returns the number of base classes if this type is a derived class/struct.
499
/// Returns zero otherwise.
500
4.10k
inline uint32_t getNumBaseClasses(QualType type) {
501
4.10k
  if (const auto *cxxDecl = type->getAsCXXRecordDecl())
502
4.10k
    return cxxDecl->getNumBases();
503
0
  return 0;
504
4.10k
}
505
506
/// Returns the appropriate storage class for an extern variable of the given
507
/// type.
508
spv::StorageClass getStorageClassForExternVar(QualType type,
509
3.88k
                                              bool hasGroupsharedAttr) {
510
  // For CS groupshared variables
511
3.88k
  if (hasGroupsharedAttr)
512
88
    return spv::StorageClass::Workgroup;
513
514
3.79k
  if (isAKindOfStructuredOrByteBuffer(type) || 
isConstantTextureBuffer(type)2.32k
)
515
1.59k
    return spv::StorageClass::Uniform;
516
517
2.19k
  return spv::StorageClass::UniformConstant;
518
3.79k
}
519
520
/// Returns the appropriate layout rule for an extern variable of the given
521
/// type.
522
SpirvLayoutRule getLayoutRuleForExternVar(QualType type,
523
3.88k
                                          const SpirvCodeGenOptions &opts) {
524
3.88k
  if (isAKindOfStructuredOrByteBuffer(type))
525
1.47k
    return opts.sBufferLayoutRule;
526
2.40k
  if (isConstantBuffer(type))
527
98
    return opts.cBufferLayoutRule;
528
2.31k
  if (isTextureBuffer(type))
529
26
    return opts.tBufferLayoutRule;
530
2.28k
  return SpirvLayoutRule::Void;
531
2.31k
}
532
533
std::optional<spv::ImageFormat>
534
3.64k
getSpvImageFormat(const VKImageFormatAttr *imageFormatAttr) {
535
3.64k
  if (imageFormatAttr == nullptr)
536
3.55k
    return std::nullopt;
537
538
94
  switch (imageFormatAttr->getImageFormat()) {
539
4
  case VKImageFormatAttr::unknown:
540
4
    return spv::ImageFormat::Unknown;
541
10
  case VKImageFormatAttr::rgba32f:
542
10
    return spv::ImageFormat::Rgba32f;
543
12
  case VKImageFormatAttr::rgba16f:
544
12
    return spv::ImageFormat::Rgba16f;
545
4
  case VKImageFormatAttr::r32f:
546
4
    return spv::ImageFormat::R32f;
547
0
  case VKImageFormatAttr::rgba8:
548
0
    return spv::ImageFormat::Rgba8;
549
4
  case VKImageFormatAttr::rgba8snorm:
550
4
    return spv::ImageFormat::Rgba8Snorm;
551
0
  case VKImageFormatAttr::rg32f:
552
0
    return spv::ImageFormat::Rg32f;
553
4
  case VKImageFormatAttr::rg16f:
554
4
    return spv::ImageFormat::Rg16f;
555
4
  case VKImageFormatAttr::r11g11b10f:
556
4
    return spv::ImageFormat::R11fG11fB10f;
557
0
  case VKImageFormatAttr::r16f:
558
0
    return spv::ImageFormat::R16f;
559
0
  case VKImageFormatAttr::rgba16:
560
0
    return spv::ImageFormat::Rgba16;
561
4
  case VKImageFormatAttr::rgb10a2:
562
4
    return spv::ImageFormat::Rgb10A2;
563
0
  case VKImageFormatAttr::rg16:
564
0
    return spv::ImageFormat::Rg16;
565
4
  case VKImageFormatAttr::rg8:
566
4
    return spv::ImageFormat::Rg8;
567
0
  case VKImageFormatAttr::r16:
568
0
    return spv::ImageFormat::R16;
569
4
  case VKImageFormatAttr::r8:
570
4
    return spv::ImageFormat::R8;
571
0
  case VKImageFormatAttr::rgba16snorm:
572
0
    return spv::ImageFormat::Rgba16Snorm;
573
4
  case VKImageFormatAttr::rg16snorm:
574
4
    return spv::ImageFormat::Rg16Snorm;
575
0
  case VKImageFormatAttr::rg8snorm:
576
0
    return spv::ImageFormat::Rg8Snorm;
577
0
  case VKImageFormatAttr::r16snorm:
578
0
    return spv::ImageFormat::R16Snorm;
579
0
  case VKImageFormatAttr::r8snorm:
580
0
    return spv::ImageFormat::R8Snorm;
581
8
  case VKImageFormatAttr::rgba32i:
582
8
    return spv::ImageFormat::Rgba32i;
583
0
  case VKImageFormatAttr::rgba16i:
584
0
    return spv::ImageFormat::Rgba16i;
585
0
  case VKImageFormatAttr::rgba8i:
586
0
    return spv::ImageFormat::Rgba8i;
587
0
  case VKImageFormatAttr::r32i:
588
0
    return spv::ImageFormat::R32i;
589
0
  case VKImageFormatAttr::rg32i:
590
0
    return spv::ImageFormat::Rg32i;
591
0
  case VKImageFormatAttr::rg16i:
592
0
    return spv::ImageFormat::Rg16i;
593
4
  case VKImageFormatAttr::rg8i:
594
4
    return spv::ImageFormat::Rg8i;
595
0
  case VKImageFormatAttr::r16i:
596
0
    return spv::ImageFormat::R16i;
597
0
  case VKImageFormatAttr::r8i:
598
0
    return spv::ImageFormat::R8i;
599
2
  case VKImageFormatAttr::rgba32ui:
600
2
    return spv::ImageFormat::Rgba32ui;
601
6
  case VKImageFormatAttr::rgba16ui:
602
6
    return spv::ImageFormat::Rgba16ui;
603
0
  case VKImageFormatAttr::rgba8ui:
604
0
    return spv::ImageFormat::Rgba8ui;
605
0
  case VKImageFormatAttr::r32ui:
606
0
    return spv::ImageFormat::R32ui;
607
4
  case VKImageFormatAttr::rgb10a2ui:
608
4
    return spv::ImageFormat::Rgb10a2ui;
609
0
  case VKImageFormatAttr::rg32ui:
610
0
    return spv::ImageFormat::Rg32ui;
611
0
  case VKImageFormatAttr::rg16ui:
612
0
    return spv::ImageFormat::Rg16ui;
613
0
  case VKImageFormatAttr::rg8ui:
614
0
    return spv::ImageFormat::Rg8ui;
615
0
  case VKImageFormatAttr::r16ui:
616
0
    return spv::ImageFormat::R16ui;
617
0
  case VKImageFormatAttr::r8ui:
618
0
    return spv::ImageFormat::R8ui;
619
6
  case VKImageFormatAttr::r64ui:
620
6
    return spv::ImageFormat::R64ui;
621
6
  case VKImageFormatAttr::r64i:
622
6
    return spv::ImageFormat::R64i;
623
94
  }
624
0
  return spv::ImageFormat::Unknown;
625
94
}
626
627
// Inserts seen semantics for entryPoint to seenSemanticsForEntryPoints. Returns
628
// whether it does not already exist in seenSemanticsForEntryPoints.
629
bool insertSeenSemanticsForEntryPointIfNotExist(
630
    llvm::SmallDenseMap<SpirvFunction *, llvm::StringSet<>>
631
        *seenSemanticsForEntryPoints,
632
2.86k
    SpirvFunction *entryPoint, const std::string &semantics) {
633
2.86k
  auto seenSemanticsForEntryPointsItr =
634
2.86k
      seenSemanticsForEntryPoints->find(entryPoint);
635
2.86k
  if (seenSemanticsForEntryPointsItr == seenSemanticsForEntryPoints->end()) {
636
1.93k
    bool insertResult = false;
637
1.93k
    std::tie(seenSemanticsForEntryPointsItr, insertResult) =
638
1.93k
        seenSemanticsForEntryPoints->insert(
639
1.93k
            std::make_pair(entryPoint, llvm::StringSet<>()));
640
1.93k
    assert(insertResult);
641
1.93k
    seenSemanticsForEntryPointsItr->second.insert(semantics);
642
1.93k
    return true;
643
1.93k
  }
644
645
926
  auto &seenSemantics = seenSemanticsForEntryPointsItr->second;
646
926
  if (seenSemantics.count(semantics)) {
647
10
    return false;
648
10
  }
649
916
  seenSemantics.insert(semantics);
650
916
  return true;
651
926
}
652
653
// Returns whether the type is translated to a 32-bit floating point type,
654
// depending on whether SPIR-V codegen options are configured to use 16-bit
655
// types when possible.
656
116
bool is32BitFloatingPointType(BuiltinType::Kind kind, bool use16Bit) {
657
  // Always translated into 32-bit floating point types.
658
116
  if (kind == BuiltinType::Float || 
kind == BuiltinType::LitFloat14
)
659
102
    return true;
660
661
  // Translated into 32-bit floating point types when run without
662
  // -enable-16bit-types.
663
14
  if (kind == BuiltinType::Half || 
kind == BuiltinType::HalfFloat12
||
664
14
      
kind == BuiltinType::Min10Float10
||
kind == BuiltinType::Min16Float6
)
665
10
    return !use16Bit;
666
667
4
  return false;
668
14
}
669
670
// Returns whether the type is a 4-component 32-bit float or a composite type
671
// recursively including only such a vector e.g., float4, float4[1], struct S {
672
// float4 foo[1]; }.
673
244
bool containOnlyVecWithFourFloats(QualType type, bool use16Bit) {
674
244
  if (type->isReferenceType())
675
14
    type = type->getPointeeType();
676
677
244
  if (is1xNMatrix(type, nullptr, nullptr))
678
2
    return false;
679
680
242
  uint32_t elemCount = 0;
681
242
  if (type->isConstantArrayType()) {
682
2
    const ConstantArrayType *arrayType =
683
2
        (const ConstantArrayType *)type->getAsArrayTypeUnsafe();
684
2
    elemCount = hlsl::GetArraySize(type);
685
2
    return elemCount == 1 &&
686
2
           
containOnlyVecWithFourFloats(arrayType->getElementType(), use16Bit)0
;
687
2
  }
688
689
240
  if (const auto *structType = type->getAs<RecordType>()) {
690
120
    uint32_t fieldCount = 0;
691
120
    for (const auto *field : structType->getDecl()->fields()) {
692
120
      if (fieldCount != 0)
693
0
        return false;
694
120
      if (!containOnlyVecWithFourFloats(field->getType(), use16Bit))
695
12
        return false;
696
108
      ++fieldCount;
697
108
    }
698
108
    return fieldCount == 1;
699
120
  }
700
701
120
  QualType elemType = {};
702
120
  if (isVectorType(type, &elemType, &elemCount)) {
703
120
    if (const auto *builtinType = elemType->getAs<BuiltinType>()) {
704
120
      return elemCount == 4 &&
705
120
             
is32BitFloatingPointType(builtinType->getKind(), use16Bit)116
;
706
120
    }
707
0
    return false;
708
120
  }
709
0
  return false;
710
120
}
711
712
} // anonymous namespace
713
714
17.7k
std::string StageVar::getSemanticStr() const {
715
  // A special case for zero index, which is equivalent to no index.
716
  // Use what is in the source code.
717
  // TODO: this looks like a hack to make the current tests happy.
718
  // Should consider remove it and fix all tests.
719
17.7k
  if (semanticInfo.index == 0)
720
16.8k
    return semanticInfo.str;
721
722
936
  std::ostringstream ss;
723
936
  ss << semanticInfo.name.str() << semanticInfo.index;
724
936
  return ss.str();
725
17.7k
}
726
727
40
SpirvInstruction *CounterIdAliasPair::getAliasAddress() const {
728
40
  assert(isAlias);
729
40
  return counterVar;
730
40
}
731
732
SpirvInstruction *
733
CounterIdAliasPair::getCounterVariable(SpirvBuilder &builder,
734
800
                                       SpirvContext &spvContext) const {
735
800
  if (isAlias) {
736
324
    const auto *counterType = spvContext.getACSBufferCounterType();
737
324
    const auto *counterVarType =
738
324
        spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
739
324
    return builder.createLoad(counterVarType, counterVar,
740
324
                              /* SourceLocation */ {});
741
324
  }
742
476
  return counterVar;
743
800
}
744
745
const CounterIdAliasPair *
746
232
CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
747
232
  for (const auto &field : fields)
748
676
    if (field.indices == indices)
749
212
      return &field.counterVar;
750
20
  return nullptr;
751
232
}
752
753
bool CounterVarFields::assign(const CounterVarFields &srcFields,
754
                              SpirvBuilder &builder,
755
28
                              SpirvContext &context) const {
756
138
  for (const auto &field : fields) {
757
138
    const auto *srcField = srcFields.get(field.indices);
758
138
    if (!srcField)
759
0
      return false;
760
761
138
    field.counterVar.assign(srcField->getCounterVariable(builder, context),
762
138
                            builder);
763
138
  }
764
765
28
  return true;
766
28
}
767
768
bool CounterVarFields::assign(const CounterVarFields &srcFields,
769
                              const llvm::SmallVector<uint32_t, 4> &dstPrefix,
770
                              const llvm::SmallVector<uint32_t, 4> &srcPrefix,
771
                              SpirvBuilder &builder,
772
38
                              SpirvContext &context) const {
773
38
  if (dstPrefix.empty() && 
srcPrefix.empty()32
)
774
28
    return assign(srcFields, builder, context);
775
776
10
  llvm::SmallVector<uint32_t, 4> srcIndices = srcPrefix;
777
778
  // If whole has the given prefix, appends all elements after the prefix in
779
  // whole to srcIndices.
780
10
  const auto applyDiff =
781
10
      [&srcIndices](const llvm::SmallVector<uint32_t, 4> &whole,
782
38
                    const llvm::SmallVector<uint32_t, 4> &prefix) -> bool {
783
38
    uint32_t i = 0;
784
76
    for (; i < prefix.size(); 
++i38
)
785
54
      if (whole[i] != prefix[i]) {
786
16
        break;
787
16
      }
788
38
    if (i == prefix.size()) {
789
44
      for (; i < whole.size(); 
++i22
)
790
22
        srcIndices.push_back(whole[i]);
791
22
      return true;
792
22
    }
793
16
    return false;
794
38
  };
795
796
10
  for (const auto &field : fields)
797
38
    if (applyDiff(field.indices, dstPrefix)) {
798
22
      const auto *srcField = srcFields.get(srcIndices);
799
22
      if (!srcField)
800
0
        return false;
801
802
22
      field.counterVar.assign(srcField->getCounterVariable(builder, context),
803
22
                              builder);
804
44
      for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); 
++i22
)
805
22
        srcIndices.pop_back();
806
22
    }
807
808
10
  return true;
809
10
}
810
811
5.45k
SemanticInfo DeclResultIdMapper::getStageVarSemantic(const NamedDecl *decl) {
812
5.45k
  for (auto *annotation : decl->getUnusualAnnotations()) {
813
4.45k
    if (auto *sema = dyn_cast<hlsl::SemanticDecl>(annotation)) {
814
4.45k
      llvm::StringRef semanticStr = sema->SemanticName;
815
4.45k
      llvm::StringRef semanticName;
816
4.45k
      uint32_t index = 0;
817
4.45k
      hlsl::Semantic::DecomposeNameAndIndex(semanticStr, &semanticName, &index);
818
4.45k
      const auto *semantic = hlsl::Semantic::GetByName(semanticName);
819
4.45k
      return {semanticStr, semantic, semanticName, index, sema->Loc};
820
4.45k
    }
821
4.45k
  }
822
996
  return {};
823
5.45k
}
824
825
bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
826
                                              SpirvInstruction *storedValue,
827
3.15k
                                              bool forPCF) {
828
3.15k
  QualType type = getTypeOrFnRetType(decl);
829
3.15k
  uint32_t arraySize = 0;
830
831
  // Output stream types (PointStream, LineStream, TriangleStream) are
832
  // translated as their underlying struct types.
833
3.15k
  if (hlsl::IsHLSLStreamOutputType(type))
834
46
    type = hlsl::GetHLSLResourceResultType(type);
835
836
3.15k
  if (decl->hasAttr<HLSLIndicesAttr>() || 
decl->hasAttr<HLSLVerticesAttr>()3.11k
||
837
3.15k
      
decl->hasAttr<HLSLPrimitivesAttr>()3.08k
) {
838
76
    const auto *typeDecl = astContext.getAsConstantArrayType(type);
839
76
    type = typeDecl->getElementType();
840
76
    arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
841
76
    if (decl->hasAttr<HLSLIndicesAttr>()) {
842
      // create SPIR-V builtin array PrimitiveIndicesNV of type
843
      // "uint [MaxPrimitiveCount * verticesPerPrim]"
844
34
      uint32_t verticesPerPrim = 1;
845
34
      if (!isVectorType(type, nullptr, &verticesPerPrim)) {
846
8
        assert(isScalarType(type));
847
8
      }
848
849
34
      spv::BuiltIn builtinID = spv::BuiltIn::Max;
850
34
      if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
851
        // For EXT_mesh_shader, set builtin type as
852
        // PrimitivePoint/Line/TriangleIndicesEXT based on the vertices per
853
        // primitive
854
10
        switch (verticesPerPrim) {
855
0
        case 1:
856
0
          builtinID = spv::BuiltIn::PrimitivePointIndicesEXT;
857
0
          break;
858
0
        case 2:
859
0
          builtinID = spv::BuiltIn::PrimitiveLineIndicesEXT;
860
0
          break;
861
10
        case 3:
862
10
          builtinID = spv::BuiltIn::PrimitiveTriangleIndicesEXT;
863
10
          break;
864
0
        default:
865
0
          break;
866
10
        }
867
10
        QualType arrayType = astContext.getConstantArrayType(
868
10
            type, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0);
869
870
10
        msOutIndicesBuiltin =
871
10
            getBuiltinVar(builtinID, arrayType, decl->getLocation());
872
24
      } else {
873
        // For NV_mesh_shader, the built type is PrimitiveIndicesNV
874
24
        builtinID = spv::BuiltIn::PrimitiveIndicesNV;
875
876
24
        arraySize = arraySize * verticesPerPrim;
877
24
        QualType arrayType = astContext.getConstantArrayType(
878
24
            astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
879
24
            clang::ArrayType::Normal, 0);
880
881
24
        msOutIndicesBuiltin =
882
24
            getBuiltinVar(builtinID, arrayType, decl->getLocation());
883
24
      }
884
885
34
      return true;
886
34
    }
887
76
  }
888
889
3.11k
  const auto *sigPoint = deduceSigPoint(
890
3.11k
      decl, /*asInput=*/false, spvContext.getCurrentShaderModelKind(), forPCF);
891
892
  // HS output variables are created using the other overload. For the rest,
893
  // none of them should be created as arrays.
894
3.11k
  assert(sigPoint->GetKind() != hlsl::DXIL::SigPointKind::HSCPOut);
895
896
3.11k
  SemanticInfo inheritSemantic = {};
897
898
  // If storedValue is 0, it means this parameter in the original source code is
899
  // not used at all. Avoid writing back.
900
  //
901
  // Write back of stage output variables in GS is manually controlled by
902
  // .Append() intrinsic method, implemented in writeBackOutputStream(). So
903
  // ignoreValue should be set to true for GS.
904
3.11k
  const bool noWriteBack =
905
3.11k
      storedValue == nullptr || 
spvContext.isGS()3.01k
||
spvContext.isMS()2.96k
;
906
907
3.11k
  StageVarDataBundle stageVarData = {
908
3.11k
      decl, &inheritSemantic, false,     sigPoint,
909
3.11k
      type, arraySize,        "out.var", llvm::None};
910
3.11k
  return createStageVars(stageVarData, /*asInput=*/false, &storedValue,
911
3.11k
                         noWriteBack);
912
3.15k
}
913
914
bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
915
                                              uint32_t arraySize,
916
                                              SpirvInstruction *invocationId,
917
80
                                              SpirvInstruction *storedValue) {
918
80
  assert(spvContext.isHS());
919
920
80
  QualType type = getTypeOrFnRetType(decl);
921
922
80
  const auto *sigPoint =
923
80
      hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::HSCPOut);
924
925
80
  SemanticInfo inheritSemantic = {};
926
927
80
  StageVarDataBundle stageVarData = {
928
80
      decl, &inheritSemantic, false,     sigPoint,
929
80
      type, arraySize,        "out.var", invocationId};
930
80
  return createStageVars(stageVarData, /*asInput=*/false, &storedValue,
931
80
                         /*noWriteBack=*/false);
932
80
}
933
934
bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
935
                                             SpirvInstruction **loadedValue,
936
1.81k
                                             bool forPCF) {
937
1.81k
  uint32_t arraySize = 0;
938
1.81k
  QualType type = paramDecl->getType();
939
940
  // Deprive the outermost arrayness for HS/DS/GS and use arraySize
941
  // to convey that information
942
1.81k
  if (hlsl::IsHLSLInputPatchType(type)) {
943
78
    arraySize = hlsl::GetHLSLInputPatchCount(type);
944
78
    type = hlsl::GetHLSLInputPatchElementType(type);
945
1.74k
  } else if (hlsl::IsHLSLOutputPatchType(type)) {
946
30
    arraySize = hlsl::GetHLSLOutputPatchCount(type);
947
30
    type = hlsl::GetHLSLOutputPatchElementType(type);
948
30
  }
949
1.81k
  if (hasGSPrimitiveTypeQualifier(paramDecl)) {
950
50
    const auto *typeDecl = astContext.getAsConstantArrayType(type);
951
50
    arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
952
50
    type = typeDecl->getElementType();
953
50
  }
954
955
1.81k
  const auto *sigPoint =
956
1.81k
      deduceSigPoint(paramDecl, /*asInput=*/true,
957
1.81k
                     spvContext.getCurrentShaderModelKind(), forPCF);
958
959
1.81k
  SemanticInfo inheritSemantic = {};
960
961
1.81k
  if (paramDecl->hasAttr<HLSLPayloadAttr>()) {
962
6
    spv::StorageClass sc =
963
6
        (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader))
964
6
            ? 
spv::StorageClass::TaskPayloadWorkgroupEXT2
965
6
            : 
getStorageClassForSigPoint(sigPoint)4
;
966
6
    return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true,
967
6
                                  type, "in.var", loadedValue);
968
1.81k
  } else {
969
1.81k
    StageVarDataBundle stageVarData = {
970
1.81k
        paramDecl,
971
1.81k
        &inheritSemantic,
972
1.81k
        paramDecl->hasAttr<HLSLNoInterpolationAttr>(),
973
1.81k
        sigPoint,
974
1.81k
        type,
975
1.81k
        arraySize,
976
1.81k
        "in.var",
977
1.81k
        llvm::None};
978
1.81k
    return createStageVars(stageVarData, /*asInput=*/true, loadedValue,
979
1.81k
                           /*noWriteBack=*/false);
980
1.81k
  }
981
1.81k
}
982
983
const DeclResultIdMapper::DeclSpirvInfo *
984
25.4k
DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
985
25.4k
  auto it = astDecls.find(decl);
986
25.4k
  if (it != astDecls.end())
987
25.4k
    return &it->second;
988
989
6
  return nullptr;
990
25.4k
}
991
992
SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
993
                                                      SourceLocation loc,
994
25.4k
                                                      SourceRange range) {
995
25.4k
  if (auto *builtinAttr = decl->getAttr<VKExtBuiltinInputAttr>()) {
996
16
    return getBuiltinVar(spv::BuiltIn(builtinAttr->getBuiltInID()),
997
16
                         decl->getType(), spv::StorageClass::Input, loc);
998
25.4k
  } else if (auto *builtinAttr = decl->getAttr<VKExtBuiltinOutputAttr>()) {
999
6
    return getBuiltinVar(spv::BuiltIn(builtinAttr->getBuiltInID()),
1000
6
                         decl->getType(), spv::StorageClass::Output, loc);
1001
6
  }
1002
1003
25.4k
  const DeclSpirvInfo *info = getDeclSpirvInfo(decl);
1004
1005
  // If DeclSpirvInfo is not found for this decl, it might be because it is an
1006
  // implicit VarDecl. All implicit VarDecls are lazily created in order to
1007
  // avoid creating large number of unused variables/constants/enums.
1008
25.4k
  if (!info) {
1009
4
    tryToCreateImplicitConstVar(decl);
1010
4
    info = getDeclSpirvInfo(decl);
1011
4
  }
1012
1013
25.4k
  if (info) {
1014
25.4k
    if (info->indexInCTBuffer >= 0) {
1015
      // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
1016
      // OpAccessChain to get the pointer to the variable since we created
1017
      // a single variable for the whole buffer object.
1018
1019
      // Should only have VarDecls in a HLSLBufferDecl.
1020
472
      QualType valueType = cast<VarDecl>(decl)->getType();
1021
472
      return spvBuilder.createAccessChain(
1022
472
          valueType, info->instr,
1023
472
          {spvBuilder.getConstantInt(
1024
472
              astContext.IntTy, llvm::APInt(32, info->indexInCTBuffer, true))},
1025
472
          loc, range);
1026
25.0k
    } else if (auto *type = info->instr->getResultType()) {
1027
164
      const auto *ptrTy = dyn_cast<HybridPointerType>(type);
1028
1029
      // If it is a local variable or function parameter with a bindless
1030
      // array of an opaque type, we have to load it because we pass a
1031
      // pointer of a global variable that has the bindless opaque array.
1032
164
      if (ptrTy != nullptr && 
isBindlessOpaqueArray(decl->getType())4
) {
1033
4
        auto *load = spvBuilder.createLoad(ptrTy, info->instr, loc, range);
1034
4
        load->setRValue(false);
1035
4
        return load;
1036
160
      } else {
1037
160
        return *info;
1038
160
      }
1039
24.8k
    } else {
1040
24.8k
      return *info;
1041
24.8k
    }
1042
25.4k
  }
1043
1044
2
  emitFatalError("found unregistered decl %0", decl->getLocation())
1045
2
      << decl->getName();
1046
2
  emitNote("please file a bug report on "
1047
2
           "https://github.com/Microsoft/DirectXShaderCompiler/issues with "
1048
2
           "source code if possible",
1049
2
           {});
1050
2
  return 0;
1051
25.4k
}
1052
1053
SpirvFunctionParameter *
1054
DeclResultIdMapper::createFnParam(const ParmVarDecl *param,
1055
4.11k
                                  uint32_t dbgArgNumber) {
1056
4.11k
  const auto type = getTypeOrFnRetType(param);
1057
4.11k
  const auto loc = param->getLocation();
1058
4.11k
  const auto range = param->getSourceRange();
1059
4.11k
  const auto name = param->getName();
1060
4.11k
  SpirvFunctionParameter *fnParamInstr = spvBuilder.addFnParam(
1061
4.11k
      type, param->hasAttr<HLSLPreciseAttr>(),
1062
4.11k
      param->hasAttr<HLSLNoInterpolationAttr>(), loc, param->getName());
1063
4.11k
  bool isAlias = false;
1064
4.11k
  (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);
1065
4.11k
  fnParamInstr->setContainsAliasComponent(isAlias);
1066
1067
4.11k
  assert(astDecls[param].instr == nullptr);
1068
4.11k
  registerVariableForDecl(param, fnParamInstr);
1069
1070
4.11k
  if (spirvOptions.debugInfoRich) {
1071
    // Add DebugLocalVariable information
1072
146
    const auto &sm = astContext.getSourceManager();
1073
146
    const uint32_t line = sm.getPresumedLineNumber(loc);
1074
146
    const uint32_t column = sm.getPresumedColumnNumber(loc);
1075
146
    const auto *info = theEmitter.getOrCreateRichDebugInfo(loc);
1076
    // TODO: replace this with FlagIsLocal enum.
1077
146
    uint32_t flags = 1 << 2;
1078
146
    auto *debugLocalVar = spvBuilder.createDebugLocalVariable(
1079
146
        type, name, info->source, line, column, info->scopeStack.back(), flags,
1080
146
        dbgArgNumber);
1081
146
    spvBuilder.createDebugDeclare(debugLocalVar, fnParamInstr, loc, range);
1082
146
  }
1083
1084
4.11k
  return fnParamInstr;
1085
4.11k
}
1086
1087
2.87k
void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
1088
2.87k
  const QualType declType = getTypeOrFnRetType(decl);
1089
1090
2.87k
  if (!counterVars.count(decl) && 
isRWAppendConsumeSBuffer(declType)2.44k
) {
1091
286
    createCounterVar(decl, /*declId=*/0, /*isAlias=*/true);
1092
2.59k
  } else if (!fieldCounterVars.count(decl) && 
declType->isStructureType()2.56k
&&
1093
             // Exclude other resource types which are represented as structs
1094
2.59k
             
!hlsl::IsHLSLResourceType(declType)348
) {
1095
302
    createFieldCounterVars(decl);
1096
302
  }
1097
2.87k
}
1098
1099
SpirvVariable *
1100
DeclResultIdMapper::createFnVar(const VarDecl *var,
1101
8.98k
                                llvm::Optional<SpirvInstruction *> init) {
1102
8.98k
  if (astDecls[var].instr != nullptr)
1103
4
    return cast<SpirvVariable>(astDecls[var].instr);
1104
1105
8.97k
  const auto type = getTypeOrFnRetType(var);
1106
8.97k
  const auto loc = var->getLocation();
1107
8.97k
  const auto name = var->getName();
1108
8.97k
  const bool isPrecise = var->hasAttr<HLSLPreciseAttr>();
1109
8.97k
  const bool isNointerp = var->hasAttr<HLSLNoInterpolationAttr>();
1110
8.97k
  SpirvVariable *varInstr =
1111
8.97k
      spvBuilder.addFnVar(type, loc, name, isPrecise, isNointerp,
1112
8.97k
                          init.hasValue() ? 
init.getValue()0
: nullptr);
1113
1114
8.97k
  bool isAlias = false;
1115
8.97k
  (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
1116
8.97k
  varInstr->setContainsAliasComponent(isAlias);
1117
8.97k
  registerVariableForDecl(var, varInstr);
1118
8.97k
  return varInstr;
1119
8.98k
}
1120
1121
SpirvDebugGlobalVariable *DeclResultIdMapper::createDebugGlobalVariable(
1122
    SpirvVariable *var, const QualType &type, const SourceLocation &loc,
1123
3.88k
    const StringRef &name) {
1124
3.88k
  if (spirvOptions.debugInfoRich) {
1125
    // Add DebugGlobalVariable information
1126
96
    const auto &sm = astContext.getSourceManager();
1127
96
    const uint32_t line = sm.getPresumedLineNumber(loc);
1128
96
    const uint32_t column = sm.getPresumedColumnNumber(loc);
1129
96
    const auto *info = theEmitter.getOrCreateRichDebugInfo(loc);
1130
    // TODO: replace this with FlagIsDefinition enum.
1131
96
    uint32_t flags = 1 << 3;
1132
    // TODO: update linkageName correctly.
1133
96
    auto *dbgGlobalVar = spvBuilder.createDebugGlobalVariable(
1134
96
        type, name, info->source, line, column, info->scopeStack.back(),
1135
96
        /* linkageName */ name, var, flags);
1136
96
    dbgGlobalVar->setDebugSpirvType(var->getResultType());
1137
96
    dbgGlobalVar->setLayoutRule(var->getLayoutRule());
1138
96
    return dbgGlobalVar;
1139
96
  }
1140
3.79k
  return nullptr;
1141
3.88k
}
1142
1143
SpirvVariable *
1144
DeclResultIdMapper::createFileVar(const VarDecl *var,
1145
252
                                  llvm::Optional<SpirvInstruction *> init) {
1146
  // In the case of template specialization, the same VarDecl node in the AST
1147
  // may be traversed more than once.
1148
252
  if (astDecls[var].instr != nullptr) {
1149
2
    return cast<SpirvVariable>(astDecls[var].instr);
1150
2
  }
1151
1152
250
  const auto type = getTypeOrFnRetType(var);
1153
250
  const auto loc = var->getLocation();
1154
250
  const auto name = var->getName();
1155
250
  SpirvVariable *varInstr = spvBuilder.addModuleVar(
1156
250
      type, spv::StorageClass::Private, var->hasAttr<HLSLPreciseAttr>(),
1157
250
      var->hasAttr<HLSLNoInterpolationAttr>(), name, init, loc);
1158
1159
250
  bool isAlias = false;
1160
250
  (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
1161
250
  varInstr->setContainsAliasComponent(isAlias);
1162
250
  registerVariableForDecl(var, varInstr);
1163
1164
250
  createDebugGlobalVariable(varInstr, type, loc, name);
1165
1166
250
  return varInstr;
1167
252
}
1168
1169
SpirvVariable *DeclResultIdMapper::createResourceHeap(const VarDecl *var,
1170
74
                                                      QualType ResourceType) {
1171
74
  QualType ResourceArrayType = astContext.getIncompleteArrayType(
1172
74
      ResourceType, clang::ArrayType::Normal, 0);
1173
74
  return createExternVar(var, ResourceArrayType);
1174
74
}
1175
1176
3.80k
SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
1177
3.80k
  return createExternVar(var, var->getType());
1178
3.80k
}
1179
1180
SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var,
1181
3.88k
                                                   QualType type) {
1182
3.88k
  const bool isGroupShared = var->hasAttr<HLSLGroupSharedAttr>();
1183
3.88k
  const bool isACSBuffer =
1184
3.88k
      isAppendStructuredBuffer(type) || 
isConsumeStructuredBuffer(type)3.79k
;
1185
3.88k
  const bool isRWSBuffer = isRWStructuredBuffer(type);
1186
3.88k
  const auto storageClass = getStorageClassForExternVar(type, isGroupShared);
1187
3.88k
  const auto rule = getLayoutRuleForExternVar(type, spirvOptions);
1188
3.88k
  const auto loc = var->getLocation();
1189
1190
3.88k
  if (!isGroupShared && 
!isResourceType(type)3.79k
&&
1191
3.88k
      
!isResourceOnlyStructure(type)252
) {
1192
1193
    // We currently cannot support global structures that contain both resources
1194
    // and non-resources. That would require significant work in manipulating
1195
    // structure field decls, manipulating QualTypes, as well as inserting
1196
    // non-resources into the Globals cbuffer which changes offset decorations
1197
    // for it.
1198
232
    if (isStructureContainingMixOfResourcesAndNonResources(type)) {
1199
4
      emitError("global structures containing both resources and non-resources "
1200
4
                "are not supported",
1201
4
                loc);
1202
4
      return nullptr;
1203
4
    }
1204
1205
    // This is a stand-alone externally-visiable non-resource-type variable.
1206
    // They should be grouped into the $Globals cbuffer. We create that cbuffer
1207
    // and record all variables inside it upon seeing the first such variable.
1208
228
    if (astDecls.count(var) == 0)
1209
130
      createGlobalsCBuffer(var);
1210
1211
228
    auto *varInstr = astDecls[var].instr;
1212
228
    return varInstr ? 
cast<SpirvVariable>(varInstr)226
:
nullptr2
;
1213
232
  }
1214
1215
3.65k
  if (isResourceOnlyStructure(type)) {
1216
    // We currently do not support global structures that contain buffers.
1217
    // Supporting global structures that contain buffers has two complications:
1218
    //
1219
    // 1- Buffers have the Uniform storage class, whereas Textures/Samplers have
1220
    // UniformConstant storage class. As a result, if a struct contains both
1221
    // textures and buffers, it is not clear what storage class should be used
1222
    // for the struct. Also legalization cannot deduce the proper storage class
1223
    // for struct members based on the structure's storage class.
1224
    //
1225
    // 2- Any kind of structured buffer has associated counters. The current DXC
1226
    // code is not written in a way to place associated counters inside a
1227
    // structure. Changing this behavior is non-trivial. There's also
1228
    // significant work to be done both in DXC (to properly generate binding
1229
    // numbers for the resource and its associated counters at correct offsets)
1230
    // and in spirv-opt (to flatten such strcutures and modify the binding
1231
    // numbers accordingly).
1232
24
    if (isStructureContainingAnyKindOfBuffer(type)) {
1233
2
      emitError("global structures containing buffers are not supported", loc);
1234
2
      return nullptr;
1235
2
    }
1236
1237
22
    needsFlatteningCompositeResources = true;
1238
22
  }
1239
1240
3.64k
  const auto name = var->getName();
1241
3.64k
  SpirvVariable *varInstr = spvBuilder.addModuleVar(
1242
3.64k
      type, storageClass, var->hasAttr<HLSLPreciseAttr>(),
1243
3.64k
      var->hasAttr<HLSLNoInterpolationAttr>(), name, llvm::None, loc);
1244
3.64k
  varInstr->setLayoutRule(rule);
1245
1246
  // If this variable has [[vk::combinedImageSampler]] and/or
1247
  // [[vk::image_format("..")]] attributes, we have to keep the information in
1248
  // the SpirvContext and use it when we lower the QualType to SpirvType.
1249
3.64k
  VkImageFeatures vkImgFeatures = {
1250
3.64k
      var->getAttr<VKCombinedImageSamplerAttr>() != nullptr,
1251
3.64k
      getSpvImageFormat(var->getAttr<VKImageFormatAttr>())};
1252
3.64k
  if (vkImgFeatures.format) {
1253
    // Legalization is needed to propagate the correct image type for
1254
    // instructions in addition to cases where the resource is assigned to
1255
    // another variable or function parameter
1256
94
    needsLegalization = true;
1257
94
  }
1258
3.64k
  if (vkImgFeatures.isCombinedImageSampler || 
vkImgFeatures.format3.62k
) {
1259
120
    spvContext.registerVkImageFeaturesForSpvVariable(varInstr, vkImgFeatures);
1260
120
  }
1261
1262
3.64k
  if (const auto *recordType = type->getAs<RecordType>()) {
1263
3.33k
    StringRef typeName = recordType->getDecl()->getName();
1264
3.33k
    if (typeName.startswith("FeedbackTexture")) {
1265
8
      emitError("Texture resource type '%0' is not supported with -spirv", loc)
1266
8
          << typeName;
1267
8
      return nullptr;
1268
8
    }
1269
3.33k
  }
1270
1271
3.64k
  if (hlsl::IsHLSLResourceType(type)) {
1272
3.25k
    if (!areFormatAndTypeCompatible(
1273
3.25k
            vkImgFeatures.format.value_or(spv::ImageFormat::Unknown),
1274
3.25k
            hlsl::GetHLSLResourceResultType(type))) {
1275
12
      emitError("The image format and the sampled type are not compatible.\n"
1276
12
                "For the table of compatible types, see "
1277
12
                "https://docs.vulkan.org/spec/latest/appendices/"
1278
12
                "spirvenv.html#spirvenv-format-type-matching.",
1279
12
                loc);
1280
12
      return nullptr;
1281
12
    }
1282
3.25k
  }
1283
1284
3.62k
  registerVariableForDecl(var, createDeclSpirvInfo(varInstr));
1285
1286
3.62k
  createDebugGlobalVariable(varInstr, type, loc, name);
1287
1288
  // Variables in Workgroup do not need descriptor decorations.
1289
3.62k
  if (storageClass == spv::StorageClass::Workgroup)
1290
88
    return varInstr;
1291
1292
3.54k
  const auto *bindingAttr = var->getAttr<VKBindingAttr>();
1293
3.54k
  resourceVars.emplace_back(varInstr, var, loc, getResourceBinding(var),
1294
3.54k
                            bindingAttr, var->getAttr<VKCounterBindingAttr>());
1295
1296
3.54k
  if (const auto *inputAttachment = var->getAttr<VKInputAttachmentIndexAttr>())
1297
28
    spvBuilder.decorateInputAttachmentIndex(varInstr,
1298
28
                                            inputAttachment->getIndex(), loc);
1299
1300
3.54k
  if (isACSBuffer) {
1301
    // For {Append|Consume}StructuredBuffer, we need to always create another
1302
    // variable for its associated counter.
1303
172
    createCounterVar(var, varInstr, /*isAlias=*/false);
1304
3.36k
  } else if (isRWSBuffer) {
1305
936
    declRWSBuffers[var] = varInstr;
1306
936
  }
1307
1308
3.54k
  return varInstr;
1309
3.62k
}
1310
1311
2
SpirvInstruction *DeclResultIdMapper::createResultId(const VarDecl *var) {
1312
2
  assert(isExtResultIdType(var->getType()));
1313
1314
  // Without initialization, we cannot generate the result id.
1315
2
  if (!var->hasInit()) {
1316
0
    emitError("Found uninitialized variable for result id.",
1317
0
              var->getLocation());
1318
0
    return nullptr;
1319
0
  }
1320
1321
2
  SpirvInstruction *init = theEmitter.doExpr(var->getInit());
1322
2
  registerVariableForDecl(var, createDeclSpirvInfo(init));
1323
2
  return init;
1324
2
}
1325
1326
SpirvInstruction *
1327
18
DeclResultIdMapper::createOrUpdateStringVar(const VarDecl *var) {
1328
18
  assert(hlsl::IsStringType(var->getType()) ||
1329
18
         hlsl::IsStringLiteralType(var->getType()));
1330
1331
  // If the string variable is not initialized to a string literal, we cannot
1332
  // generate an OpString for it.
1333
18
  if (!var->hasInit()) {
1334
2
    emitError("Found uninitialized string variable.", var->getLocation());
1335
2
    return nullptr;
1336
2
  }
1337
1338
16
  const StringLiteral *stringLiteral =
1339
16
      dyn_cast<StringLiteral>(var->getInit()->IgnoreParenCasts());
1340
16
  SpirvString *init = spvBuilder.getString(stringLiteral->getString());
1341
16
  registerVariableForDecl(var, createDeclSpirvInfo(init));
1342
16
  return init;
1343
18
}
1344
1345
SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
1346
    const DeclContext *decl, llvm::ArrayRef<int> arraySize,
1347
    const ContextUsageKind usageKind, llvm::StringRef typeName,
1348
348
    llvm::StringRef varName) {
1349
  // cbuffers are translated into OpTypeStruct with Block decoration.
1350
  // tbuffers are translated into OpTypeStruct with BufferBlock decoration.
1351
  // Push constants are translated into OpTypeStruct with Block decoration.
1352
  //
1353
  // Both cbuffers and tbuffers have the SPIR-V Uniform storage class.
1354
  // Push constants have the SPIR-V PushConstant storage class.
1355
1356
348
  const bool forCBuffer = usageKind == ContextUsageKind::CBuffer;
1357
348
  const bool forTBuffer = usageKind == ContextUsageKind::TBuffer;
1358
348
  const bool forGlobals = usageKind == ContextUsageKind::Globals;
1359
348
  const bool forPC = usageKind == ContextUsageKind::PushConstant;
1360
348
  const bool forShaderRecordNV =
1361
348
      usageKind == ContextUsageKind::ShaderRecordBufferNV;
1362
348
  const bool forShaderRecordEXT =
1363
348
      usageKind == ContextUsageKind::ShaderRecordBufferKHR;
1364
1365
348
  const auto &declGroup = collectDeclsInDeclContext(decl);
1366
1367
  // Collect the type and name for each field
1368
348
  llvm::SmallVector<HybridStructType::FieldInfo, 4> fields;
1369
910
  for (const auto *subDecl : declGroup) {
1370
    // The field can only be FieldDecl (for normal structs) or VarDecl (for
1371
    // HLSLBufferDecls).
1372
910
    assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
1373
910
    const auto *declDecl = cast<DeclaratorDecl>(subDecl);
1374
910
    auto varType = declDecl->getType();
1375
910
    if (const auto *fieldVar = dyn_cast<VarDecl>(subDecl)) {
1376
1377
      // Static variables are not part of the struct from a layout perspective.
1378
      // Thus, they should not be listed in the struct fields.
1379
852
      if (fieldVar->getStorageClass() == StorageClass::SC_Static) {
1380
2
        continue;
1381
2
      }
1382
1383
850
      if (isResourceType(varType)) {
1384
4
        continue;
1385
4
      }
1386
850
    }
1387
1388
    // In case 'register(c#)' annotation is placed on a global variable.
1389
904
    const hlsl::RegisterAssignment *registerC =
1390
904
        forGlobals ? 
getRegisterCAssignment(declDecl)248
:
nullptr656
;
1391
1392
904
    llvm::Optional<BitfieldInfo> bitfieldInfo;
1393
904
    {
1394
904
      const FieldDecl *Field = dyn_cast<FieldDecl>(subDecl);
1395
904
      if (Field && 
Field->isBitField()58
) {
1396
4
        bitfieldInfo = BitfieldInfo();
1397
4
        bitfieldInfo->sizeInBits =
1398
4
            Field->getBitWidthValue(Field->getASTContext());
1399
4
      }
1400
904
    }
1401
1402
    // All fields are qualified with const. It will affect the debug name.
1403
    // We don't need it here.
1404
904
    varType.removeLocalConst();
1405
904
    HybridStructType::FieldInfo info(
1406
904
        varType, declDecl->getName(),
1407
904
        /*vkoffset*/ declDecl->getAttr<VKOffsetAttr>(),
1408
904
        /*packoffset*/ getPackOffset(declDecl),
1409
904
        /*RegisterAssignment*/ registerC,
1410
904
        /*isPrecise*/ declDecl->hasAttr<HLSLPreciseAttr>(),
1411
904
        /*bitfield*/ bitfieldInfo);
1412
904
    fields.push_back(info);
1413
904
  }
1414
1415
  // Get the type for the whole struct
1416
  // tbuffer/TextureBuffers are non-writable SSBOs.
1417
348
  const SpirvType *resultType = spvContext.getHybridStructType(
1418
348
      fields, typeName, /*isReadOnly*/ forTBuffer,
1419
348
      forTBuffer ? 
StructInterfaceType::StorageBuffer32
1420
348
                 : 
StructInterfaceType::UniformBuffer316
);
1421
1422
348
  for (int size : arraySize) {
1423
0
    if (size != -1) {
1424
0
      resultType = spvContext.getArrayType(resultType, size,
1425
0
                                           /*ArrayStride*/ llvm::None);
1426
0
    } else {
1427
0
      resultType = spvContext.getRuntimeArrayType(resultType,
1428
0
                                                  /*ArrayStride*/ llvm::None);
1429
0
    }
1430
0
  }
1431
1432
348
  const auto sc = forPC               ? 
spv::StorageClass::PushConstant32
1433
348
                  : 
forShaderRecordNV316
?
spv::StorageClass::ShaderRecordBufferNV2
1434
316
                  : 
forShaderRecordEXT314
1435
314
                      ? 
spv::StorageClass::ShaderRecordBufferKHR2
1436
314
                      : 
spv::StorageClass::Uniform312
;
1437
1438
  // Create the variable for the whole struct / struct array.
1439
  // The fields may be 'precise', but the structure itself is not.
1440
348
  SpirvVariable *var = spvBuilder.addModuleVar(
1441
348
      resultType, sc, /*isPrecise*/ false, /*isNoInterp*/ false, varName);
1442
1443
348
  const SpirvLayoutRule layoutRule =
1444
348
      (forCBuffer || 
forGlobals198
)
1445
348
          ? 
spirvOptions.cBufferLayoutRule280
1446
348
          : 
(68
forTBuffer68
?
spirvOptions.tBufferLayoutRule32
1447
68
                        : 
spirvOptions.sBufferLayoutRule36
);
1448
1449
348
  var->setHlslUserType(forCBuffer ? 
"cbuffer"150
:
forTBuffer198
?
"tbuffer"32
:
""166
);
1450
348
  var->setLayoutRule(layoutRule);
1451
348
  return var;
1452
348
}
1453
1454
SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
1455
    const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
1456
348
    llvm::StringRef typeName, llvm::StringRef varName) {
1457
348
  llvm::SmallVector<int, 1> arraySizes;
1458
348
  if (arraySize > 0)
1459
0
    arraySizes.push_back(arraySize);
1460
1461
348
  return createStructOrStructArrayVarOfExplicitLayout(
1462
348
      decl, arraySizes, usageKind, typeName, varName);
1463
348
}
1464
1465
1.41k
void DeclResultIdMapper::createEnumConstant(const EnumConstantDecl *decl) {
1466
1.41k
  const auto *valueDecl = dyn_cast<ValueDecl>(decl);
1467
1.41k
  const auto enumConstant =
1468
1.41k
      spvBuilder.getConstantInt(astContext.IntTy, decl->getInitVal());
1469
1.41k
  SpirvVariable *varInstr = spvBuilder.addModuleVar(
1470
1.41k
      astContext.IntTy, spv::StorageClass::Private, /*isPrecise*/ false, false,
1471
1.41k
      decl->getName(), enumConstant, decl->getLocation());
1472
1.41k
  astDecls[valueDecl] = createDeclSpirvInfo(varInstr);
1473
1.41k
}
1474
1475
186
void DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
1476
1477
186
  SmallVector<const VarDecl *, 4> variablesToDeclare;
1478
600
  for (const auto *subDecl : decl->decls()) {
1479
600
    if (shouldSkipInStructLayout(subDecl))
1480
16
      continue;
1481
1482
    // If the subDecl is a resource, it is lowered as a standalone variable.
1483
584
    const auto *varDecl = cast<VarDecl>(subDecl);
1484
584
    if (isResourceType(varDecl->getType())) {
1485
6
      createExternVar(varDecl);
1486
6
      continue;
1487
6
    }
1488
1489
578
    variablesToDeclare.push_back(varDecl);
1490
578
  }
1491
1492
  // If the cbuffer is empty or only contains resources, skip the variable
1493
  // creation.
1494
186
  if (variablesToDeclare.size() == 0)
1495
4
    return;
1496
1497
  // This function handles creation of cbuffer or tbuffer.
1498
182
  const auto usageKind =
1499
182
      decl->isCBuffer() ? 
ContextUsageKind::CBuffer150
:
ContextUsageKind::TBuffer32
;
1500
182
  const std::string structName = "type." + decl->getName().str();
1501
  // The front-end does not allow arrays of cbuffer/tbuffer.
1502
182
  SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
1503
182
      decl, /*arraySize*/ 0, usageKind, structName, decl->getName());
1504
1505
  // We still register all VarDecls seperately here. All the VarDecls are
1506
  // mapped to the <result-id> of the buffer object, which means when querying
1507
  // querying the <result-id> for a certain VarDecl, we need to do an extra
1508
  // OpAccessChain.
1509
760
  for (unsigned I = 0; I < variablesToDeclare.size(); 
++I578
)
1510
578
    registerVariableForDecl(variablesToDeclare[I],
1511
578
                            createDeclSpirvInfo(bufferVar, I));
1512
1513
182
  resourceVars.emplace_back(
1514
182
      bufferVar, decl, decl->getLocation(), getResourceBinding(decl),
1515
182
      decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
1516
1517
182
  if (!spirvOptions.debugInfoRich)
1518
172
    return;
1519
1520
10
  auto *dbgGlobalVar = createDebugGlobalVariable(
1521
10
      bufferVar, QualType(), decl->getLocation(), decl->getName());
1522
10
  assert(dbgGlobalVar);
1523
10
  (void)dbgGlobalVar; // For NDEBUG builds.
1524
1525
10
  auto *resultType = bufferVar->getResultType();
1526
  // Depending on the requested layout (DX or VK), constant buffers is either a
1527
  // struct containing every constant fields, or a pointer to the type. This is
1528
  // caused by the workaround we implemented to support FXC/DX layout. See #3672
1529
  // for more details.
1530
10
  assert(isa<SpirvPointerType>(resultType) ||
1531
10
         isa<HybridStructType>(resultType));
1532
10
  if (auto *ptr = dyn_cast<SpirvPointerType>(resultType))
1533
2
    resultType = ptr->getPointeeType();
1534
  // Debug type lowering requires the HLSLBufferDecl. Updating the type<>decl
1535
  // mapping.
1536
10
  spvContext.registerStructDeclForSpirvType(resultType, decl);
1537
1538
10
  return;
1539
182
}
1540
1541
38
SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
1542
  // The front-end errors out if non-struct type push constant is used.
1543
38
  const QualType type = decl->getType();
1544
38
  const auto *recordType = type->getAs<RecordType>();
1545
1546
38
  SpirvVariable *var = nullptr;
1547
1548
38
  if (isConstantBuffer(type)) {
1549
    // Constant buffers already have Block decoration. The variable will need
1550
    // the PushConstant storage class.
1551
1552
    // Create the variable for the whole struct / struct array.
1553
    // The fields may be 'precise', but the structure itself is not.
1554
6
    var = spvBuilder.addModuleVar(type, spv::StorageClass::PushConstant,
1555
6
                                  /*isPrecise*/ false,
1556
6
                                  /*isNoInterp*/ false, decl->getName());
1557
1558
6
    const SpirvLayoutRule layoutRule = spirvOptions.sBufferLayoutRule;
1559
1560
6
    var->setHlslUserType("");
1561
6
    var->setLayoutRule(layoutRule);
1562
32
  } else {
1563
32
    assert(recordType);
1564
32
    const std::string structName =
1565
32
        "type.PushConstant." + recordType->getDecl()->getName().str();
1566
32
    var = createStructOrStructArrayVarOfExplicitLayout(
1567
32
        recordType->getDecl(), /*arraySize*/ 0, ContextUsageKind::PushConstant,
1568
32
        structName, decl->getName());
1569
32
  }
1570
1571
  // Register the VarDecl
1572
38
  registerVariableForDecl(decl, createDeclSpirvInfo(var));
1573
1574
  // Do not push this variable into resourceVars since it does not need
1575
  // descriptor set.
1576
1577
38
  return var;
1578
38
}
1579
1580
SpirvVariable *
1581
DeclResultIdMapper::createShaderRecordBuffer(const VarDecl *decl,
1582
14
                                             ContextUsageKind kind) {
1583
14
  const QualType type = decl->getType();
1584
14
  const auto *recordType =
1585
14
      hlsl::GetHLSLResourceResultType(type)->getAs<RecordType>();
1586
14
  assert(recordType);
1587
1588
14
  assert(kind == ContextUsageKind::ShaderRecordBufferKHR ||
1589
14
         kind == ContextUsageKind::ShaderRecordBufferNV);
1590
1591
14
  SpirvVariable *var = nullptr;
1592
14
  if (isConstantBuffer(type)) {
1593
    // Constant buffers already have Block decoration. The variable will need
1594
    // the appropriate storage class.
1595
1596
14
    const auto sc = kind == ContextUsageKind::ShaderRecordBufferNV
1597
14
                        ? 
spv::StorageClass::ShaderRecordBufferNV6
1598
14
                        : 
spv::StorageClass::ShaderRecordBufferKHR8
;
1599
1600
    // Create the variable for the whole struct / struct array.
1601
    // The fields may be 'precise', but the structure itself is not.
1602
14
    var = spvBuilder.addModuleVar(type, sc,
1603
14
                                  /*isPrecise*/ false,
1604
14
                                  /*isNoInterp*/ false, decl->getName());
1605
1606
14
    const SpirvLayoutRule layoutRule = spirvOptions.sBufferLayoutRule;
1607
1608
14
    var->setHlslUserType("");
1609
14
    var->setLayoutRule(layoutRule);
1610
14
  } else {
1611
0
    const auto typeName = kind == ContextUsageKind::ShaderRecordBufferKHR
1612
0
                              ? "type.ShaderRecordBufferKHR."
1613
0
                              : "type.ShaderRecordBufferNV.";
1614
1615
0
    const std::string structName =
1616
0
        typeName + recordType->getDecl()->getName().str();
1617
0
    var = createStructOrStructArrayVarOfExplicitLayout(
1618
0
        recordType->getDecl(), /*arraySize*/ 0, kind, structName,
1619
0
        decl->getName());
1620
0
  }
1621
1622
  // Register the VarDecl
1623
14
  registerVariableForDecl(decl, createDeclSpirvInfo(var));
1624
1625
  // Do not push this variable into resourceVars since it does not need
1626
  // descriptor set.
1627
1628
14
  return var;
1629
14
}
1630
1631
SpirvVariable *
1632
DeclResultIdMapper::createShaderRecordBuffer(const HLSLBufferDecl *decl,
1633
4
                                             ContextUsageKind kind) {
1634
4
  assert(kind == ContextUsageKind::ShaderRecordBufferKHR ||
1635
4
         kind == ContextUsageKind::ShaderRecordBufferNV);
1636
1637
4
  const auto typeName = kind == ContextUsageKind::ShaderRecordBufferKHR
1638
4
                            ? 
"type.ShaderRecordBufferKHR."2
1639
4
                            : 
"type.ShaderRecordBufferNV."2
;
1640
1641
4
  const std::string structName = typeName + decl->getName().str();
1642
  // The front-end does not allow arrays of cbuffer/tbuffer.
1643
4
  SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
1644
4
      decl, /*arraySize*/ 0, kind, structName, decl->getName());
1645
1646
  // We still register all VarDecls seperately here. All the VarDecls are
1647
  // mapped to the <result-id> of the buffer object, which means when
1648
  // querying the <result-id> for a certain VarDecl, we need to do an extra
1649
  // OpAccessChain.
1650
4
  int index = 0;
1651
20
  for (const auto *subDecl : decl->decls()) {
1652
20
    if (shouldSkipInStructLayout(subDecl))
1653
0
      continue;
1654
1655
    // If subDecl is a variable with resource type, we already added a separate
1656
    // OpVariable for it in createStructOrStructArrayVarOfExplicitLayout().
1657
20
    const auto *varDecl = cast<VarDecl>(subDecl);
1658
20
    if (isResourceType(varDecl->getType()))
1659
0
      continue;
1660
1661
20
    registerVariableForDecl(varDecl, createDeclSpirvInfo(bufferVar, index++));
1662
20
  }
1663
4
  return bufferVar;
1664
4
}
1665
1666
142
void DeclResultIdMapper::recordsSpirvTypeAlias(const Decl *decl) {
1667
142
  auto *typedefDecl = dyn_cast<TypedefNameDecl>(decl);
1668
142
  if (!typedefDecl)
1669
0
    return;
1670
1671
142
  if (!typedefDecl->hasAttr<VKCapabilityExtAttr>() &&
1672
142
      
!typedefDecl->hasAttr<VKExtensionExtAttr>()134
)
1673
134
    return;
1674
1675
8
  typeAliasesWithAttributes.push_back(typedefDecl);
1676
8
}
1677
1678
130
void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
1679
130
  if (astDecls.count(var) != 0)
1680
0
    return;
1681
1682
130
  const auto *context = var->getTranslationUnitDecl();
1683
130
  SpirvVariable *globals = createStructOrStructArrayVarOfExplicitLayout(
1684
130
      context, /*arraySize*/ 0, ContextUsageKind::Globals, "type.$Globals",
1685
130
      "$Globals");
1686
1687
130
  uint32_t index = 0;
1688
248
  for (const auto *decl : collectDeclsInDeclContext(context)) {
1689
248
    if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
1690
248
      if (!spirvOptions.noWarnIgnoredFeatures) {
1691
246
        if (const auto *init = varDecl->getInit())
1692
2
          emitWarning(
1693
2
              "variable '%0' will be placed in $Globals so initializer ignored",
1694
2
              init->getExprLoc())
1695
2
              << var->getName() << init->getSourceRange();
1696
246
      }
1697
248
      if (const auto *attr = varDecl->getAttr<VKBindingAttr>()) {
1698
2
        emitError("variable '%0' will be placed in $Globals so cannot have "
1699
2
                  "vk::binding attribute",
1700
2
                  attr->getLocation())
1701
2
            << var->getName();
1702
2
        return;
1703
2
      }
1704
1705
      // If subDecl is a variable with resource type, we already added a
1706
      // separate OpVariable for it in
1707
      // createStructOrStructArrayVarOfExplicitLayout().
1708
246
      if (isResourceType(varDecl->getType()))
1709
0
        continue;
1710
1711
246
      registerVariableForDecl(varDecl, createDeclSpirvInfo(globals, index++));
1712
246
    }
1713
248
  }
1714
1715
  // If it does not contains a member with non-resource type, we do not want to
1716
  // set a dedicated binding number.
1717
128
  if (index != 0) {
1718
128
    resourceVars.emplace_back(globals, /*decl*/ nullptr, SourceLocation(),
1719
128
                              nullptr, nullptr, nullptr, /*isCounterVar*/ false,
1720
128
                              /*isGlobalsCBuffer*/ true);
1721
128
  }
1722
128
}
1723
1724
6.29k
SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
1725
  // Return it if it's already been created.
1726
6.29k
  auto it = astFunctionDecls.find(fn);
1727
6.29k
  if (it != astFunctionDecls.end()) {
1728
1.71k
    return it->second;
1729
1.71k
  }
1730
1731
4.57k
  bool isAlias = false;
1732
4.57k
  (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
1733
1734
4.57k
  const bool isPrecise = fn->hasAttr<HLSLPreciseAttr>();
1735
4.57k
  const bool isNoInline = fn->hasAttr<NoInlineAttr>();
1736
  // Note: we do not need to worry about function parameter types at this point
1737
  // as this is used when function declarations are seen. When function
1738
  // definition is seen, the parameter types will be set properly and take into
1739
  // account whether the function is a member function of a class/struct (in
1740
  // which case a 'this' parameter is added at the beginnig).
1741
4.57k
  SpirvFunction *spirvFunction = spvBuilder.createSpirvFunction(
1742
4.57k
      fn->getReturnType(), fn->getLocation(),
1743
4.57k
      getFunctionOrOperatorName(fn, true), isPrecise, isNoInline);
1744
1745
4.57k
  if (fn->getAttr<HLSLExportAttr>()) {
1746
8
    spvBuilder.decorateLinkage(nullptr, spirvFunction, fn->getName(),
1747
8
                               spv::LinkageType::Export, fn->getLocation());
1748
8
  }
1749
1750
  // No need to dereference to get the pointer. Function returns that are
1751
  // stand-alone aliases are already pointers to values. All other cases should
1752
  // be normal rvalues.
1753
4.57k
  if (!isAlias || 
!isAKindOfStructuredOrByteBuffer(fn->getReturnType())46
)
1754
4.54k
    spirvFunction->setRValue();
1755
1756
4.57k
  spirvFunction->setConstainsAliasComponent(isAlias);
1757
1758
4.57k
  astFunctionDecls[fn] = spirvFunction;
1759
4.57k
  return spirvFunction;
1760
6.29k
}
1761
1762
const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
1763
20.4k
    const DeclaratorDecl *decl, const llvm::SmallVector<uint32_t, 4> *indices) {
1764
20.4k
  if (!decl)
1765
6.08k
    return nullptr;
1766
1767
14.3k
  if (indices) {
1768
    // Indices are provided. Walk through the fields of the decl.
1769
1.06k
    const auto counter = fieldCounterVars.find(decl);
1770
1.06k
    if (counter != fieldCounterVars.end())
1771
72
      return counter->second.get(*indices);
1772
13.2k
  } else {
1773
    // No indices. Check the stand-alone entities. If not found,
1774
    // likely a deferred RWStructuredBuffer counter, so try
1775
    // creating it now.
1776
13.2k
    auto counter = counterVars.find(decl);
1777
13.2k
    if (counter == counterVars.end()) {
1778
12.5k
      auto declInstr = declRWSBuffers[decl];
1779
12.5k
      if (declInstr) {
1780
150
        createCounterVar(decl, declInstr, /*isAlias*/ false);
1781
150
        counter = counterVars.find(decl);
1782
150
      }
1783
12.5k
    }
1784
13.2k
    if (counter != counterVars.end())
1785
918
      return &counter->second;
1786
13.2k
  }
1787
1788
13.3k
  return nullptr;
1789
14.3k
}
1790
1791
const CounterIdAliasPair *
1792
13.2k
DeclResultIdMapper::createOrGetCounterIdAliasPair(const DeclaratorDecl *decl) {
1793
13.2k
  auto counterPair = getCounterIdAliasPair(decl);
1794
13.2k
  if (counterPair)
1795
918
    return counterPair;
1796
12.3k
  if (!decl)
1797
0
    return nullptr;
1798
  // If deferred RWStructuredBuffer, try creating the counter now
1799
12.3k
  auto declInstr = declRWSBuffers[decl];
1800
12.3k
  if (declInstr) {
1801
0
    createCounterVar(decl, declInstr, /*isAlias*/ false);
1802
0
    auto counter = counterVars.find(decl);
1803
0
    assert(counter != counterVars.end() && "counter not found");
1804
0
    return &counter->second;
1805
0
  }
1806
12.3k
  return nullptr;
1807
12.3k
}
1808
1809
const CounterVarFields *
1810
27.7k
DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
1811
27.7k
  if (!decl)
1812
11.6k
    return nullptr;
1813
1814
16.0k
  const auto found = fieldCounterVars.find(decl);
1815
16.0k
  if (found != fieldCounterVars.end())
1816
88
    return &found->second;
1817
1818
15.9k
  return nullptr;
1819
16.0k
}
1820
1821
void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
1822
44
                                              SpirvInstruction *specConstant) {
1823
44
  specConstant->setRValue();
1824
44
  registerVariableForDecl(decl, createDeclSpirvInfo(specConstant));
1825
44
}
1826
1827
void DeclResultIdMapper::createCounterVar(
1828
    const DeclaratorDecl *decl, SpirvInstruction *declInstr, bool isAlias,
1829
790
    const llvm::SmallVector<uint32_t, 4> *indices) {
1830
790
  std::string counterName = "counter.var." + decl->getName().str();
1831
790
  if (indices) {
1832
    // Append field indices to the name
1833
182
    for (const auto index : *indices)
1834
446
      counterName += "." + std::to_string(index);
1835
182
  }
1836
1837
790
  const SpirvType *counterType = spvContext.getACSBufferCounterType();
1838
790
  llvm::Optional<uint32_t> noArrayStride;
1839
790
  QualType declType = decl->getType();
1840
790
  if (declType->isArrayType()) {
1841
    // Vulkan does not support multi-dimentional arrays of resource, so we
1842
    // assume the array is a single dimensional array.
1843
26
    assert(!declType->getArrayElementTypeNoTypeQual()->isArrayType());
1844
1845
26
    if (const auto *constArrayType =
1846
26
            astContext.getAsConstantArrayType(declType)) {
1847
22
      counterType = spvContext.getArrayType(
1848
22
          counterType, constArrayType->getSize().getZExtValue(), noArrayStride);
1849
22
    } else {
1850
4
      assert(declType->isIncompleteArrayType());
1851
4
      counterType = spvContext.getRuntimeArrayType(counterType, noArrayStride);
1852
4
    }
1853
764
  } else if (isResourceDescriptorHeap(decl->getType()) ||
1854
764
             
isSamplerDescriptorHeap(decl->getType())738
) {
1855
26
    counterType = spvContext.getRuntimeArrayType(counterType, noArrayStride);
1856
26
  }
1857
1858
  // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
1859
  // Alias counter variables should be created into the Private storage class.
1860
790
  const spv::StorageClass sc =
1861
790
      isAlias ? 
spv::StorageClass::Private468
:
spv::StorageClass::Uniform322
;
1862
1863
790
  if (isAlias) {
1864
    // Apply an extra level of pointer for alias counter variable
1865
468
    counterType =
1866
468
        spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
1867
468
  }
1868
1869
790
  SpirvVariable *counterInstr = spvBuilder.addModuleVar(
1870
790
      counterType, sc, /*isPrecise*/ false, false, declInstr, counterName);
1871
1872
790
  if (!isAlias) {
1873
    // Non-alias counter variables should be put in to resourceVars so that
1874
    // descriptors can be allocated for them.
1875
322
    resourceVars.emplace_back(counterInstr, decl, decl->getLocation(),
1876
322
                              getResourceBinding(decl),
1877
322
                              decl->getAttr<VKBindingAttr>(),
1878
322
                              decl->getAttr<VKCounterBindingAttr>(), true);
1879
322
    assert(declInstr);
1880
322
    spvBuilder.decorateCounterBuffer(declInstr, counterInstr,
1881
322
                                     decl->getLocation());
1882
322
  }
1883
1884
790
  if (indices)
1885
182
    fieldCounterVars[decl].append(*indices, counterInstr);
1886
608
  else
1887
608
    counterVars[decl] = {counterInstr, isAlias};
1888
790
}
1889
1890
void DeclResultIdMapper::createFieldCounterVars(
1891
    const DeclaratorDecl *rootDecl, const DeclaratorDecl *decl,
1892
396
    llvm::SmallVector<uint32_t, 4> *indices) {
1893
396
  const QualType type = getTypeOrFnRetType(decl);
1894
396
  const auto *recordType = type->getAs<RecordType>();
1895
396
  assert(recordType);
1896
396
  const auto *recordDecl = recordType->getDecl();
1897
1898
758
  for (const auto *field : recordDecl->fields()) {
1899
    // Build up the index chain
1900
758
    indices->push_back(getNumBaseClasses(type) + field->getFieldIndex());
1901
1902
758
    const QualType fieldType = field->getType();
1903
758
    if (isRWAppendConsumeSBuffer(fieldType))
1904
182
      createCounterVar(rootDecl, /*declId=*/0, /*isAlias=*/true, indices);
1905
576
    else if (fieldType->isStructureType() &&
1906
576
             
!hlsl::IsHLSLResourceType(fieldType)132
)
1907
      // Go recursively into all nested structs
1908
94
      createFieldCounterVars(rootDecl, field, indices);
1909
1910
758
    indices->pop_back();
1911
758
  }
1912
396
}
1913
1914
std::vector<SpirvVariable *>
1915
2.88k
DeclResultIdMapper::collectStageVars(SpirvFunction *entryPoint) const {
1916
2.88k
  std::vector<SpirvVariable *> vars;
1917
1918
2.88k
  for (auto var : glPerVertex.getStageInVars())
1919
60
    vars.push_back(var);
1920
2.88k
  for (auto var : glPerVertex.getStageOutVars())
1921
32
    vars.push_back(var);
1922
1923
5.21k
  for (const auto &var : stageVars) {
1924
    // We must collect stage variables that are included in entryPoint and stage
1925
    // variables that are not included in any specific entryPoint i.e.,
1926
    // var.getEntryPoint() is nullptr. Note that stage variables without any
1927
    // specific entry point are common stage variables among all entry points.
1928
5.21k
    if (var.getEntryPoint() && 
var.getEntryPoint() != entryPoint4.15k
)
1929
250
      continue;
1930
4.96k
    auto *instr = var.getSpirvInstr();
1931
4.96k
    if (instr->getStorageClass() == spv::StorageClass::Private)
1932
2
      continue;
1933
4.96k
    vars.push_back(instr);
1934
4.96k
  }
1935
1936
2.88k
  return vars;
1937
2.88k
}
1938
1939
namespace {
1940
/// A class for managing stage input/output locations to avoid duplicate uses of
1941
/// the same location.
1942
class LocationSet {
1943
public:
1944
  /// Maximum number of indices supported
1945
  const static uint32_t kMaxIndex = 2;
1946
1947
  // Creates an empty set.
1948
3.09k
  LocationSet() {
1949
9.29k
    for (uint32_t i = 0; i < kMaxIndex; 
++i6.19k
) {
1950
      // Default size. 64 should cover most cases without having to resize.
1951
6.19k
      usedLocations[i].resize(64);
1952
6.19k
      nextAvailableLocation[i] = 0;
1953
6.19k
    }
1954
3.09k
  }
1955
1956
  /// Marks a given location as used.
1957
772
  void useLocation(uint32_t loc, uint32_t index = 0) {
1958
772
    assert(index < kMaxIndex);
1959
1960
772
    auto &set = usedLocations[index];
1961
772
    if (loc >= set.size()) {
1962
2
      set.resize(std::max<size_t>(loc + 1, set.size() * 2));
1963
2
    }
1964
772
    set.set(loc);
1965
772
    nextAvailableLocation[index] =
1966
772
        std::max(loc + 1, nextAvailableLocation[index]);
1967
772
  }
1968
1969
  // Find the first range of size |count| of unused locations,
1970
  // and marks them as used.
1971
  // Returns the first index of this range.
1972
1.91k
  int useNextNLocations(uint32_t count, uint32_t index = 0) {
1973
1.91k
    auto res = findUnusedRange(index, count);
1974
1.91k
    auto &locations = usedLocations[index];
1975
1976
    // Simple case: no hole large enough left, resizing.
1977
1.91k
    if (res == std::nullopt) {
1978
12
      const uint32_t spaceLeft =
1979
12
          locations.size() - nextAvailableLocation[index];
1980
12
      assert(spaceLeft < count && "There is a bug.");
1981
1982
12
      const uint32_t requiredAlloc = count - spaceLeft;
1983
12
      locations.resize(locations.size() + requiredAlloc);
1984
12
      res = nextAvailableLocation[index];
1985
12
    }
1986
1987
11.5k
    for (uint32_t i = res.value(); i < res.value() + count; 
i++9.67k
) {
1988
9.67k
      locations.set(i);
1989
9.67k
    }
1990
1991
1.91k
    nextAvailableLocation[index] =
1992
1.91k
        std::max(res.value() + count, nextAvailableLocation[index]);
1993
1.91k
    return res.value();
1994
1.91k
  }
1995
1996
  /// Returns true if the given location number is already used.
1997
94
  bool isLocationUsed(uint32_t loc, uint32_t index = 0) {
1998
94
    assert(index < kMaxIndex);
1999
94
    if (loc >= usedLocations[index].size())
2000
2
      return false;
2001
92
    return usedLocations[index][loc];
2002
94
  }
2003
2004
private:
2005
  // Find the first unused range of size |size| in the given set.
2006
  // If the set contains such range, returns the first usable index.
2007
  // Otherwise, nullopt is returned.
2008
1.91k
  std::optional<uint32_t> findUnusedRange(uint32_t index, uint32_t size) {
2009
1.91k
    if (size == 0) {
2010
0
      return 0;
2011
0
    }
2012
2013
1.91k
    assert(index < kMaxIndex);
2014
1.91k
    const auto &locations = usedLocations[index];
2015
2016
1.91k
    uint32_t required_size = size;
2017
1.91k
    uint32_t start = 0;
2018
16.7k
    for (uint32_t i = 0; i < locations.size() && 
required_size > 016.7k
;
i++14.8k
) {
2019
14.8k
      if (!locations[i]) {
2020
3.05k
        --required_size;
2021
3.05k
        continue;
2022
3.05k
      }
2023
2024
11.7k
      required_size = size;
2025
11.7k
      start = i + 1;
2026
11.7k
    }
2027
2028
1.91k
    return required_size == 0 ? 
std::optional(start)1.90k
:
std::nullopt12
;
2029
1.91k
  }
2030
2031
  // The sets to remember used locations. A set bit means the location is used.
2032
  /// All previously used locations
2033
  llvm::SmallBitVector usedLocations[kMaxIndex];
2034
2035
  // The position of the last bit set in the usedLocation vector.
2036
  uint32_t nextAvailableLocation[kMaxIndex];
2037
};
2038
2039
} // namespace
2040
2041
/// A class for managing resource bindings to avoid duplicate uses of the same
2042
/// set and binding number.
2043
class DeclResultIdMapper::BindingSet {
2044
public:
2045
  /// Uses the given set and binding number. Returns false if the binding number
2046
  /// was already occupied in the set, and returns true otherwise.
2047
986
  bool useBinding(uint32_t binding, uint32_t set) {
2048
986
    bool inserted = false;
2049
986
    std::tie(std::ignore, inserted) = usedBindings[set].insert(binding);
2050
986
    return inserted;
2051
986
  }
2052
2053
  /// Uses the next available binding number in |set|. If more than one binding
2054
  /// number is to be occupied, it finds the next available chunk that can fit
2055
  /// |numBindingsToUse| in the |set|.
2056
  uint32_t useNextBinding(uint32_t set, uint32_t numBindingsToUse = 1,
2057
3.11k
                          uint32_t bindingShift = 0) {
2058
3.11k
    uint32_t bindingNoStart =
2059
3.11k
        getNextBindingChunk(set, numBindingsToUse, bindingShift);
2060
3.11k
    auto &binding = usedBindings[set];
2061
6.41k
    for (uint32_t i = 0; i < numBindingsToUse; 
++i3.29k
)
2062
3.29k
      binding.insert(bindingNoStart + i);
2063
3.11k
    return bindingNoStart;
2064
3.11k
  }
2065
2066
  /// Returns the first available binding number in the |set| for which |n|
2067
  /// consecutive binding numbers are unused starting at |bindingShift|.
2068
  uint32_t getNextBindingChunk(uint32_t set, uint32_t n,
2069
3.11k
                               uint32_t bindingShift) {
2070
3.11k
    auto &existingBindings = usedBindings[set];
2071
2072
    // There were no bindings in this set. Can start at binding zero.
2073
3.11k
    if (existingBindings.empty())
2074
1.02k
      return bindingShift;
2075
2076
    // Check whether the chunk of |n| binding numbers can be fitted at the
2077
    // very beginning of the list (start at binding 0 in the current set).
2078
2.08k
    uint32_t curBinding = *existingBindings.begin();
2079
2.08k
    if (curBinding >= (n + bindingShift))
2080
98
      return bindingShift;
2081
2082
1.98k
    auto iter = std::next(existingBindings.begin());
2083
11.5k
    while (iter != existingBindings.end()) {
2084
      // There exists a next binding number that is used. Check to see if the
2085
      // gap between current binding number and next binding number is large
2086
      // enough to accommodate |n|.
2087
9.70k
      uint32_t nextBinding = *iter;
2088
9.70k
      if ((bindingShift > 0) && 
(curBinding < (bindingShift - 1))1.43k
)
2089
952
        curBinding = bindingShift - 1;
2090
2091
9.70k
      if (curBinding < nextBinding && 
n <= nextBinding - curBinding - 18.79k
)
2092
100
        return curBinding + 1;
2093
2094
9.60k
      curBinding = nextBinding;
2095
2096
      // Peek at the next binding that has already been used (if any).
2097
9.60k
      ++iter;
2098
9.60k
    }
2099
2100
    // |curBinding| was the last binding that was used in this set. The next
2101
    // chunk of |n| bindings can start at |curBinding|+1.
2102
1.88k
    return std::max(curBinding + 1, bindingShift);
2103
1.98k
  }
2104
2105
private:
2106
  ///< set number -> set of used binding number
2107
  llvm::DenseMap<uint32_t, std::set<uint32_t>> usedBindings;
2108
};
2109
2110
5.31k
bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
2111
  // Mapping from entry points to the corresponding set of semantics.
2112
5.31k
  llvm::SmallDenseMap<SpirvFunction *, llvm::StringSet<>>
2113
5.31k
      seenSemanticsForEntryPoints;
2114
5.31k
  bool success = true;
2115
8.08k
  for (const auto &var : stageVars) {
2116
8.08k
    auto s = var.getSemanticStr();
2117
2118
8.08k
    if (s.empty()) {
2119
      // We translate WaveGetLaneCount(), WaveGetLaneIndex() and 'payload' param
2120
      // block declaration into builtin variables. Those variables are inserted
2121
      // into the normal stage IO processing pipeline, but with the semantics as
2122
      // empty strings.
2123
220
      assert(var.isSpirvBuitin());
2124
220
      continue;
2125
220
    }
2126
2127
7.86k
    if (forInput && 
var.getSigPoint()->IsInput()3.93k
) {
2128
1.39k
      bool insertionSuccess = insertSeenSemanticsForEntryPointIfNotExist(
2129
1.39k
          &seenSemanticsForEntryPoints, var.getEntryPoint(), s);
2130
1.39k
      if (!insertionSuccess) {
2131
6
        emitError("input semantic '%0' used more than once",
2132
6
                  var.getSemanticInfo().loc)
2133
6
            << s;
2134
6
        success = false;
2135
6
      }
2136
6.47k
    } else if (!forInput && 
var.getSigPoint()->IsOutput()3.93k
) {
2137
1.47k
      bool insertionSuccess = insertSeenSemanticsForEntryPointIfNotExist(
2138
1.47k
          &seenSemanticsForEntryPoints, var.getEntryPoint(), s);
2139
1.47k
      if (!insertionSuccess) {
2140
4
        emitError("output semantic '%0' used more than once",
2141
4
                  var.getSemanticInfo().loc)
2142
4
            << s;
2143
4
        success = false;
2144
4
      }
2145
1.47k
    }
2146
7.86k
  }
2147
2148
5.31k
  return success;
2149
5.31k
}
2150
2151
bool DeclResultIdMapper::isDuplicatedStageVarLocation(
2152
    llvm::DenseSet<StageVariableLocationInfo, StageVariableLocationInfo>
2153
        *stageVariableLocationInfo,
2154
2.66k
    const StageVar &var, uint32_t location, uint32_t index) {
2155
2.66k
  if (!stageVariableLocationInfo
2156
2.66k
           ->insert({var.getEntryPoint(),
2157
2.66k
                     var.getSpirvInstr()->getStorageClass(), location, index})
2158
2.66k
           .second) {
2159
10
    emitError("Multiple stage variables have a duplicated pair of "
2160
10
              "location and index at %0 / %1",
2161
10
              var.getSpirvInstr()->getSourceLocation())
2162
10
        << location << index;
2163
10
    return false;
2164
10
  }
2165
2.65k
  return true;
2166
2.66k
}
2167
2168
bool DeclResultIdMapper::assignLocations(
2169
    const std::vector<const StageVar *> &vars,
2170
    llvm::function_ref<uint32_t(uint32_t)> nextLocs,
2171
    llvm::DenseSet<StageVariableLocationInfo, StageVariableLocationInfo>
2172
980
        *stageVariableLocationInfo) {
2173
1.95k
  for (const auto *var : vars) {
2174
1.95k
    if (hlsl::IsHLSLNodeType(var->getAstType()))
2175
64
      continue;
2176
1.89k
    auto locCount = var->getLocationCount();
2177
1.89k
    uint32_t location = nextLocs(locCount);
2178
1.89k
    spvBuilder.decorateLocation(var->getSpirvInstr(), location);
2179
2180
1.89k
    if (!isDuplicatedStageVarLocation(stageVariableLocationInfo, *var, location,
2181
1.89k
                                      0)) {
2182
0
      return false;
2183
0
    }
2184
1.89k
  }
2185
980
  return true;
2186
980
}
2187
2188
bool DeclResultIdMapper::finalizeStageIOLocationsForASingleEntryPoint(
2189
3.09k
    bool forInput, ArrayRef<StageVar> functionStageVars) {
2190
  // Returns false if the given StageVar is an input/output variable without
2191
  // explicit location assignment. Otherwise, returns true.
2192
5.37k
  const auto locAssigned = [forInput, this](const StageVar &v) {
2193
5.37k
    if (forInput == isInputStorageClass(v)) {
2194
      // No need to assign location for builtins. Treat as assigned.
2195
2.62k
      return v.isSpirvBuitin() || 
v.hasLocOrBuiltinDecorateAttr()1.76k
||
2196
2.62k
             
v.getLocationAttr() != nullptr1.74k
;
2197
2.62k
    }
2198
    // For the ones we don't care, treat as assigned.
2199
2.75k
    return true;
2200
5.37k
  };
2201
2202
  /// Set of locations of assigned stage variables used to correctly report
2203
  /// duplicated stage variable locations.
2204
3.09k
  llvm::DenseSet<StageVariableLocationInfo, StageVariableLocationInfo>
2205
3.09k
      stageVariableLocationInfo;
2206
2207
  // If we have explicit location specified for all input/output variables,
2208
  // use them instead assign by ourselves.
2209
3.09k
  if (std::all_of(functionStageVars.begin(), functionStageVars.end(),
2210
3.09k
                  locAssigned)) {
2211
1.44k
    LocationSet locSet;
2212
1.44k
    bool noError = true;
2213
2214
2.24k
    for (const auto &var : functionStageVars) {
2215
      // Skip builtins & those stage variables we are not handling for this call
2216
2.24k
      if (var.isSpirvBuitin() || 
var.hasLocOrBuiltinDecorateAttr()884
||
2217
2.24k
          
forInput != isInputStorageClass(var)848
) {
2218
2.15k
        continue;
2219
2.15k
      }
2220
2221
90
      const auto *attr = var.getLocationAttr();
2222
90
      const auto loc = attr->getNumber();
2223
90
      const auto locCount = var.getLocationCount();
2224
90
      const auto attrLoc = attr->getLocation(); // Attr source code location
2225
90
      const auto idx = var.getIndexAttr() ? 
var.getIndexAttr()->getNumber()4
:
086
;
2226
2227
      // Make sure the same location is not assigned more than once
2228
184
      for (uint32_t l = loc; l < loc + locCount; 
++l94
) {
2229
94
        if (locSet.isLocationUsed(l, idx)) {
2230
10
          emitError("stage %select{output|input}0 location #%1 already "
2231
10
                    "consumed by semantic '%2'",
2232
10
                    attrLoc)
2233
10
              << forInput << l << functionStageVars[idx].getSemanticStr();
2234
10
          noError = false;
2235
10
        }
2236
2237
94
        locSet.useLocation(l, idx);
2238
94
      }
2239
2240
90
      spvBuilder.decorateLocation(var.getSpirvInstr(), loc);
2241
90
      if (var.getIndexAttr())
2242
4
        spvBuilder.decorateIndex(var.getSpirvInstr(), idx,
2243
4
                                 var.getSemanticInfo().loc);
2244
2245
90
      if (!isDuplicatedStageVarLocation(&stageVariableLocationInfo, var, loc,
2246
90
                                        idx)) {
2247
8
        return false;
2248
8
      }
2249
90
    }
2250
2251
1.44k
    return noError;
2252
1.44k
  }
2253
2254
1.65k
  std::vector<const StageVar *> vars;
2255
1.65k
  LocationSet locSet;
2256
2257
5.79k
  for (const auto &var : functionStageVars) {
2258
5.79k
    if (var.isSpirvBuitin() || 
var.hasLocOrBuiltinDecorateAttr()4.68k
||
2259
5.79k
        
forInput != isInputStorageClass(var)4.68k
) {
2260
3.11k
      continue;
2261
3.11k
    }
2262
2263
2.68k
    if (var.getLocationAttr()) {
2264
      // We have checked that not all of the stage variables have explicit
2265
      // location assignment.
2266
2
      emitError("partial explicit stage %select{output|input}0 location "
2267
2
                "assignment via vk::location(X) unsupported",
2268
2
                {})
2269
2
          << forInput;
2270
2
      return false;
2271
2
    }
2272
2273
2.67k
    const auto &semaInfo = var.getSemanticInfo();
2274
2275
    // We should special rules for SV_Target: the location number comes from the
2276
    // semantic string index.
2277
2.67k
    if (semaInfo.isTarget()) {
2278
678
      spvBuilder.decorateLocation(var.getSpirvInstr(), semaInfo.index);
2279
678
      locSet.useLocation(semaInfo.index);
2280
2281
678
      if (!isDuplicatedStageVarLocation(&stageVariableLocationInfo, var,
2282
678
                                        semaInfo.index, 0)) {
2283
2
        return false;
2284
2
      }
2285
2.00k
    } else {
2286
2.00k
      vars.push_back(&var);
2287
2.00k
    }
2288
2.67k
  }
2289
2290
1.64k
  if (vars.empty())
2291
660
    return true;
2292
2293
1.91k
  
auto nextLocs = [&locSet](uint32_t locCount) 986
{
2294
1.91k
    return locSet.useNextNLocations(locCount);
2295
1.91k
  };
2296
2297
  // If alphabetical ordering was requested, sort by semantic string.
2298
986
  if (spirvOptions.stageIoOrder == "alpha") {
2299
    // Sort stage input/output variables alphabetically
2300
4
    std::stable_sort(vars.begin(), vars.end(),
2301
56
                     [](const StageVar *a, const StageVar *b) {
2302
56
                       return a->getSemanticStr() < b->getSemanticStr();
2303
56
                     });
2304
4
    return assignLocations(vars, nextLocs, &stageVariableLocationInfo);
2305
4
  }
2306
2307
  // Pack signature if it is enabled. Vertext shader input and pixel
2308
  // shader output are special. We have to preserve the given signature.
2309
982
  auto sigPointKind = vars[0]->getSigPoint()->GetKind();
2310
982
  if (spirvOptions.signaturePacking &&
2311
982
      
sigPointKind != hlsl::SigPoint::Kind::VSIn8
&&
2312
982
      
sigPointKind != hlsl::SigPoint::Kind::PSOut6
) {
2313
6
    return packSignature(spvBuilder, vars, nextLocs, forInput);
2314
6
  }
2315
2316
  // Since HS includes 2 sets of outputs (patch-constant output and
2317
  // OutputPatch), running into location mismatches between HS and DS is very
2318
  // likely. In order to avoid location mismatches between HS and DS, use
2319
  // alphabetical ordering.
2320
976
  if ((!forInput && 
spvContext.isHS()336
) ||
(904
forInput904
&&
spvContext.isDS()640
)) {
2321
    // Sort stage input/output variables alphabetically
2322
96
    std::stable_sort(vars.begin(), vars.end(),
2323
654
                     [](const StageVar *a, const StageVar *b) {
2324
654
                       return a->getSemanticStr() < b->getSemanticStr();
2325
654
                     });
2326
96
  }
2327
976
  return assignLocations(vars, nextLocs, &stageVariableLocationInfo);
2328
982
}
2329
2330
llvm::DenseMap<const SpirvFunction *, SmallVector<StageVar, 8>>
2331
5.31k
DeclResultIdMapper::getStageVarsPerFunction() {
2332
5.31k
  llvm::DenseMap<const SpirvFunction *, SmallVector<StageVar, 8>> result;
2333
8.05k
  for (const auto &var : stageVars) {
2334
8.05k
    result[var.getEntryPoint()].push_back(var);
2335
8.05k
  }
2336
5.31k
  return result;
2337
5.31k
}
2338
2339
5.31k
bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
2340
5.31k
  if (!checkSemanticDuplication(forInput))
2341
6
    return false;
2342
2343
5.31k
  auto stageVarPerFunction = getStageVarsPerFunction();
2344
5.31k
  for (const auto &functionStageVars : stageVarPerFunction) {
2345
3.09k
    if (!finalizeStageIOLocationsForASingleEntryPoint(
2346
3.09k
            forInput, functionStageVars.getSecond())) {
2347
14
      return false;
2348
14
    }
2349
3.09k
  }
2350
5.29k
  return true;
2351
5.31k
}
2352
2353
namespace {
2354
/// A class for maintaining the binding number shift requested for descriptor
2355
/// sets.
2356
class BindingShiftMapper {
2357
public:
2358
  explicit BindingShiftMapper(const llvm::SmallVectorImpl<int32_t> &shifts)
2359
10.8k
      : masterShift(0) {
2360
10.8k
    assert(shifts.size() % 2 == 0);
2361
10.8k
    if (shifts.size() == 2 && 
shifts[1] == -126
) {
2362
16
      masterShift = shifts[0];
2363
10.8k
    } else {
2364
10.8k
      for (uint32_t i = 0; i < shifts.size(); 
i += 248
)
2365
48
        perSetShift[shifts[i + 1]] = shifts[i];
2366
10.8k
    }
2367
10.8k
  }
2368
2369
  /// Returns the shift amount for the given set.
2370
792
  int32_t getShiftForSet(int32_t set) const {
2371
792
    const auto found = perSetShift.find(set);
2372
792
    if (found != perSetShift.end())
2373
92
      return found->second;
2374
700
    return masterShift;
2375
792
  }
2376
2377
private:
2378
  uint32_t masterShift; /// Shift amount applies to all sets.
2379
  llvm::DenseMap<int32_t, int32_t> perSetShift;
2380
};
2381
2382
/// A class for maintaining the mapping from source code register attributes to
2383
/// descriptor set and number settings.
2384
class RegisterBindingMapper {
2385
public:
2386
  /// Takes in the relation between register attributes and descriptor settings.
2387
  /// Each relation is represented by four strings:
2388
  ///   <register-type-number> <space> <descriptor-binding> <set>
2389
  bool takeInRelation(const std::vector<std::string> &relation,
2390
18
                      std::string *error) {
2391
18
    assert(relation.size() % 4 == 0);
2392
18
    mapping.clear();
2393
2394
36
    for (uint32_t i = 0; i < relation.size(); 
i += 418
) {
2395
24
      int32_t spaceNo = -1, setNo = -1, bindNo = -1;
2396
24
      if (StringRef(relation[i + 1]).getAsInteger(10, spaceNo) || 
spaceNo < 022
) {
2397
2
        *error = "space number: " + relation[i + 1];
2398
2
        return false;
2399
2
      }
2400
22
      if (StringRef(relation[i + 2]).getAsInteger(10, bindNo) || bindNo < 0) {
2401
2
        *error = "binding number: " + relation[i + 2];
2402
2
        return false;
2403
2
      }
2404
20
      if (StringRef(relation[i + 3]).getAsInteger(10, setNo) || 
setNo < 018
) {
2405
2
        *error = "set number: " + relation[i + 3];
2406
2
        return false;
2407
2
      }
2408
18
      mapping[relation[i + 1] + relation[i]] = std::make_pair(setNo, bindNo);
2409
18
    }
2410
12
    return true;
2411
18
  }
2412
2413
  /// Returns true and set the correct set and binding number if we can find a
2414
  /// descriptor setting for the given register. False otherwise.
2415
  bool getSetBinding(const hlsl::RegisterAssignment *regAttr,
2416
18
                     uint32_t defaultSpace, int *setNo, int *bindNo) const {
2417
18
    std::ostringstream iss;
2418
18
    iss << regAttr->RegisterSpace.getValueOr(defaultSpace)
2419
18
        << regAttr->RegisterType << regAttr->RegisterNumber;
2420
2421
18
    auto found = mapping.find(iss.str());
2422
18
    if (found != mapping.end()) {
2423
16
      *setNo = found->second.first;
2424
16
      *bindNo = found->second.second;
2425
16
      return true;
2426
16
    }
2427
2428
2
    return false;
2429
18
  }
2430
2431
private:
2432
  llvm::StringMap<std::pair<int, int>> mapping;
2433
};
2434
} // namespace
2435
2436
2.72k
bool DeclResultIdMapper::decorateResourceBindings() {
2437
  // For normal resource, we support 4 approaches of setting binding numbers:
2438
  // - m1: [[vk::binding(...)]]
2439
  // - m2: :register(xX, spaceY)
2440
  // - m3: None
2441
  // - m4: :register(spaceY)
2442
  //
2443
  // For associated counters, we support 2 approaches:
2444
  // - c1: [[vk::counter_binding(...)]
2445
  // - c2: None
2446
  //
2447
  // In combination, we need to handle 12 cases:
2448
  // - 4 cases for nomral resoures (m1, m2, m3, m4)
2449
  // - 8 cases for associated counters (mX * cY)
2450
  //
2451
  // In the following order:
2452
  // - m1, mX * c1
2453
  // - m2
2454
  // - m3, m4, mX * c2
2455
2456
  // The "-auto-binding-space" command line option can be used to specify a
2457
  // certain space as default. UINT_MAX means the user has not provided this
2458
  // option. If not provided, the SPIR-V backend uses space "0" as default.
2459
2.72k
  auto defaultSpaceOpt =
2460
2.72k
      theEmitter.getCompilerInstance().getCodeGenOpts().HLSLDefaultSpace;
2461
2.72k
  uint32_t defaultSpace = (defaultSpaceOpt == UINT_MAX) ? 
02.72k
:
defaultSpaceOpt8
;
2462
2463
2.72k
  const bool bindGlobals = !spirvOptions.bindGlobals.empty();
2464
2.72k
  int32_t globalsBindNo = -1, globalsSetNo = -1;
2465
2.72k
  if (bindGlobals) {
2466
4
    assert(spirvOptions.bindGlobals.size() == 2);
2467
4
    if (StringRef(spirvOptions.bindGlobals[0])
2468
4
            .getAsInteger(10, globalsBindNo) ||
2469
4
        globalsBindNo < 0) {
2470
0
      emitError("invalid -fvk-bind-globals binding number: %0", {})
2471
0
          << spirvOptions.bindGlobals[0];
2472
0
      return false;
2473
0
    }
2474
4
    if (StringRef(spirvOptions.bindGlobals[1]).getAsInteger(10, globalsSetNo) ||
2475
4
        globalsSetNo < 0) {
2476
0
      emitError("invalid -fvk-bind-globals set number: %0", {})
2477
0
          << spirvOptions.bindGlobals[1];
2478
0
      return false;
2479
0
    }
2480
4
  }
2481
2482
  // Special handling of -fvk-bind-register, which requires
2483
  // * All resources are annoated with :register() in the source code
2484
  // * -fvk-bind-register is specified for every resource
2485
2.72k
  if (!spirvOptions.bindRegister.empty()) {
2486
18
    RegisterBindingMapper bindingMapper;
2487
18
    std::string error;
2488
2489
18
    if (!bindingMapper.takeInRelation(spirvOptions.bindRegister, &error)) {
2490
6
      emitError("invalid -fvk-bind-register %0", {}) << error;
2491
6
      return false;
2492
6
    }
2493
2494
12
    for (const auto &var : resourceVars)
2495
26
      if (const auto *regAttr = var.getRegister()) {
2496
20
        if (var.isCounter()) {
2497
2
          emitError("-fvk-bind-register for RW/Append/Consume StructuredBuffer "
2498
2
                    "unimplemented",
2499
2
                    var.getSourceLocation());
2500
18
        } else {
2501
18
          int setNo = 0, bindNo = 0;
2502
18
          if (!bindingMapper.getSetBinding(regAttr, defaultSpace, &setNo,
2503
18
                                           &bindNo)) {
2504
2
            emitError("missing -fvk-bind-register for resource",
2505
2
                      var.getSourceLocation());
2506
2
            return false;
2507
2
          }
2508
16
          spvBuilder.decorateDSetBinding(var.getSpirvInstr(), setNo, bindNo);
2509
16
        }
2510
20
      } else 
if (6
var.isGlobalsBuffer()6
) {
2511
4
        if (!bindGlobals) {
2512
2
          emitError("-fvk-bind-register requires Globals buffer to be bound "
2513
2
                    "with -fvk-bind-globals",
2514
2
                    var.getSourceLocation());
2515
2
          return false;
2516
2
        }
2517
2518
2
        spvBuilder.decorateDSetBinding(var.getSpirvInstr(), globalsSetNo,
2519
2
                                       globalsBindNo);
2520
2
      } else {
2521
2
        emitError(
2522
2
            "-fvk-bind-register requires register annotations on all resources",
2523
2
            var.getSourceLocation());
2524
2
        return false;
2525
2
      }
2526
2527
6
    return true;
2528
12
  }
2529
2530
2.71k
  BindingSet bindingSet;
2531
2532
  // If some bindings are reserved for heaps, mark those are used.
2533
2.71k
  if (spirvOptions.resourceHeapBinding)
2534
8
    bindingSet.useBinding(spirvOptions.resourceHeapBinding->binding,
2535
8
                          spirvOptions.resourceHeapBinding->set);
2536
2.71k
  if (spirvOptions.samplerHeapBinding)
2537
6
    bindingSet.useBinding(spirvOptions.samplerHeapBinding->binding,
2538
6
                          spirvOptions.samplerHeapBinding->set);
2539
2.71k
  if (spirvOptions.counterHeapBinding)
2540
6
    bindingSet.useBinding(spirvOptions.counterHeapBinding->binding,
2541
6
                          spirvOptions.counterHeapBinding->set);
2542
2543
  // Decorates the given varId of the given category with set number
2544
  // setNo, binding number bindingNo. Ignores overlaps.
2545
2.71k
  const auto tryToDecorate = [this, &bindingSet](const ResourceVar &var,
2546
2.71k
                                                 const uint32_t setNo,
2547
2.71k
                                                 const uint32_t bindingNo) {
2548
    // By default we use one binding number per resource, and an array of
2549
    // resources also gets only one binding number. However, for array of
2550
    // resources (e.g. array of textures), DX uses one binding number per array
2551
    // element. We can match this behavior via a command line option.
2552
892
    uint32_t numBindingsToUse = 1;
2553
892
    if (spirvOptions.flattenResourceArrays || 
needsFlatteningCompositeResources864
)
2554
30
      numBindingsToUse = getNumBindingsUsedByResourceType(
2555
30
          var.getSpirvInstr()->getAstResultType());
2556
2557
1.85k
    for (uint32_t i = 0; i < numBindingsToUse; 
++i966
) {
2558
966
      bool success = bindingSet.useBinding(bindingNo + i, setNo);
2559
      // We will not emit an error if we find a set/binding overlap because it
2560
      // is possible that the optimizer optimizes away a resource which resolves
2561
      // the overlap.
2562
966
      (void)success;
2563
966
    }
2564
2565
    // No need to decorate multiple binding numbers for arrays. It will be done
2566
    // by legalization/optimization.
2567
892
    spvBuilder.decorateDSetBinding(var.getSpirvInstr(), setNo, bindingNo);
2568
892
  };
2569
2570
4.04k
  for (const auto &var : resourceVars) {
2571
4.04k
    if (var.isCounter()) {
2572
316
      if (const auto *vkCBinding = var.getCounterBinding()) {
2573
        // Process mX * c1
2574
14
        uint32_t set = defaultSpace;
2575
14
        if (const auto *vkBinding = var.getBinding())
2576
10
          set = getVkBindingAttrSet(vkBinding, defaultSpace);
2577
4
        else if (const auto *reg = var.getRegister())
2578
2
          set = reg->RegisterSpace.getValueOr(defaultSpace);
2579
2580
14
        tryToDecorate(var, set, vkCBinding->getBinding());
2581
14
      }
2582
3.72k
    } else {
2583
3.72k
      if (const auto *vkBinding = var.getBinding()) {
2584
        // Process m1
2585
132
        tryToDecorate(var, getVkBindingAttrSet(vkBinding, defaultSpace),
2586
132
                      vkBinding->getBinding());
2587
132
      }
2588
3.72k
    }
2589
4.04k
  }
2590
2591
2.71k
  BindingShiftMapper bShiftMapper(spirvOptions.bShift);
2592
2.71k
  BindingShiftMapper tShiftMapper(spirvOptions.tShift);
2593
2.71k
  BindingShiftMapper sShiftMapper(spirvOptions.sShift);
2594
2.71k
  BindingShiftMapper uShiftMapper(spirvOptions.uShift);
2595
2596
  // Process m2
2597
2.71k
  for (const auto &var : resourceVars)
2598
4.04k
    if (!var.isCounter() && 
!var.getBinding()3.72k
)
2599
3.59k
      if (const auto *reg = var.getRegister()) {
2600
        // Skip space-only register() annotations
2601
778
        if (reg->isSpaceOnly())
2602
32
          continue;
2603
2604
746
        const uint32_t set = reg->RegisterSpace.getValueOr(defaultSpace);
2605
746
        uint32_t binding = reg->RegisterNumber;
2606
746
        switch (reg->RegisterType) {
2607
62
        case 'b':
2608
62
          binding += bShiftMapper.getShiftForSet(set);
2609
62
          break;
2610
402
        case 't':
2611
402
          binding += tShiftMapper.getShiftForSet(set);
2612
402
          break;
2613
184
        case 's':
2614
          // For combined texture and sampler resources, always use the t shift
2615
          // value and ignore the s shift value.
2616
184
          if (const auto *decl = var.getDeclaration()) {
2617
184
            if (decl->getAttr<VKCombinedImageSamplerAttr>() != nullptr) {
2618
12
              binding += tShiftMapper.getShiftForSet(set);
2619
12
              break;
2620
12
            }
2621
184
          }
2622
172
          binding += sShiftMapper.getShiftForSet(set);
2623
172
          break;
2624
98
        case 'u':
2625
98
          binding += uShiftMapper.getShiftForSet(set);
2626
98
          break;
2627
0
        case 'c':
2628
          // For setting packing offset. Does not affect binding.
2629
0
          break;
2630
0
        default:
2631
0
          llvm_unreachable("unknown register type found");
2632
746
        }
2633
2634
746
        tryToDecorate(var, set, binding);
2635
746
      }
2636
2637
4.04k
  
for (const auto &var : resourceVars)2.71k
{
2638
    // By default we use one binding number per resource, and an array of
2639
    // resources also gets only one binding number. However, for array of
2640
    // resources (e.g. array of textures), DX uses one binding number per array
2641
    // element. We can match this behavior via a command line option.
2642
4.04k
    uint32_t numBindingsToUse = 1;
2643
4.04k
    if (spirvOptions.flattenResourceArrays || 
needsFlatteningCompositeResources3.93k
)
2644
150
      numBindingsToUse = getNumBindingsUsedByResourceType(
2645
150
          var.getSpirvInstr()->getAstResultType());
2646
2647
4.04k
    BindingShiftMapper *bindingShiftMapper = nullptr;
2648
4.04k
    if (spirvOptions.autoShiftBindings) {
2649
58
      char registerType = '\0';
2650
58
      if (getImplicitRegisterType(var, &registerType)) {
2651
54
        switch (registerType) {
2652
4
        case 'b':
2653
4
          bindingShiftMapper = &bShiftMapper;
2654
4
          break;
2655
26
        case 't':
2656
26
          bindingShiftMapper = &tShiftMapper;
2657
26
          break;
2658
4
        case 's':
2659
4
          bindingShiftMapper = &sShiftMapper;
2660
4
          break;
2661
20
        case 'u':
2662
20
          bindingShiftMapper = &uShiftMapper;
2663
20
          break;
2664
0
        default:
2665
0
          llvm_unreachable("unknown register type found");
2666
54
        }
2667
54
      }
2668
58
    }
2669
2670
4.04k
    if (var.getDeclaration()) {
2671
3.92k
      const VarDecl *decl = dyn_cast<VarDecl>(var.getDeclaration());
2672
3.92k
      if (decl && 
(3.74k
isResourceDescriptorHeap(decl->getType())3.74k
||
2673
3.74k
                   
isSamplerDescriptorHeap(decl->getType())3.66k
))
2674
100
        continue;
2675
3.92k
    }
2676
2677
3.94k
    if (var.isCounter()) {
2678
2679
290
      if (!var.getCounterBinding()) {
2680
        // Process mX * c2
2681
276
        uint32_t set = defaultSpace;
2682
276
        if (const auto *vkBinding = var.getBinding())
2683
14
          set = getVkBindingAttrSet(vkBinding, defaultSpace);
2684
262
        else if (const auto *reg = var.getRegister())
2685
30
          set = reg->RegisterSpace.getValueOr(defaultSpace);
2686
2687
276
        uint32_t bindingShift = 0;
2688
276
        if (bindingShiftMapper)
2689
0
          bindingShift = bindingShiftMapper->getShiftForSet(set);
2690
276
        spvBuilder.decorateDSetBinding(
2691
276
            var.getSpirvInstr(), set,
2692
276
            bindingSet.useNextBinding(set, numBindingsToUse, bindingShift));
2693
276
      }
2694
3.65k
    } else if (!var.getBinding()) {
2695
3.52k
      const auto *reg = var.getRegister();
2696
3.52k
      if (reg && 
reg->isSpaceOnly()778
) {
2697
32
        const uint32_t set = reg->RegisterSpace.getValueOr(defaultSpace);
2698
32
        uint32_t bindingShift = 0;
2699
32
        if (bindingShiftMapper)
2700
0
          bindingShift = bindingShiftMapper->getShiftForSet(set);
2701
32
        spvBuilder.decorateDSetBinding(
2702
32
            var.getSpirvInstr(), set,
2703
32
            bindingSet.useNextBinding(set, numBindingsToUse, bindingShift));
2704
3.48k
      } else if (!reg) {
2705
        // Process m3 (no 'vk::binding' and no ':register' assignment)
2706
2707
        // There is a special case for the $Globals cbuffer. The $Globals buffer
2708
        // doesn't have either 'vk::binding' or ':register', but the user may
2709
        // ask for a specific binding for it via command line options.
2710
2.74k
        if (bindGlobals && 
var.isGlobalsBuffer()2
) {
2711
2
          uint32_t bindingShift = 0;
2712
2
          if (bindingShiftMapper)
2713
0
            bindingShift = bindingShiftMapper->getShiftForSet(globalsSetNo);
2714
2
          spvBuilder.decorateDSetBinding(var.getSpirvInstr(), globalsSetNo,
2715
2
                                         globalsBindNo + bindingShift);
2716
2
        }
2717
        // The normal case
2718
2.74k
        else {
2719
2.74k
          uint32_t bindingShift = 0;
2720
2.74k
          if (bindingShiftMapper)
2721
46
            bindingShift = bindingShiftMapper->getShiftForSet(defaultSpace);
2722
2.74k
          spvBuilder.decorateDSetBinding(
2723
2.74k
              var.getSpirvInstr(), defaultSpace,
2724
2.74k
              bindingSet.useNextBinding(defaultSpace, numBindingsToUse,
2725
2.74k
                                        bindingShift));
2726
2.74k
        }
2727
2.74k
      }
2728
3.52k
    }
2729
3.94k
  }
2730
2731
2.71k
  decorateResourceHeapsBindings(bindingSet);
2732
2.71k
  return true;
2733
2.71k
}
2734
2735
SpirvCodeGenOptions::BindingInfo DeclResultIdMapper::getBindingInfo(
2736
    BindingSet &bindingSet,
2737
80
    const std::optional<SpirvCodeGenOptions::BindingInfo> &userProvidedInfo) {
2738
80
  if (userProvidedInfo.has_value()) {
2739
14
    return *userProvidedInfo;
2740
14
  }
2741
66
  return {bindingSet.useNextBinding(0), /* set= */ 0};
2742
80
}
2743
2744
2.71k
void DeclResultIdMapper::decorateResourceHeapsBindings(BindingSet &bindingSet) {
2745
2.71k
  bool hasResource = false;
2746
2.71k
  bool hasSamplers = false;
2747
2.71k
  bool hasCounters = false;
2748
2749
  // Determine which type of heap resource is used to lazily allocation
2750
  // bindings.
2751
4.04k
  for (const auto &var : resourceVars) {
2752
4.04k
    if (!var.getDeclaration())
2753
118
      continue;
2754
3.92k
    const VarDecl *decl = dyn_cast<VarDecl>(var.getDeclaration());
2755
3.92k
    if (!decl)
2756
180
      continue;
2757
2758
3.74k
    const bool isResourceHeap = isResourceDescriptorHeap(decl->getType());
2759
3.74k
    const bool isSamplerHeap = isSamplerDescriptorHeap(decl->getType());
2760
2761
3.74k
    assert(!(var.isCounter() && isSamplerHeap));
2762
2763
3.74k
    hasResource |= isResourceHeap;
2764
3.74k
    hasSamplers |= isSamplerHeap;
2765
3.74k
    hasCounters |= isResourceHeap && 
var.isCounter()80
;
2766
3.74k
  }
2767
2768
  // Allocate bindings only for used resources. The order of this allocation is
2769
  // important:
2770
  //  - First resource heaps, then sampler heaps, and finally counter heaps.
2771
2.71k
  SpirvCodeGenOptions::BindingInfo resourceBinding = {/* binding= */ 0,
2772
2.71k
                                                      /* set= */ 0};
2773
2.71k
  SpirvCodeGenOptions::BindingInfo samplersBinding = {/* binding= */ 0,
2774
2.71k
                                                      /* set= */ 0};
2775
2.71k
  SpirvCodeGenOptions::BindingInfo countersBinding = {/* binding= */ 0,
2776
2.71k
                                                      /* set= */ 0};
2777
2.71k
  if (hasResource)
2778
38
    resourceBinding =
2779
38
        getBindingInfo(bindingSet, spirvOptions.resourceHeapBinding);
2780
2.71k
  if (hasSamplers)
2781
18
    samplersBinding =
2782
18
        getBindingInfo(bindingSet, spirvOptions.samplerHeapBinding);
2783
2.71k
  if (hasCounters)
2784
24
    countersBinding =
2785
24
        getBindingInfo(bindingSet, spirvOptions.counterHeapBinding);
2786
2787
4.04k
  for (const auto &var : resourceVars) {
2788
4.04k
    if (!var.getDeclaration())
2789
118
      continue;
2790
3.92k
    const VarDecl *decl = dyn_cast<VarDecl>(var.getDeclaration());
2791
3.92k
    if (!decl)
2792
180
      continue;
2793
2794
3.74k
    const bool isResourceHeap = isResourceDescriptorHeap(decl->getType());
2795
3.74k
    const bool isSamplerHeap = isSamplerDescriptorHeap(decl->getType());
2796
3.74k
    if (!isSamplerHeap && 
!isResourceHeap3.72k
)
2797
3.64k
      continue;
2798
100
    const SpirvCodeGenOptions::BindingInfo &info =
2799
100
        isSamplerHeap ? 
samplersBinding20
2800
100
                      : 
(80
var.isCounter()80
?
countersBinding26
:
resourceBinding54
);
2801
100
    spvBuilder.decorateDSetBinding(var.getSpirvInstr(), info.set, info.binding);
2802
100
  }
2803
2.71k
}
2804
2805
2.71k
bool DeclResultIdMapper::decorateResourceCoherent() {
2806
4.05k
  for (const auto &var : resourceVars) {
2807
4.05k
    if (const auto *decl = var.getDeclaration()) {
2808
3.93k
      if (decl->getAttr<HLSLGloballyCoherentAttr>()) {
2809
12
        spvBuilder.decorateCoherent(var.getSpirvInstr(),
2810
12
                                    var.getSourceLocation());
2811
12
      }
2812
3.93k
    }
2813
4.05k
  }
2814
2815
2.71k
  return true;
2816
2.71k
}
2817
2818
bool DeclResultIdMapper::createStructOutputVar(
2819
    const StageVarDataBundle &stageVarData, SpirvInstruction *value,
2820
452
    bool noWriteBack) {
2821
  // If we have base classes, we need to handle them first.
2822
452
  if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) {
2823
452
    uint32_t baseIndex = 0;
2824
452
    for (auto base : cxxDecl->bases()) {
2825
6
      SpirvInstruction *subValue = nullptr;
2826
6
      if (!noWriteBack)
2827
2
        subValue = spvBuilder.createCompositeExtract(
2828
2
            base.getType(), value, {baseIndex++},
2829
2
            stageVarData.decl->getLocation());
2830
2831
6
      StageVarDataBundle memberVarData = stageVarData;
2832
6
      memberVarData.decl = base.getType()->getAsCXXRecordDecl();
2833
6
      memberVarData.type = base.getType();
2834
6
      if (!createStageVars(memberVarData, false, &subValue, noWriteBack))
2835
0
        return false;
2836
6
    }
2837
452
  }
2838
2839
  // Unlike reading, which may require us to read stand-alone builtins and
2840
  // stage input variables and compose an array of structs out of them,
2841
  // it happens that we don't need to write an array of structs in a bunch
2842
  // for all shader stages:
2843
  //
2844
  // * VS: output is a single struct, without extra arrayness
2845
  // * HS: output is an array of structs, with extra arrayness,
2846
  //       but we only write to the struct at the InvocationID index
2847
  // * DS: output is a single struct, without extra arrayness
2848
  // * GS: output is controlled by OpEmitVertex, one vertex per time
2849
  // * MS: output is an array of structs, with extra arrayness
2850
  //
2851
  // The interesting shader stage is HS. We need the InvocationID to write
2852
  // out the value to the correct array element.
2853
452
  const auto *structDecl = stageVarData.type->getAs<RecordType>()->getDecl();
2854
1.09k
  for (const auto *field : structDecl->fields()) {
2855
1.09k
    const auto fieldType = field->getType();
2856
1.09k
    SpirvInstruction *subValue = nullptr;
2857
1.09k
    if (!noWriteBack) {
2858
868
      subValue = spvBuilder.createCompositeExtract(
2859
868
          fieldType, value,
2860
868
          {getNumBaseClasses(stageVarData.type) + field->getFieldIndex()},
2861
868
          stageVarData.decl->getLocation());
2862
868
      if (field->hasAttr<HLSLNoInterpolationAttr>() ||
2863
868
          
structDecl->hasAttr<HLSLNoInterpolationAttr>()848
)
2864
20
        subValue->setNoninterpolated();
2865
868
    }
2866
2867
1.09k
    StageVarDataBundle memberVarData = stageVarData;
2868
1.09k
    memberVarData.decl = field;
2869
1.09k
    memberVarData.type = field->getType();
2870
1.09k
    memberVarData.asNoInterp |= field->hasAttr<HLSLNoInterpolationAttr>();
2871
1.09k
    if (!createStageVars(memberVarData, false, &subValue, noWriteBack))
2872
2
      return false;
2873
1.09k
  }
2874
450
  return true;
2875
452
}
2876
2877
SpirvInstruction *
2878
DeclResultIdMapper::createStructInputVar(const StageVarDataBundle &stageVarData,
2879
410
                                         bool noWriteBack) {
2880
  // If this decl translates into multiple stage input variables, we need to
2881
  // load their values into a composite.
2882
410
  llvm::SmallVector<SpirvInstruction *, 4> subValues;
2883
2884
  // If we have base classes, we need to handle them first.
2885
410
  if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) {
2886
410
    for (auto base : cxxDecl->bases()) {
2887
8
      SpirvInstruction *subValue = nullptr;
2888
8
      StageVarDataBundle memberVarData = stageVarData;
2889
8
      memberVarData.decl = base.getType()->getAsCXXRecordDecl();
2890
8
      memberVarData.type = base.getType();
2891
8
      if (!createStageVars(memberVarData, true, &subValue, noWriteBack))
2892
0
        return nullptr;
2893
8
      subValues.push_back(subValue);
2894
8
    }
2895
410
  }
2896
2897
410
  const auto *structDecl = stageVarData.type->getAs<RecordType>()->getDecl();
2898
838
  for (const auto *field : structDecl->fields()) {
2899
838
    SpirvInstruction *subValue = nullptr;
2900
838
    StageVarDataBundle memberVarData = stageVarData;
2901
838
    memberVarData.decl = field;
2902
838
    memberVarData.type = field->getType();
2903
838
    memberVarData.asNoInterp |= field->hasAttr<HLSLNoInterpolationAttr>();
2904
838
    if (!createStageVars(memberVarData, true, &subValue, noWriteBack))
2905
0
      return nullptr;
2906
838
    subValues.push_back(subValue);
2907
838
  }
2908
2909
410
  if (stageVarData.arraySize == 0) {
2910
270
    SpirvInstruction *value = spvBuilder.createCompositeConstruct(
2911
270
        stageVarData.type, subValues, stageVarData.decl->getLocation());
2912
270
    for (auto *subInstr : subValues)
2913
570
      spvBuilder.addPerVertexStgInputFuncVarEntry(subInstr, value);
2914
270
    return value;
2915
270
  }
2916
2917
  // Handle the extra level of arrayness.
2918
2919
  // We need to return an array of structs. But we get arrays of fields
2920
  // from visiting all fields. So now we need to extract all the elements
2921
  // at the same index of each field arrays and compose a new struct out
2922
  // of them.
2923
140
  const auto structType = stageVarData.type;
2924
140
  const auto arrayType = astContext.getConstantArrayType(
2925
140
      structType, llvm::APInt(32, stageVarData.arraySize),
2926
140
      clang::ArrayType::Normal, 0);
2927
2928
140
  llvm::SmallVector<SpirvInstruction *, 16> arrayElements;
2929
2930
1.14k
  for (uint32_t arrayIndex = 0; arrayIndex < stageVarData.arraySize;
2931
1.00k
       ++arrayIndex) {
2932
1.00k
    llvm::SmallVector<SpirvInstruction *, 8> fields;
2933
2934
    // If we have base classes, we need to handle them first.
2935
1.00k
    if (const auto *cxxDecl = stageVarData.type->getAsCXXRecordDecl()) {
2936
1.00k
      uint32_t baseIndex = 0;
2937
1.00k
      for (auto base : cxxDecl->bases()) {
2938
8
        const auto baseType = base.getType();
2939
8
        fields.push_back(spvBuilder.createCompositeExtract(
2940
8
            baseType, subValues[baseIndex++], {arrayIndex},
2941
8
            stageVarData.decl->getLocation()));
2942
8
      }
2943
1.00k
    }
2944
2945
    // Extract the element at index arrayIndex from each field
2946
2.41k
    for (const auto *field : structDecl->fields()) {
2947
2.41k
      const auto fieldType = field->getType();
2948
2.41k
      fields.push_back(spvBuilder.createCompositeExtract(
2949
2.41k
          fieldType,
2950
2.41k
          subValues[getNumBaseClasses(stageVarData.type) +
2951
2.41k
                    field->getFieldIndex()],
2952
2.41k
          {arrayIndex}, stageVarData.decl->getLocation()));
2953
2.41k
    }
2954
    // Compose a new struct out of them
2955
1.00k
    arrayElements.push_back(spvBuilder.createCompositeConstruct(
2956
1.00k
        structType, fields, stageVarData.decl->getLocation()));
2957
1.00k
  }
2958
2959
140
  return spvBuilder.createCompositeConstruct(arrayType, arrayElements,
2960
140
                                             stageVarData.decl->getLocation());
2961
410
}
2962
2963
void DeclResultIdMapper::storeToShaderOutputVariable(
2964
    SpirvVariable *varInstr, SpirvInstruction *value,
2965
1.74k
    const StageVarDataBundle &stageVarData) {
2966
1.74k
  SpirvInstruction *ptr = varInstr;
2967
2968
  // Since boolean output stage variables are represented as unsigned
2969
  // integers, we must cast the value to uint before storing.
2970
1.74k
  if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type,
2971
1.74k
                          stageVarData.semantic->getKind(),
2972
1.74k
                          stageVarData.sigPoint->GetKind())) {
2973
16
    QualType finalType = varInstr->getAstResultType();
2974
16
    if (stageVarData.arraySize != 0) {
2975
      // We assume that we will only have to write to a single value of the
2976
      // array, so we have to cast to the element type of the array, and not the
2977
      // array type.
2978
2
      assert(stageVarData.invocationId.hasValue());
2979
2
      finalType = finalType->getAsArrayTypeUnsafe()->getElementType();
2980
2
    }
2981
16
    value = theEmitter.castToType(value, stageVarData.type, finalType,
2982
16
                                  stageVarData.decl->getLocation());
2983
16
  }
2984
2985
  // Special handling of SV_TessFactor HS patch constant output.
2986
  // TessLevelOuter is always an array of size 4 in SPIR-V, but
2987
  // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
2988
  // relevant indexes must be written to.
2989
1.74k
  if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::TessFactor &&
2990
1.74k
      
hlsl::GetArraySize(stageVarData.type) != 478
) {
2991
16
    const auto tessFactorSize = hlsl::GetArraySize(stageVarData.type);
2992
64
    for (uint32_t i = 0; i < tessFactorSize; 
++i48
) {
2993
48
      ptr = spvBuilder.createAccessChain(
2994
48
          astContext.FloatTy, varInstr,
2995
48
          {spvBuilder.getConstantInt(astContext.UnsignedIntTy,
2996
48
                                     llvm::APInt(32, i))},
2997
48
          stageVarData.decl->getLocation());
2998
48
      spvBuilder.createStore(
2999
48
          ptr,
3000
48
          spvBuilder.createCompositeExtract(astContext.FloatTy, value, {i},
3001
48
                                            stageVarData.decl->getLocation()),
3002
48
          stageVarData.decl->getLocation());
3003
48
    }
3004
16
  }
3005
  // Special handling of SV_InsideTessFactor HS patch constant output.
3006
  // TessLevelInner is always an array of size 2 in SPIR-V, but
3007
  // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
3008
  // HLSL. If SV_InsideTessFactor is a scalar, only write to index 0 of
3009
  // TessLevelInner.
3010
1.73k
  else if (stageVarData.semantic->getKind() ==
3011
1.73k
               hlsl::Semantic::Kind::InsideTessFactor &&
3012
           // Some developers use float[1] instead of a scalar float.
3013
1.73k
           
(76
!stageVarData.type->isArrayType()76
||
3014
76
            
hlsl::GetArraySize(stageVarData.type) == 168
)) {
3015
14
    ptr = spvBuilder.createAccessChain(
3016
14
        astContext.FloatTy, varInstr,
3017
14
        spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)),
3018
14
        stageVarData.decl->getLocation());
3019
14
    if (stageVarData.type->isArrayType()) // float[1]
3020
6
      value = spvBuilder.createCompositeExtract(
3021
6
          astContext.FloatTy, value, {0}, stageVarData.decl->getLocation());
3022
14
    spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation());
3023
14
  }
3024
  // Special handling of SV_Coverage, which is an unit value. We need to
3025
  // write it to the first element in the SampleMask builtin.
3026
1.71k
  else if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Coverage) {
3027
4
    ptr = spvBuilder.createAccessChain(
3028
4
        stageVarData.type, varInstr,
3029
4
        spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)),
3030
4
        stageVarData.decl->getLocation());
3031
4
    ptr->setStorageClass(spv::StorageClass::Output);
3032
4
    spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation());
3033
4
  }
3034
  // Special handling of HS ouput, for which we write to only one
3035
  // element in the per-vertex data array: the one indexed by
3036
  // SV_ControlPointID.
3037
1.71k
  else if (stageVarData.invocationId.hasValue() &&
3038
1.71k
           
stageVarData.invocationId.getValue() != nullptr94
) {
3039
    // Remove the arrayness to get the element type.
3040
92
    assert(isa<ConstantArrayType>(varInstr->getAstResultType()));
3041
92
    const auto elementType =
3042
92
        astContext.getAsArrayType(varInstr->getAstResultType())
3043
92
            ->getElementType();
3044
92
    auto index = stageVarData.invocationId.getValue();
3045
92
    ptr = spvBuilder.createAccessChain(elementType, varInstr, index,
3046
92
                                       stageVarData.decl->getLocation());
3047
92
    ptr->setStorageClass(spv::StorageClass::Output);
3048
92
    spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation());
3049
92
  }
3050
  // For all normal cases
3051
1.62k
  else {
3052
1.62k
    spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation());
3053
1.62k
  }
3054
1.74k
}
3055
3056
SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable(
3057
2.09k
    SpirvVariable *varInstr, const StageVarDataBundle &stageVarData) {
3058
2.09k
  SpirvInstruction *load = spvBuilder.createLoad(
3059
2.09k
      varInstr->getAstResultType(), varInstr, stageVarData.decl->getLocation());
3060
  // Fix ups for corner cases
3061
3062
  // Special handling of SV_TessFactor DS patch constant input.
3063
  // TessLevelOuter is always an array of size 4 in SPIR-V, but
3064
  // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
3065
  // relevant indexes must be loaded.
3066
2.09k
  if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::TessFactor &&
3067
2.09k
      
hlsl::GetArraySize(stageVarData.type) != 430
) {
3068
6
    llvm::SmallVector<SpirvInstruction *, 4> components;
3069
6
    const auto tessFactorSize = hlsl::GetArraySize(stageVarData.type);
3070
6
    const auto arrType = astContext.getConstantArrayType(
3071
6
        astContext.FloatTy, llvm::APInt(32, tessFactorSize),
3072
6
        clang::ArrayType::Normal, 0);
3073
24
    for (uint32_t i = 0; i < tessFactorSize; 
++i18
)
3074
18
      components.push_back(spvBuilder.createCompositeExtract(
3075
18
          astContext.FloatTy, load, {i}, stageVarData.decl->getLocation()));
3076
6
    load = spvBuilder.createCompositeConstruct(
3077
6
        arrType, components, stageVarData.decl->getLocation());
3078
6
  }
3079
  // Special handling of SV_InsideTessFactor DS patch constant input.
3080
  // TessLevelInner is always an array of size 2 in SPIR-V, but
3081
  // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
3082
  // HLSL. If SV_InsideTessFactor is a scalar, only extract index 0 of
3083
  // TessLevelInner.
3084
2.08k
  else if (stageVarData.semantic->getKind() ==
3085
2.08k
               hlsl::Semantic::Kind::InsideTessFactor &&
3086
           // Some developers use float[1] instead of a scalar float.
3087
2.08k
           
(30
!stageVarData.type->isArrayType()30
||
3088
30
            
hlsl::GetArraySize(stageVarData.type) == 126
)) {
3089
6
    load = spvBuilder.createCompositeExtract(astContext.FloatTy, load, {0},
3090
6
                                             stageVarData.decl->getLocation());
3091
6
    if (stageVarData.type->isArrayType()) { // float[1]
3092
2
      const auto arrType = astContext.getConstantArrayType(
3093
2
          astContext.FloatTy, llvm::APInt(32, 1), clang::ArrayType::Normal, 0);
3094
2
      load = spvBuilder.createCompositeConstruct(
3095
2
          arrType, {load}, stageVarData.decl->getLocation());
3096
2
    }
3097
6
  }
3098
  // SV_DomainLocation can refer to a float2 or a float3, whereas TessCoord
3099
  // is always a float3. To ensure SPIR-V validity, a float3 stage variable
3100
  // is created, and we must extract a float2 from it before passing it to
3101
  // the main function.
3102
2.07k
  else if (stageVarData.semantic->getKind() ==
3103
2.07k
               hlsl::Semantic::Kind::DomainLocation &&
3104
2.07k
           
hlsl::GetHLSLVecSize(stageVarData.type) != 316
) {
3105
14
    const auto domainLocSize = hlsl::GetHLSLVecSize(stageVarData.type);
3106
14
    load = spvBuilder.createVectorShuffle(
3107
14
        astContext.getExtVectorType(astContext.FloatTy, domainLocSize), load,
3108
14
        load, {0, 1}, stageVarData.decl->getLocation());
3109
14
  }
3110
  // Special handling of SV_Coverage, which is an uint value. We need to
3111
  // read SampleMask and extract its first element.
3112
2.06k
  else if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Coverage) {
3113
4
    load = spvBuilder.createCompositeExtract(stageVarData.type, load, {0},
3114
4
                                             stageVarData.decl->getLocation());
3115
4
  }
3116
  // Special handling of SV_InnerCoverage, which is an uint value. We need
3117
  // to read FullyCoveredEXT, which is a boolean value, and convert it to an
3118
  // uint value. According to D3D12 "Conservative Rasterization" doc: "The
3119
  // Pixel Shader has a 32-bit scalar integer System Generate Value
3120
  // available: InnerCoverage. This is a bit-field that has bit 0 from the
3121
  // LSB set to 1 for a given conservatively rasterized pixel, only when
3122
  // that pixel is guaranteed to be entirely inside the current primitive.
3123
  // All other input register bits must be set to 0 when bit 0 is not set,
3124
  // but are undefined when bit 0 is set to 1 (essentially, this bit-field
3125
  // represents a Boolean value where false must be exactly 0, but true can
3126
  // be any odd (i.e. bit 0 set) non-zero value)."
3127
2.06k
  else if (stageVarData.semantic->getKind() ==
3128
2.06k
           hlsl::Semantic::Kind::InnerCoverage) {
3129
2
    const auto constOne =
3130
2
        spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
3131
2
    const auto constZero =
3132
2
        spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
3133
2
    load = spvBuilder.createSelect(astContext.UnsignedIntTy, load, constOne,
3134
2
                                   constZero, stageVarData.decl->getLocation());
3135
2
  }
3136
  // Special handling of SV_Barycentrics, which is a float3, but the
3137
  // The 3 values are NOT guaranteed to add up to floating-point 1.0
3138
  // exactly. Calculate the third element here.
3139
2.05k
  else if (stageVarData.semantic->getKind() ==
3140
2.05k
           hlsl::Semantic::Kind::Barycentrics) {
3141
12
    const auto x = spvBuilder.createCompositeExtract(
3142
12
        astContext.FloatTy, load, {0}, stageVarData.decl->getLocation());
3143
12
    const auto y = spvBuilder.createCompositeExtract(
3144
12
        astContext.FloatTy, load, {1}, stageVarData.decl->getLocation());
3145
12
    const auto xy =
3146
12
        spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy, x, y,
3147
12
                                  stageVarData.decl->getLocation());
3148
12
    const auto z = spvBuilder.createBinaryOp(
3149
12
        spv::Op::OpFSub, astContext.FloatTy,
3150
12
        spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f)),
3151
12
        xy, stageVarData.decl->getLocation());
3152
12
    load = spvBuilder.createCompositeConstruct(
3153
12
        astContext.getExtVectorType(astContext.FloatTy, 3), {x, y, z},
3154
12
        stageVarData.decl->getLocation());
3155
12
  }
3156
  // Special handling of SV_DispatchThreadID and SV_GroupThreadID, which may
3157
  // be a uint or uint2, but the underlying stage input variable is a uint3.
3158
  // The last component(s) should be discarded in needed.
3159
2.04k
  else if ((stageVarData.semantic->getKind() ==
3160
2.04k
                hlsl::Semantic::Kind::DispatchThreadID ||
3161
2.04k
            stageVarData.semantic->getKind() ==
3162
1.80k
                hlsl::Semantic::Kind::GroupThreadID ||
3163
2.04k
            stageVarData.semantic->getKind() ==
3164
1.72k
                hlsl::Semantic::Kind::GroupID) &&
3165
2.04k
           
(354
!hlsl::IsHLSLVecType(stageVarData.type)354
||
3166
354
            
hlsl::GetHLSLVecSize(stageVarData.type) != 3276
)) {
3167
104
    const auto srcVecElemType =
3168
104
        hlsl::IsHLSLVecType(stageVarData.type)
3169
104
            ? 
hlsl::GetHLSLVecElementType(stageVarData.type)26
3170
104
            : 
stageVarData.type78
;
3171
104
    const auto vecSize = hlsl::IsHLSLVecType(stageVarData.type)
3172
104
                             ? 
hlsl::GetHLSLVecSize(stageVarData.type)26
3173
104
                             : 
178
;
3174
104
    if (vecSize == 1)
3175
78
      load = spvBuilder.createCompositeExtract(
3176
78
          srcVecElemType, load, {0}, stageVarData.decl->getLocation());
3177
26
    else if (vecSize == 2)
3178
26
      load = spvBuilder.createVectorShuffle(
3179
26
          astContext.getExtVectorType(srcVecElemType, 2), load, load, {0, 1},
3180
26
          stageVarData.decl->getLocation());
3181
104
  }
3182
3183
  // Reciprocate SV_Position.w if requested
3184
2.09k
  if (stageVarData.semantic->getKind() == hlsl::Semantic::Kind::Position)
3185
106
    load = invertWIfRequested(load, stageVarData.decl->getLocation());
3186
3187
  // Since boolean stage input variables are represented as unsigned
3188
  // integers, after loading them, we should cast them to boolean.
3189
2.09k
  if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type,
3190
2.09k
                          stageVarData.semantic->getKind(),
3191
2.09k
                          stageVarData.sigPoint->GetKind())) {
3192
3193
36
    if (stageVarData.arraySize == 0) {
3194
34
      load = theEmitter.castToType(load, varInstr->getAstResultType(),
3195
34
                                   stageVarData.type,
3196
34
                                   stageVarData.decl->getLocation());
3197
34
    } else {
3198
2
      llvm::SmallVector<SpirvInstruction *, 8> fields;
3199
2
      SourceLocation loc = stageVarData.decl->getLocation();
3200
2
      QualType originalScalarType = varInstr->getAstResultType()
3201
2
                                        ->castAsArrayTypeUnsafe()
3202
2
                                        ->getElementType();
3203
8
      for (uint32_t idx = 0; idx < stageVarData.arraySize; 
++idx6
) {
3204
6
        SpirvInstruction *field = spvBuilder.createCompositeExtract(
3205
6
            originalScalarType, load, {idx}, loc);
3206
6
        field = theEmitter.castToType(field, field->getAstResultType(),
3207
6
                                      stageVarData.type, loc);
3208
6
        fields.push_back(field);
3209
6
      }
3210
3211
2
      QualType finalType = astContext.getConstantArrayType(
3212
2
          stageVarData.type, llvm::APInt(32, stageVarData.arraySize),
3213
2
          clang::ArrayType::Normal, 0);
3214
2
      load = spvBuilder.createCompositeConstruct(finalType, fields, loc);
3215
2
    }
3216
36
  }
3217
2.09k
  return load;
3218
2.09k
}
3219
3220
bool DeclResultIdMapper::validateShaderStageVar(
3221
4.22k
    const StageVarDataBundle &stageVarData) {
3222
4.22k
  if (!validateVKAttributes(stageVarData.decl))
3223
4
    return false;
3224
3225
4.22k
  if (!isValidSemanticInShaderModel(stageVarData)) {
3226
8
    emitError("invalid usage of semantic '%0' in shader profile %1",
3227
8
              stageVarData.decl->getLocation())
3228
8
        << stageVarData.semantic->str
3229
8
        << hlsl::ShaderModel::GetKindName(
3230
8
               spvContext.getCurrentShaderModelKind());
3231
8
    return false;
3232
8
  }
3233
3234
4.21k
  if (!validateVKBuiltins(stageVarData))
3235
8
    return false;
3236
3237
4.20k
  if (!validateShaderStageVarType(stageVarData))
3238
2
    return false;
3239
4.20k
  return true;
3240
4.20k
}
3241
3242
4.22k
bool DeclResultIdMapper::validateVKAttributes(const NamedDecl *decl) {
3243
4.22k
  bool success = true;
3244
4.22k
  if (const auto *idxAttr = decl->getAttr<VKIndexAttr>()) {
3245
8
    if (!spvContext.isPS()) {
3246
2
      emitError("vk::index only allowed in pixel shader",
3247
2
                idxAttr->getLocation());
3248
2
      success = false;
3249
2
    }
3250
3251
8
    const auto *locAttr = decl->getAttr<VKLocationAttr>();
3252
3253
8
    if (!locAttr) {
3254
2
      emitError("vk::index should be used together with vk::location for "
3255
2
                "dual-source blending",
3256
2
                idxAttr->getLocation());
3257
2
      success = false;
3258
6
    } else {
3259
6
      const auto locNumber = locAttr->getNumber();
3260
6
      if (locNumber != 0) {
3261
2
        emitError("dual-source blending should use vk::location 0",
3262
2
                  locAttr->getLocation());
3263
2
        success = false;
3264
2
      }
3265
6
    }
3266
3267
8
    const auto idxNumber = idxAttr->getNumber();
3268
8
    if (idxNumber != 0 && 
idxNumber != 16
) {
3269
2
      emitError("dual-source blending only accepts 0 or 1 as vk::index",
3270
2
                idxAttr->getLocation());
3271
2
      success = false;
3272
2
    }
3273
8
  }
3274
3275
4.22k
  return success;
3276
4.22k
}
3277
3278
bool DeclResultIdMapper::validateVKBuiltins(
3279
4.21k
    const StageVarDataBundle &stageVarData) {
3280
4.21k
  bool success = true;
3281
3282
4.21k
  if (const auto *builtinAttr = stageVarData.decl->getAttr<VKBuiltInAttr>()) {
3283
    // The front end parsing only allows vk::builtin to be attached to a
3284
    // function/parameter/variable; all of them are DeclaratorDecls.
3285
60
    const auto declType =
3286
60
        getTypeOrFnRetType(cast<DeclaratorDecl>(stageVarData.decl));
3287
60
    const auto loc = builtinAttr->getLocation();
3288
3289
60
    if (stageVarData.decl->hasAttr<VKLocationAttr>()) {
3290
4
      emitError("cannot use vk::builtin and vk::location together", loc);
3291
4
      success = false;
3292
4
    }
3293
3294
60
    const llvm::StringRef builtin = builtinAttr->getBuiltIn();
3295
3296
60
    if (builtin == "HelperInvocation") {
3297
6
      if (!declType->isBooleanType()) {
3298
2
        emitError("HelperInvocation builtin must be of boolean type", loc);
3299
2
        success = false;
3300
2
      }
3301
3302
6
      if (stageVarData.sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) {
3303
2
        emitError(
3304
2
            "HelperInvocation builtin can only be used as pixel shader input",
3305
2
            loc);
3306
2
        success = false;
3307
2
      }
3308
54
    } else if (builtin == "PointSize") {
3309
22
      if (!declType->isFloatingType()) {
3310
2
        emitError("PointSize builtin must be of float type", loc);
3311
2
        success = false;
3312
2
      }
3313
3314
22
      switch (stageVarData.sigPoint->GetKind()) {
3315
4
      case hlsl::SigPoint::Kind::VSOut:
3316
6
      case hlsl::SigPoint::Kind::HSCPIn:
3317
8
      case hlsl::SigPoint::Kind::HSCPOut:
3318
10
      case hlsl::SigPoint::Kind::DSCPIn:
3319
12
      case hlsl::SigPoint::Kind::DSOut:
3320
14
      case hlsl::SigPoint::Kind::GSVIn:
3321
16
      case hlsl::SigPoint::Kind::GSOut:
3322
16
      case hlsl::SigPoint::Kind::PSIn:
3323
20
      case hlsl::SigPoint::Kind::MSOut:
3324
20
        break;
3325
2
      default:
3326
2
        emitError("PointSize builtin cannot be used as %0", loc)
3327
2
            << stageVarData.sigPoint->GetName();
3328
2
        success = false;
3329
22
      }
3330
32
    } else if (builtin == "BaseVertex" || 
builtin == "BaseInstance"22
||
3331
32
               
builtin == "DrawIndex"20
) {
3332
24
      if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
3333
24
          
!declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)6
) {
3334
2
        emitError("%0 builtin must be of 32-bit scalar integer type", loc)
3335
2
            << builtin;
3336
2
        success = false;
3337
2
      }
3338
3339
24
      switch (stageVarData.sigPoint->GetKind()) {
3340
14
      case hlsl::SigPoint::Kind::VSIn:
3341
14
        break;
3342
4
      case hlsl::SigPoint::Kind::MSIn:
3343
8
      case hlsl::SigPoint::Kind::ASIn:
3344
8
        if (builtin != "DrawIndex") {
3345
0
          emitError("%0 builtin cannot be used as %1", loc)
3346
0
              << builtin << stageVarData.sigPoint->GetName();
3347
0
          success = false;
3348
0
        }
3349
8
        break;
3350
2
      default:
3351
2
        emitError("%0 builtin cannot be used as %1", loc)
3352
2
            << builtin << stageVarData.sigPoint->GetName();
3353
2
        success = false;
3354
24
      }
3355
24
    } else 
if (8
builtin == "DeviceIndex"8
) {
3356
6
      if (getStorageClassForSigPoint(stageVarData.sigPoint) !=
3357
6
          spv::StorageClass::Input) {
3358
2
        emitError("%0 builtin can only be used as shader input", loc)
3359
2
            << builtin;
3360
2
        success = false;
3361
2
      }
3362
6
      if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
3363
6
          
!declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)2
) {
3364
2
        emitError("%0 builtin must be of 32-bit scalar integer type", loc)
3365
2
            << builtin;
3366
2
        success = false;
3367
2
      }
3368
6
    } else 
if (2
builtin == "ViewportMaskNV"2
) {
3369
2
      if (stageVarData.sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) {
3370
0
        emitError("%0 builtin can only be used as 'primitives' output in MS",
3371
0
                  loc)
3372
0
            << builtin;
3373
0
        success = false;
3374
0
      }
3375
2
      if (!declType->isArrayType() ||
3376
2
          !declType->getArrayElementTypeNoTypeQual()->isSpecificBuiltinType(
3377
2
              BuiltinType::Kind::Int)) {
3378
0
        emitError("%0 builtin must be of type array of integers", loc)
3379
0
            << builtin;
3380
0
        success = false;
3381
0
      }
3382
2
    }
3383
60
  }
3384
3385
4.21k
  return success;
3386
4.21k
}
3387
3388
bool DeclResultIdMapper::validateShaderStageVarType(
3389
4.20k
    const StageVarDataBundle &stageVarData) {
3390
3391
4.20k
  switch (stageVarData.semantic->getKind()) {
3392
4
  case hlsl::Semantic::Kind::InnerCoverage:
3393
4
    if (!stageVarData.type->isSpecificBuiltinType(BuiltinType::UInt)) {
3394
2
      emitError("SV_InnerCoverage must be of uint type.",
3395
2
                stageVarData.decl->getLocation());
3396
2
      return false;
3397
2
    }
3398
2
    break;
3399
4.20k
  default:
3400
4.20k
    break;
3401
4.20k
  }
3402
4.20k
  return true;
3403
4.20k
}
3404
3405
bool DeclResultIdMapper::isValidSemanticInShaderModel(
3406
4.22k
    const StageVarDataBundle &stageVarData) {
3407
  // Error out when the given semantic is invalid in this shader model
3408
4.22k
  if (hlsl::SigPoint::GetInterpretation(
3409
4.22k
          stageVarData.semantic->getKind(), stageVarData.sigPoint->GetKind(),
3410
4.22k
          spvContext.getMajorVersion(), spvContext.getMinorVersion()) ==
3411
4.22k
      hlsl::DXIL::SemanticInterpretationKind::NA) {
3412
    // Special handle MSIn/ASIn allowing VK-only builtin "DrawIndex".
3413
16
    switch (stageVarData.sigPoint->GetKind()) {
3414
6
    case hlsl::SigPoint::Kind::MSIn:
3415
14
    case hlsl::SigPoint::Kind::ASIn:
3416
14
      if (const auto *builtinAttr =
3417
14
              stageVarData.decl->getAttr<VKBuiltInAttr>()) {
3418
8
        const llvm::StringRef builtin = builtinAttr->getBuiltIn();
3419
8
        if (builtin == "DrawIndex") {
3420
8
          break;
3421
8
        }
3422
8
      }
3423
14
      
LLVM_FALLTHROUGH6
;6
3424
8
    default:
3425
8
      return false;
3426
16
    }
3427
16
  }
3428
4.21k
  return true;
3429
4.22k
}
3430
3431
SpirvVariable *DeclResultIdMapper::getInstanceIdFromIndexAndBase(
3432
4
    SpirvVariable *instanceIndexVar, SpirvVariable *baseInstanceVar) {
3433
4
  QualType type = instanceIndexVar->getAstResultType();
3434
4
  auto *instanceIdVar = spvBuilder.addFnVar(
3435
4
      type, instanceIndexVar->getSourceLocation(), "SV_InstanceID");
3436
4
  auto *instanceIndexValue = spvBuilder.createLoad(
3437
4
      type, instanceIndexVar, instanceIndexVar->getSourceLocation());
3438
4
  auto *baseInstanceValue = spvBuilder.createLoad(
3439
4
      type, baseInstanceVar, instanceIndexVar->getSourceLocation());
3440
4
  auto *instanceIdValue = spvBuilder.createBinaryOp(
3441
4
      spv::Op::OpISub, type, instanceIndexValue, baseInstanceValue,
3442
4
      instanceIndexVar->getSourceLocation());
3443
4
  spvBuilder.createStore(instanceIdVar, instanceIdValue,
3444
4
                         instanceIndexVar->getSourceLocation());
3445
4
  return instanceIdVar;
3446
4
}
3447
3448
SpirvVariable *
3449
DeclResultIdMapper::getVertexIdFromIndexAndBase(SpirvVariable *vertexIndexVar,
3450
4
                                                SpirvVariable *baseVertexVar) {
3451
4
  QualType type = vertexIndexVar->getAstResultType();
3452
4
  auto *vertexIdVar = spvBuilder.addFnVar(
3453
4
      type, vertexIndexVar->getSourceLocation(), "SV_VertexID");
3454
4
  auto *vertexIndexValue = spvBuilder.createLoad(
3455
4
      type, vertexIndexVar, vertexIndexVar->getSourceLocation());
3456
4
  auto *baseVertexValue = spvBuilder.createLoad(
3457
4
      type, baseVertexVar, vertexIndexVar->getSourceLocation());
3458
4
  auto *vertexIdValue = spvBuilder.createBinaryOp(
3459
4
      spv::Op::OpISub, type, vertexIndexValue, baseVertexValue,
3460
4
      vertexIndexVar->getSourceLocation());
3461
4
  spvBuilder.createStore(vertexIdVar, vertexIdValue,
3462
4
                         vertexIndexVar->getSourceLocation());
3463
4
  return vertexIdVar;
3464
4
}
3465
3466
SpirvVariable *
3467
DeclResultIdMapper::getBaseInstanceVariable(const hlsl::SigPoint *sigPoint,
3468
4
                                            QualType type) {
3469
4
  assert(type->isSpecificBuiltinType(BuiltinType::Kind::Int) ||
3470
4
         type->isSpecificBuiltinType(BuiltinType::Kind::UInt));
3471
4
  auto *baseInstanceVar = spvBuilder.addStageBuiltinVar(
3472
4
      type, spv::StorageClass::Input, spv::BuiltIn::BaseInstance, false, {});
3473
4
  StageVar var(sigPoint, {}, nullptr, type,
3474
4
               getLocationAndComponentCount(astContext, type));
3475
4
  var.setSpirvInstr(baseInstanceVar);
3476
4
  var.setIsSpirvBuiltin();
3477
4
  stageVars.push_back(var);
3478
4
  return baseInstanceVar;
3479
4
}
3480
3481
SpirvVariable *
3482
DeclResultIdMapper::getBaseVertexVariable(const hlsl::SigPoint *sigPoint,
3483
4
                                          QualType type) {
3484
4
  assert(type->isSpecificBuiltinType(BuiltinType::Kind::Int) ||
3485
4
         type->isSpecificBuiltinType(BuiltinType::Kind::UInt));
3486
4
  auto *baseVertexVar = spvBuilder.addStageBuiltinVar(
3487
4
      type, spv::StorageClass::Input, spv::BuiltIn::BaseVertex, false, {});
3488
4
  StageVar var(sigPoint, {}, nullptr, type,
3489
4
               getLocationAndComponentCount(astContext, type));
3490
4
  var.setSpirvInstr(baseVertexVar);
3491
4
  var.setIsSpirvBuiltin();
3492
4
  stageVars.push_back(var);
3493
4
  return baseVertexVar;
3494
4
}
3495
3496
SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable(
3497
4.12k
    const StageVarDataBundle &stageVarData) {
3498
  // The evalType will be the type of the interface variable in SPIR-V.
3499
  // The type of the variable used in the body of the function will still be
3500
  // `stageVarData.type`.
3501
4.12k
  QualType evalType = getTypeForSpirvStageVariable(stageVarData);
3502
3503
4.12k
  const auto *builtinAttr = stageVarData.decl->getAttr<VKBuiltInAttr>();
3504
4.12k
  StageVar stageVar(
3505
4.12k
      stageVarData.sigPoint, *stageVarData.semantic, builtinAttr, evalType,
3506
      // For HS/DS/GS, we have already stripped the outmost arrayness on type.
3507
4.12k
      hlsl::IsHLSLNodeInputType(stageVarData.type)
3508
4.12k
          ? 
LocationAndComponent({0, 0, false})64
3509
4.12k
          : 
getLocationAndComponentCount(astContext, stageVarData.type)4.05k
);
3510
4.12k
  const auto name =
3511
4.12k
      stageVarData.namePrefix.str() + "." + stageVar.getSemanticStr();
3512
4.12k
  SpirvVariable *varInstr = createSpirvStageVar(
3513
4.12k
      &stageVar, stageVarData.decl, name, stageVarData.semantic->loc);
3514
3515
4.12k
  if (!varInstr)
3516
0
    return nullptr;
3517
3518
4.12k
  if (stageVarData.asNoInterp)
3519
82
    varInstr->setNoninterpolated();
3520
3521
4.12k
  stageVar.setSpirvInstr(varInstr);
3522
4.12k
  stageVar.setLocationAttr(stageVarData.decl->getAttr<VKLocationAttr>());
3523
4.12k
  stageVar.setIndexAttr(stageVarData.decl->getAttr<VKIndexAttr>());
3524
4.12k
  if (stageVar.getStorageClass() == spv::StorageClass::Input ||
3525
4.12k
      
stageVar.getStorageClass() == spv::StorageClass::Output2.03k
) {
3526
4.05k
    stageVar.setEntryPoint(entryFunction);
3527
4.05k
  }
3528
4.12k
  decorateStageVarWithIntrinsicAttrs(stageVarData.decl, &stageVar, varInstr);
3529
4.12k
  stageVars.push_back(stageVar);
3530
3531
  // Emit OpDecorate* instructions to link this stage variable with the HLSL
3532
  // semantic it is created for
3533
4.12k
  spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr());
3534
3535
  // TODO: the following may not be correct?
3536
4.12k
  if (stageVarData.sigPoint->GetSignatureKind() ==
3537
4.12k
      hlsl::DXIL::SignatureKind::PatchConstOrPrim) {
3538
536
    if (stageVarData.sigPoint->GetKind() == hlsl::SigPoint::Kind::MSPOut) {
3539
      // Decorate with PerPrimitiveNV for per-primitive out variables.
3540
30
      spvBuilder.decoratePerPrimitiveNV(varInstr,
3541
30
                                        varInstr->getSourceLocation());
3542
506
    } else if (stageVar.getSemanticInfo().getKind() !=
3543
506
               hlsl::Semantic::Kind::DomainLocation) {
3544
490
      spvBuilder.decoratePatch(varInstr, varInstr->getSourceLocation());
3545
490
    }
3546
536
  }
3547
3548
  // Decorate with interpolation modes for pixel shader input variables, vertex
3549
  // shader output variables, or mesh shader output variables.
3550
4.12k
  if ((spvContext.isPS() && 
stageVarData.sigPoint->IsInput()1.48k
) ||
3551
4.12k
      
(3.37k
spvContext.isVS()3.37k
&&
stageVarData.sigPoint->IsOutput()966
) ||
3552
4.12k
      
(2.84k
spvContext.isMS()2.84k
&&
stageVarData.sigPoint->IsOutput()136
))
3553
1.33k
    decorateInterpolationMode(stageVarData.decl, stageVarData.type, varInstr,
3554
1.33k
                              *stageVarData.semantic);
3555
3556
  // Special case: The DX12 SV_InstanceID always counts from 0, even if the
3557
  // StartInstanceLocation parameter is non-zero. gl_InstanceIndex, however,
3558
  // starts from firstInstance. Thus it doesn't emulate actual DX12 shader
3559
  // behavior. To make it equivalent, SPIR-V codegen should emit:
3560
  // SV_InstanceID = gl_InstanceIndex - gl_BaseInstance
3561
  // As a result, we have to manually create a second stage variable for this
3562
  // specific case.
3563
  //
3564
  // According to the Vulkan spec on builtin variables:
3565
  // www.khronos.org/registry/vulkan/specs/1.1-extensions/html/vkspec.html#interfaces-builtin-variables
3566
  //
3567
  // InstanceIndex:
3568
  //   Decorating a variable in a vertex shader with the InstanceIndex
3569
  //   built-in decoration will make that variable contain the index of the
3570
  //   instance that is being processed by the current vertex shader
3571
  //   invocation. InstanceIndex begins at the firstInstance.
3572
  // BaseInstance
3573
  //   Decorating a variable with the BaseInstance built-in will make that
3574
  //   variable contain the integer value corresponding to the first instance
3575
  //   that was passed to the command that invoked the current vertex shader
3576
  //   invocation. BaseInstance is the firstInstance parameter to a direct
3577
  //   drawing command or the firstInstance member of a structure consumed by
3578
  //   an indirect drawing command.
3579
4.12k
  if (spirvOptions.supportNonzeroBaseInstance &&
3580
4.12k
      
stageVarData.semantic->getKind() == hlsl::Semantic::Kind::InstanceID8
&&
3581
4.12k
      
stageVarData.sigPoint->GetKind() == hlsl::SigPoint::Kind::VSIn8
) {
3582
    // The above call to createSpirvStageVar creates the gl_InstanceIndex.
3583
    // We should now manually create the gl_BaseInstance variable and do the
3584
    // subtraction.
3585
4
    auto *baseInstanceVar =
3586
4
        getBaseInstanceVariable(stageVarData.sigPoint, stageVarData.type);
3587
3588
    // SPIR-V code for 'SV_InstanceID = gl_InstanceIndex - gl_BaseInstance'
3589
4
    varInstr = getInstanceIdFromIndexAndBase(varInstr, baseInstanceVar);
3590
4
  }
3591
3592
4.12k
  if (spirvOptions.supportNonzeroBaseVertex &&
3593
4.12k
      
stageVarData.semantic->getKind() == hlsl::Semantic::Kind::VertexID8
&&
3594
4.12k
      
stageVarData.sigPoint->GetKind() == hlsl::SigPoint::Kind::VSIn4
) {
3595
3596
4
    auto *baseVertexVar =
3597
4
        getBaseVertexVariable(stageVarData.sigPoint, stageVarData.type);
3598
3599
    // SPIR-V code for 'SV_VertexID = gl_VertexIndex - gl_BaseVertex'
3600
4
    varInstr = getVertexIdFromIndexAndBase(varInstr, baseVertexVar);
3601
4
  }
3602
3603
  // We have semantics attached to this decl, which means it must be a
3604
  // function/parameter/variable. All are DeclaratorDecls.
3605
4.12k
  stageVarInstructions[cast<DeclaratorDecl>(stageVarData.decl)] = varInstr;
3606
3607
4.12k
  return varInstr;
3608
4.12k
}
3609
3610
QualType DeclResultIdMapper::getTypeForSpirvStageVariable(
3611
4.12k
    const StageVarDataBundle &stageVarData) {
3612
4.12k
  QualType evalType = stageVarData.type;
3613
4.12k
  switch (stageVarData.semantic->getKind()) {
3614
16
  case hlsl::Semantic::Kind::DomainLocation:
3615
    // SV_DomainLocation can refer to a float2, whereas TessCoord is a float3.
3616
    // To ensure SPIR-V validity, we must create a float3 and  extract a
3617
    // float2 from it before passing it to the main function.
3618
16
    evalType = astContext.getExtVectorType(astContext.FloatTy, 3);
3619
16
    break;
3620
108
  case hlsl::Semantic::Kind::TessFactor:
3621
    // SV_TessFactor is an array of size 2 for isoline patch, array of size 3
3622
    // for tri patch, and array of size 4 for quad patch, but it must always
3623
    // be an array of size 4 in SPIR-V for Vulkan.
3624
108
    evalType = astContext.getConstantArrayType(
3625
108
        astContext.FloatTy, llvm::APInt(32, 4), clang::ArrayType::Normal, 0);
3626
108
    break;
3627
106
  case hlsl::Semantic::Kind::InsideTessFactor:
3628
    // SV_InsideTessFactor is a single float for tri patch, and an array of
3629
    // size 2 for a quad patch, but it must always be an array of size 2 in
3630
    // SPIR-V for Vulkan.
3631
106
    evalType = astContext.getConstantArrayType(
3632
106
        astContext.FloatTy, llvm::APInt(32, 2), clang::ArrayType::Normal, 0);
3633
106
    break;
3634
8
  case hlsl::Semantic::Kind::Coverage:
3635
    // SV_Coverage is an uint value, but the SPIR-V builtin it corresponds to,
3636
    // SampleMask, must be an array of integers.
3637
8
    evalType = astContext.getConstantArrayType(astContext.UnsignedIntTy,
3638
8
                                               llvm::APInt(32, 1),
3639
8
                                               clang::ArrayType::Normal, 0);
3640
8
    break;
3641
2
  case hlsl::Semantic::Kind::InnerCoverage:
3642
    // SV_InnerCoverage is an uint value, but the corresponding SPIR-V builtin,
3643
    // FullyCoveredEXT, must be an boolean value.
3644
2
    evalType = astContext.BoolTy;
3645
2
    break;
3646
12
  case hlsl::Semantic::Kind::Barycentrics:
3647
12
    evalType = astContext.getExtVectorType(astContext.FloatTy, 3);
3648
12
    break;
3649
238
  case hlsl::Semantic::Kind::DispatchThreadID:
3650
322
  case hlsl::Semantic::Kind::GroupThreadID:
3651
354
  case hlsl::Semantic::Kind::GroupID:
3652
    // SV_DispatchThreadID, SV_GroupThreadID, and SV_GroupID are allowed to be
3653
    // uint, uint2, or uint3, but the corresponding SPIR-V builtins
3654
    // (GlobalInvocationId, LocalInvocationId, WorkgroupId) must be a uint3.
3655
    // Keep the original integer signedness
3656
354
    evalType = astContext.getExtVectorType(
3657
354
        hlsl::IsHLSLVecType(stageVarData.type)
3658
354
            ? 
hlsl::GetHLSLVecElementType(stageVarData.type)276
3659
354
            : 
stageVarData.type78
,
3660
354
        3);
3661
354
    break;
3662
3.51k
  default:
3663
    // Other semantic kinds can keep the original type.
3664
3.51k
    break;
3665
4.12k
  }
3666
3667
  // Boolean stage I/O variables must be represented as unsigned integers.
3668
  // Boolean built-in variables are represented as bool.
3669
4.12k
  if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type,
3670
4.12k
                          stageVarData.semantic->getKind(),
3671
4.12k
                          stageVarData.sigPoint->GetKind())) {
3672
54
    evalType = getUintTypeWithSourceComponents(astContext, stageVarData.type);
3673
54
  }
3674
3675
  // Handle the extra arrayness
3676
4.12k
  if (stageVarData.arraySize != 0) {
3677
458
    evalType = astContext.getConstantArrayType(
3678
458
        evalType, llvm::APInt(32, stageVarData.arraySize),
3679
458
        clang::ArrayType::Normal, 0);
3680
458
  }
3681
3682
4.12k
  return evalType;
3683
4.12k
}
3684
3685
bool DeclResultIdMapper::createStageVars(StageVarDataBundle &stageVarData,
3686
                                         bool asInput, SpirvInstruction **value,
3687
6.95k
                                         bool noWriteBack) {
3688
6.95k
  assert(value);
3689
  // invocationId should only be used for handling HS per-vertex output.
3690
6.95k
  if (stageVarData.invocationId.hasValue()) {
3691
188
    assert(spvContext.isHS() && stageVarData.arraySize != 0 && !asInput);
3692
188
  }
3693
3694
6.95k
  assert(stageVarData.semantic);
3695
3696
6.95k
  if (stageVarData.type->isVoidType()) {
3697
    // No stage variables will be created for void type.
3698
1.79k
    return true;
3699
1.79k
  }
3700
3701
  // We have several cases regarding HLSL semantics to handle here:
3702
  // * If the current decl inherits a semantic from some enclosing entity,
3703
  //   use the inherited semantic no matter whether there is a semantic
3704
  //   attached to the current decl.
3705
  // * If there is no semantic to inherit,
3706
  //   * If the current decl is a struct,
3707
  //     * If the current decl has a semantic, all its members inherit this
3708
  //       decl's semantic, with the index sequentially increasing;
3709
  //     * If the current decl does not have a semantic, all its members
3710
  //       should have semantics attached;
3711
  //   * If the current decl is not a struct, it should have semantic attached.
3712
3713
5.15k
  auto thisSemantic = getStageVarSemantic(stageVarData.decl);
3714
3715
  // Which semantic we should use for this decl
3716
  // Enclosing semantics override internal ones
3717
5.15k
  if (stageVarData.semantic->isValid()) {
3718
160
    if (thisSemantic.isValid()) {
3719
46
      emitWarning(
3720
46
          "internal semantic '%0' overridden by enclosing semantic '%1'",
3721
46
          thisSemantic.loc)
3722
46
          << thisSemantic.str << stageVarData.semantic->str;
3723
46
    }
3724
4.99k
  } else {
3725
4.99k
    stageVarData.semantic = &thisSemantic;
3726
4.99k
  }
3727
3728
5.15k
  if (hlsl::IsHLSLNodeType(stageVarData.type)) {
3729
    // Hijack the notion of semantic to use createSpirvInterfaceVariable
3730
64
    StringRef str = stageVarData.decl->getName();
3731
64
    stageVarData.semantic->str = stageVarData.semantic->name = str;
3732
64
    stageVarData.semantic->semantic = hlsl::Semantic::GetArbitrary();
3733
64
    SpirvVariable *varInstr = createSpirvInterfaceVariable(stageVarData);
3734
64
    if (!varInstr) {
3735
0
      return false;
3736
0
    }
3737
3738
64
    *value = hlsl::IsHLSLNodeInputType(stageVarData.type)
3739
64
                 ? varInstr
3740
64
                 : 
loadShaderInputVariable(varInstr, stageVarData)0
;
3741
64
    return true;
3742
64
  }
3743
3744
5.08k
  if (stageVarData.semantic->isValid() &&
3745
      // Structs with attached semantics will be handled later.
3746
5.08k
      
!stageVarData.type->isStructureType()4.30k
) {
3747
    // Found semantic attached directly to this Decl. This means we need to
3748
    // map this decl to a single stage variable.
3749
3750
4.22k
    const auto semanticKind = stageVarData.semantic->getKind();
3751
4.22k
    const auto sigPointKind = stageVarData.sigPoint->GetKind();
3752
3753
4.22k
    if (!validateShaderStageVar(stageVarData)) {
3754
22
      return false;
3755
22
    }
3756
3757
    // Special handling of certain mappings between HLSL semantics and
3758
    // SPIR-V builtins:
3759
    // * SV_CullDistance/SV_ClipDistance are outsourced to GlPerVertex.
3760
4.20k
    if (glPerVertex.tryToAccess(
3761
4.20k
            sigPointKind, semanticKind, stageVarData.semantic->index,
3762
4.20k
            stageVarData.invocationId, value, noWriteBack,
3763
4.20k
            /*vecComponent=*/nullptr, stageVarData.decl->getLocation()))
3764
146
      return true;
3765
3766
4.05k
    SpirvVariable *varInstr = createSpirvInterfaceVariable(stageVarData);
3767
4.05k
    if (!varInstr) {
3768
0
      return false;
3769
0
    }
3770
3771
    // Mark that we have used one index for this semantic
3772
4.05k
    ++stageVarData.semantic->index;
3773
3774
4.05k
    if (asInput) {
3775
2.09k
      *value = loadShaderInputVariable(varInstr, stageVarData);
3776
2.09k
      if ((stageVarData.decl->hasAttr<HLSLNoInterpolationAttr>() ||
3777
2.09k
           
stageVarData.asNoInterp2.03k
) &&
3778
2.09k
          
sigPointKind == hlsl::SigPoint::Kind::PSIn60
)
3779
58
        spvBuilder.addPerVertexStgInputFuncVarEntry(varInstr, *value);
3780
3781
2.09k
    } else {
3782
1.96k
      if (noWriteBack)
3783
222
        return true;
3784
      // Negate SV_Position.y if requested
3785
1.74k
      if (semanticKind == hlsl::Semantic::Kind::Position)
3786
144
        *value = theEmitter.invertYIfRequested(*value, thisSemantic.loc);
3787
1.74k
      storeToShaderOutputVariable(varInstr, *value, stageVarData);
3788
1.74k
    }
3789
3790
3.83k
    return true;
3791
4.05k
  }
3792
3793
  // If the decl itself doesn't have semantic string attached and there is no
3794
  // one to inherit, it should be a struct having all its fields with semantic
3795
  // strings.
3796
862
  if (!stageVarData.semantic->isValid() &&
3797
862
      
!stageVarData.type->isStructureType()782
) {
3798
0
    emitError("semantic string missing for shader %select{output|input}0 "
3799
0
              "variable '%1'",
3800
0
              stageVarData.decl->getLocation())
3801
0
        << asInput << stageVarData.decl->getName();
3802
0
    return false;
3803
0
  }
3804
3805
862
  if (asInput) {
3806
410
    *value = createStructInputVar(stageVarData, noWriteBack);
3807
410
    return (*value) != nullptr;
3808
452
  } else {
3809
452
    return createStructOutputVar(stageVarData, *value, noWriteBack);
3810
452
  }
3811
862
}
3812
3813
bool DeclResultIdMapper::createPayloadStageVars(
3814
    const hlsl::SigPoint *sigPoint, spv::StorageClass sc, const NamedDecl *decl,
3815
    bool asInput, QualType type, const llvm::StringRef namePrefix,
3816
34
    SpirvInstruction **value, uint32_t payloadMemOffset) {
3817
34
  assert(spvContext.isMS() || spvContext.isAS());
3818
34
  assert(value);
3819
3820
34
  if (type->isVoidType()) {
3821
    // No stage variables will be created for void type.
3822
0
    return true;
3823
0
  }
3824
3825
34
  const auto loc = decl->getLocation();
3826
3827
  // Most struct type stage vars must be flattened, but for EXT_mesh_shaders the
3828
  // mesh payload struct should be decorated with TaskPayloadWorkgroupEXT and
3829
  // used directly as the OpEntryPoint variable.
3830
34
  if (!type->isStructureType() ||
3831
34
      
featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)20
) {
3832
3833
24
    SpirvVariable *varInstr = nullptr;
3834
3835
    // Check whether a mesh payload module variable has already been added, as
3836
    // is the case for the groupshared payload variable parameter of
3837
    // DispatchMesh. In this case, change the storage class from Workgroup to
3838
    // TaskPayloadWorkgroupEXT.
3839
24
    if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
3840
44
      for (SpirvVariable *moduleVar : spvBuilder.getModule()->getVariables()) {
3841
44
        if (moduleVar->getAstResultType() == type) {
3842
8
          moduleVar->setStorageClass(
3843
8
              spv::StorageClass::TaskPayloadWorkgroupEXT);
3844
8
          varInstr = moduleVar;
3845
8
        }
3846
44
      }
3847
10
    }
3848
3849
    // If necessary, create new stage variable for mesh payload.
3850
24
    if (!varInstr) {
3851
16
      LocationAndComponent locationAndComponentCount =
3852
16
          type->isStructureType()
3853
16
              ? 
LocationAndComponent({0, 0, false})2
3854
16
              : 
getLocationAndComponentCount(astContext, type)14
;
3855
16
      StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr,
3856
16
                        type, locationAndComponentCount);
3857
16
      const auto name = namePrefix.str() + "." + decl->getNameAsString();
3858
16
      varInstr = spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false,
3859
16
                                          /*isNointerp=*/false, loc);
3860
3861
16
      if (!varInstr)
3862
0
        return false;
3863
3864
      // Even though these as user defined IO stage variables, set them as
3865
      // SPIR-V builtins in order to bypass any semantic string checks and
3866
      // location assignment.
3867
16
      stageVar.setIsSpirvBuiltin();
3868
16
      stageVar.setSpirvInstr(varInstr);
3869
16
      if (stageVar.getStorageClass() == spv::StorageClass::Input ||
3870
16
          stageVar.getStorageClass() == spv::StorageClass::Output) {
3871
0
        stageVar.setEntryPoint(entryFunction);
3872
0
      }
3873
16
      stageVars.push_back(stageVar);
3874
3875
16
      if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
3876
        // Decorate with PerTaskNV for mesh/amplification shader payload
3877
        // variables.
3878
14
        spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
3879
14
                                     varInstr->getSourceLocation());
3880
14
      }
3881
16
    }
3882
3883
24
    if (asInput) {
3884
8
      *value = spvBuilder.createLoad(type, varInstr, loc);
3885
16
    } else {
3886
16
      spvBuilder.createStore(varInstr, *value, loc);
3887
16
    }
3888
24
    return true;
3889
24
  }
3890
3891
  // This decl translates into multiple stage input/output payload variables
3892
  // and we need to load/store these individual member variables.
3893
10
  const auto *structDecl = type->getAs<RecordType>()->getDecl();
3894
10
  llvm::SmallVector<SpirvInstruction *, 4> subValues;
3895
10
  AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
3896
10
  uint32_t nextMemberOffset = 0;
3897
3898
14
  for (const auto *field : structDecl->fields()) {
3899
14
    const auto fieldType = field->getType();
3900
14
    SpirvInstruction *subValue = nullptr;
3901
14
    uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
3902
3903
    // The next avaiable offset after laying out the previous members.
3904
14
    std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
3905
14
        field->getType(), spirvOptions.ampPayloadLayoutRule,
3906
14
        /*isRowMajor*/ llvm::None, &stride);
3907
14
    alignmentCalc.alignUsingHLSLRelaxedLayout(
3908
14
        field->getType(), memberSize, memberAlignment, &nextMemberOffset);
3909
3910
    // The vk::offset attribute takes precedence over all.
3911
14
    if (field->getAttr<VKOffsetAttr>()) {
3912
0
      nextMemberOffset = field->getAttr<VKOffsetAttr>()->getOffset();
3913
0
    }
3914
3915
    // Each payload member must have an Offset Decoration.
3916
14
    payloadMemOffset = nextMemberOffset;
3917
14
    nextMemberOffset += memberSize;
3918
3919
14
    if (!asInput) {
3920
8
      subValue = spvBuilder.createCompositeExtract(
3921
8
          fieldType, *value, {getNumBaseClasses(type) + field->getFieldIndex()},
3922
8
          loc);
3923
8
    }
3924
3925
14
    if (!createPayloadStageVars(sigPoint, sc, field, asInput, field->getType(),
3926
14
                                namePrefix, &subValue, payloadMemOffset))
3927
0
      return false;
3928
3929
14
    if (asInput) {
3930
6
      subValues.push_back(subValue);
3931
6
    }
3932
14
  }
3933
10
  if (asInput) {
3934
4
    *value = spvBuilder.createCompositeConstruct(type, subValues, loc);
3935
4
  }
3936
10
  return true;
3937
10
}
3938
3939
bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
3940
                                               QualType type,
3941
                                               SpirvInstruction *value,
3942
86
                                               SourceRange range) {
3943
86
  assert(spvContext.isGS()); // Only for GS use
3944
3945
86
  if (hlsl::IsHLSLStreamOutputType(type))
3946
26
    type = hlsl::GetHLSLResourceResultType(type);
3947
86
  if (hasGSPrimitiveTypeQualifier(decl))
3948
0
    type = astContext.getAsConstantArrayType(type)->getElementType();
3949
3950
86
  auto semanticInfo = getStageVarSemantic(decl);
3951
86
  const auto loc = decl->getLocation();
3952
3953
86
  if (semanticInfo.isValid()) {
3954
    // Found semantic attached directly to this Decl. Write the value for this
3955
    // Decl to the corresponding stage output variable.
3956
3957
    // Handle SV_ClipDistance, and SV_CullDistance
3958
50
    if (glPerVertex.tryToAccess(
3959
50
            hlsl::DXIL::SigPointKind::GSOut, semanticInfo.semantic->GetKind(),
3960
50
            semanticInfo.index, llvm::None, &value,
3961
50
            /*noWriteBack=*/false, /*vecComponent=*/nullptr, loc, range))
3962
6
      return true;
3963
3964
    // Query the <result-id> for the stage output variable generated out
3965
    // of this decl.
3966
    // We have semantic string attached to this decl; therefore, it must be a
3967
    // DeclaratorDecl.
3968
44
    const auto found = stageVarInstructions.find(cast<DeclaratorDecl>(decl));
3969
3970
    // We should have recorded its stage output variable previously.
3971
44
    assert(found != stageVarInstructions.end());
3972
3973
    // Negate SV_Position.y if requested
3974
44
    if (semanticInfo.semantic->GetKind() == hlsl::Semantic::Kind::Position)
3975
12
      value = theEmitter.invertYIfRequested(value, loc, range);
3976
3977
    // Boolean stage output variables are represented as unsigned integers.
3978
44
    if (isBooleanStageIOVar(decl, type, semanticInfo.semantic->GetKind(),
3979
44
                            hlsl::SigPoint::Kind::GSOut)) {
3980
2
      QualType uintType = getUintTypeWithSourceComponents(astContext, type);
3981
2
      value = theEmitter.castToType(value, type, uintType, loc, range);
3982
2
    }
3983
3984
44
    spvBuilder.createStore(found->second, value, loc, range);
3985
44
    return true;
3986
50
  }
3987
3988
  // If the decl itself doesn't have semantic string attached, it should be
3989
  // a struct having all its fields with semantic strings.
3990
36
  if (!type->isStructureType()) {
3991
0
    emitError("semantic string missing for shader output variable '%0'", loc)
3992
0
        << decl->getName();
3993
0
    return false;
3994
0
  }
3995
3996
  // If we have base classes, we need to handle them first.
3997
36
  if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
3998
36
    uint32_t baseIndex = 0;
3999
36
    for (auto base : cxxDecl->bases()) {
4000
4
      auto *subValue = spvBuilder.createCompositeExtract(
4001
4
          base.getType(), value, {baseIndex++}, loc, range);
4002
4003
4
      if (!writeBackOutputStream(base.getType()->getAsCXXRecordDecl(),
4004
4
                                 base.getType(), subValue, range))
4005
0
        return false;
4006
4
    }
4007
36
  }
4008
4009
36
  const auto *structDecl = type->getAs<RecordType>()->getDecl();
4010
4011
  // Write out each field
4012
56
  for (const auto *field : structDecl->fields()) {
4013
56
    const auto fieldType = field->getType();
4014
56
    auto *subValue = spvBuilder.createCompositeExtract(
4015
56
        fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()},
4016
56
        loc, range);
4017
4018
56
    if (!writeBackOutputStream(field, field->getType(), subValue, range))
4019
0
      return false;
4020
56
  }
4021
4022
36
  return true;
4023
36
}
4024
4025
SpirvInstruction *
4026
DeclResultIdMapper::invertWIfRequested(SpirvInstruction *position,
4027
106
                                       SourceLocation loc) {
4028
  // Reciprocate SV_Position.w if requested
4029
106
  if (spirvOptions.invertW && 
spvContext.isPS()4
) {
4030
2
    const auto oldW = spvBuilder.createCompositeExtract(astContext.FloatTy,
4031
2
                                                        position, {3}, loc);
4032
2
    const auto newW = spvBuilder.createBinaryOp(
4033
2
        spv::Op::OpFDiv, astContext.FloatTy,
4034
2
        spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f)),
4035
2
        oldW, loc);
4036
2
    position = spvBuilder.createCompositeInsert(
4037
2
        astContext.getExtVectorType(astContext.FloatTy, 4), position, {3}, newW,
4038
2
        loc);
4039
2
  }
4040
106
  return position;
4041
106
}
4042
4043
void DeclResultIdMapper::decorateInterpolationMode(
4044
    const NamedDecl *decl, QualType type, SpirvVariable *varInstr,
4045
1.33k
    const SemanticInfo semanticInfo) {
4046
1.33k
  if (varInstr->getStorageClass() != spv::StorageClass::Input &&
4047
1.33k
      
varInstr->getStorageClass() != spv::StorageClass::Output586
) {
4048
2
    return;
4049
2
  }
4050
1.33k
  const bool isBaryCoord =
4051
1.33k
      (semanticInfo.getKind() == hlsl::Semantic::Kind::Barycentrics);
4052
1.33k
  uint32_t semanticIndex = semanticInfo.index;
4053
4054
1.33k
  if (isBaryCoord) {
4055
    // BaryCentrics inputs cannot have attrib 'nointerpolation'.
4056
12
    if (decl->getAttr<HLSLNoInterpolationAttr>()) {
4057
0
      emitError(
4058
0
          "SV_BaryCentrics inputs cannot have attribute 'nointerpolation'.",
4059
0
          decl->getLocation());
4060
0
    }
4061
    // SV_BaryCentrics could only have two index and apply to different inputs.
4062
    // The index should be 0 or 1, each index should be mapped to different
4063
    // interpolation type.
4064
12
    if (semanticIndex > 1) {
4065
0
      emitError("The index SV_BaryCentrics semantics could only be 1 or 0.",
4066
0
                decl->getLocation());
4067
12
    } else if (noPerspBaryCentricsIndex < 2 && 
perspBaryCentricsIndex < 20
) {
4068
0
      emitError(
4069
0
          "Cannot have more than 2 inputs with SV_BaryCentrics semantics.",
4070
0
          decl->getLocation());
4071
12
    } else if (decl->getAttr<HLSLNoPerspectiveAttr>()) {
4072
6
      if (noPerspBaryCentricsIndex == 2 &&
4073
6
          perspBaryCentricsIndex != semanticIndex) {
4074
6
        noPerspBaryCentricsIndex = semanticIndex;
4075
6
      } else {
4076
0
        emitError("Cannot have more than 1 noperspective inputs with "
4077
0
                  "SV_BaryCentrics semantics.",
4078
0
                  decl->getLocation());
4079
0
      }
4080
6
    } else {
4081
6
      if (perspBaryCentricsIndex == 2 &&
4082
6
          noPerspBaryCentricsIndex != semanticIndex) {
4083
6
        perspBaryCentricsIndex = semanticIndex;
4084
6
      } else {
4085
0
        emitError("Cannot have more than 1 perspective-correct inputs with "
4086
0
                  "SV_BaryCentrics semantics.",
4087
0
                  decl->getLocation());
4088
0
      }
4089
6
    }
4090
12
  }
4091
4092
1.33k
  const auto loc = decl->getLocation();
4093
1.33k
  if (isUintOrVecMatOfUintType(type) || 
isSintOrVecMatOfSintType(type)1.16k
||
4094
1.33k
      
isBoolOrVecMatOfBoolType(type)1.02k
) {
4095
    // TODO: Probably we can call hlsl::ValidateSignatureElement() for the
4096
    // following check.
4097
310
    if (decl->getAttr<HLSLLinearAttr>() || 
decl->getAttr<HLSLCentroidAttr>()304
||
4098
310
        
decl->getAttr<HLSLNoPerspectiveAttr>()298
||
4099
310
        
decl->getAttr<HLSLSampleAttr>()292
) {
4100
24
      emitError("only nointerpolation mode allowed for integer input "
4101
24
                "parameters in pixel shader or integer output in vertex shader",
4102
24
                decl->getLocation());
4103
286
    } else {
4104
286
      spvBuilder.decorateFlat(varInstr, loc);
4105
286
    }
4106
1.02k
  } else {
4107
    // Do nothing for HLSLLinearAttr since its the default
4108
    // Attributes can be used together. So cannot use else if.
4109
1.02k
    if (decl->getAttr<HLSLCentroidAttr>())
4110
10
      spvBuilder.decorateCentroid(varInstr, loc);
4111
1.02k
    if (decl->getAttr<HLSLNoInterpolationAttr>() && 
!isBaryCoord38
)
4112
38
      spvBuilder.decorateFlat(varInstr, loc);
4113
1.02k
    if (decl->getAttr<HLSLNoPerspectiveAttr>() && 
!isBaryCoord20
)
4114
14
      spvBuilder.decorateNoPerspective(varInstr, loc);
4115
1.02k
    if (decl->getAttr<HLSLSampleAttr>()) {
4116
16
      spvBuilder.decorateSample(varInstr, loc);
4117
16
    }
4118
1.02k
  }
4119
1.33k
}
4120
4121
SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
4122
                                                 QualType type,
4123
                                                 spv::StorageClass sc,
4124
484
                                                 SourceLocation loc) {
4125
  // Guarantee uniqueness
4126
484
  uint32_t spvBuiltinId = static_cast<uint32_t>(builtIn);
4127
484
  const auto builtInVar = builtinToVarMap.find(spvBuiltinId);
4128
484
  if (builtInVar != builtinToVarMap.end()) {
4129
140
    return builtInVar->second;
4130
140
  }
4131
344
  switch (builtIn) {
4132
4
  case spv::BuiltIn::HelperInvocation:
4133
26
  case spv::BuiltIn::SubgroupSize:
4134
42
  case spv::BuiltIn::SubgroupLocalInvocationId:
4135
42
    needsLegalization = true;
4136
42
    break;
4137
344
  }
4138
4139
  // Create a dummy StageVar for this builtin variable
4140
344
  auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
4141
344
                                           /*isPrecise*/ false, loc);
4142
4143
344
  if (spvContext.isPS() && 
sc == spv::StorageClass::Input18
) {
4144
16
    if (isUintOrVecMatOfUintType(type) || 
isSintOrVecMatOfSintType(type)8
||
4145
16
        
isBoolOrVecMatOfBoolType(type)6
) {
4146
10
      spvBuilder.decorateFlat(var, loc);
4147
10
    }
4148
16
  }
4149
4150
344
  const hlsl::SigPoint *sigPoint =
4151
344
      hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
4152
344
          hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
4153
344
          /*isPatchConstant=*/false));
4154
4155
344
  StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
4156
344
                    /*locAndComponentCount=*/{0, 0, false});
4157
4158
344
  stageVar.setIsSpirvBuiltin();
4159
344
  stageVar.setSpirvInstr(var);
4160
344
  stageVars.push_back(stageVar);
4161
4162
  // Store in map for re-use
4163
344
  builtinToVarMap[spvBuiltinId] = var;
4164
344
  return var;
4165
344
}
4166
4167
SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
4168
                                                 QualType type,
4169
444
                                                 SourceLocation loc) {
4170
444
  spv::StorageClass sc = spv::StorageClass::Max;
4171
4172
  // Valid builtins supported
4173
444
  switch (builtIn) {
4174
2
  case spv::BuiltIn::HelperInvocation:
4175
28
  case spv::BuiltIn::SubgroupSize:
4176
48
  case spv::BuiltIn::SubgroupLocalInvocationId:
4177
52
  case spv::BuiltIn::HitTNV:
4178
56
  case spv::BuiltIn::RayTmaxNV:
4179
76
  case spv::BuiltIn::RayTminNV:
4180
90
  case spv::BuiltIn::HitKindNV:
4181
110
  case spv::BuiltIn::IncomingRayFlagsNV:
4182
126
  case spv::BuiltIn::InstanceCustomIndexNV:
4183
130
  case spv::BuiltIn::RayGeometryIndexKHR:
4184
146
  case spv::BuiltIn::PrimitiveId:
4185
162
  case spv::BuiltIn::InstanceId:
4186
182
  case spv::BuiltIn::WorldRayDirectionNV:
4187
202
  case spv::BuiltIn::WorldRayOriginNV:
4188
218
  case spv::BuiltIn::ObjectRayDirectionNV:
4189
234
  case spv::BuiltIn::ObjectRayOriginNV:
4190
266
  case spv::BuiltIn::ObjectToWorldNV:
4191
298
  case spv::BuiltIn::WorldToObjectNV:
4192
340
  case spv::BuiltIn::LaunchIdNV:
4193
380
  case spv::BuiltIn::LaunchSizeNV:
4194
380
  case spv::BuiltIn::GlobalInvocationId:
4195
380
  case spv::BuiltIn::WorkgroupId:
4196
380
  case spv::BuiltIn::LocalInvocationIndex:
4197
382
  case spv::BuiltIn::RemainingRecursionLevelsAMDX:
4198
382
  case spv::BuiltIn::ShaderIndexAMDX:
4199
382
    sc = spv::StorageClass::Input;
4200
382
    break;
4201
8
  case spv::BuiltIn::TaskCountNV:
4202
28
  case spv::BuiltIn::PrimitiveCountNV:
4203
52
  case spv::BuiltIn::PrimitiveIndicesNV:
4204
52
  case spv::BuiltIn::PrimitivePointIndicesEXT:
4205
52
  case spv::BuiltIn::PrimitiveLineIndicesEXT:
4206
62
  case spv::BuiltIn::PrimitiveTriangleIndicesEXT:
4207
62
  case spv::BuiltIn::CullPrimitiveEXT:
4208
62
    sc = spv::StorageClass::Output;
4209
62
    break;
4210
0
  default:
4211
0
    assert(false && "cannot infer storage class for SPIR-V builtin");
4212
0
    break;
4213
444
  }
4214
4215
444
  return getBuiltinVar(builtIn, type, sc, loc);
4216
444
}
4217
4218
SpirvFunction *
4219
3.12k
DeclResultIdMapper::getRayTracingStageVarEntryFunction(SpirvVariable *var) {
4220
3.12k
  return rayTracingStageVarToEntryPoints[var];
4221
3.12k
}
4222
4223
SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
4224
    StageVar *stageVar, const NamedDecl *decl, const llvm::StringRef name,
4225
4.12k
    SourceLocation srcLoc) {
4226
4.12k
  using spv::BuiltIn;
4227
4228
4.12k
  const auto sigPoint = stageVar->getSigPoint();
4229
4.12k
  const auto semanticKind = stageVar->getSemanticInfo().getKind();
4230
4.12k
  const auto sigPointKind = sigPoint->GetKind();
4231
4.12k
  const auto type = stageVar->getAstType();
4232
4.12k
  const auto isPrecise = decl->hasAttr<HLSLPreciseAttr>();
4233
4.12k
  auto isNointerp = decl->hasAttr<HLSLNoInterpolationAttr>();
4234
4.12k
  spv::StorageClass sc = hlsl::IsHLSLNodeInputType(stageVar->getAstType())
4235
4.12k
                             ? 
spv::StorageClass::NodePayloadAMDX64
4236
4.12k
                             : 
getStorageClassForSigPoint(sigPoint)4.05k
;
4237
4.12k
  if (sc == spv::StorageClass::Max)
4238
0
    return 0;
4239
4.12k
  stageVar->setStorageClass(sc);
4240
4241
  // [[vk::builtin(...)]] takes precedence.
4242
4.12k
  if (const auto *builtinAttr = stageVar->getBuiltInAttr()) {
4243
52
    const auto spvBuiltIn =
4244
52
        llvm::StringSwitch<BuiltIn>(builtinAttr->getBuiltIn())
4245
52
            .Case("PointSize", BuiltIn::PointSize)
4246
52
            .Case("HelperInvocation", BuiltIn::HelperInvocation)
4247
52
            .Case("BaseVertex", BuiltIn::BaseVertex)
4248
52
            .Case("BaseInstance", BuiltIn::BaseInstance)
4249
52
            .Case("DrawIndex", BuiltIn::DrawIndex)
4250
52
            .Case("DeviceIndex", BuiltIn::DeviceIndex)
4251
52
            .Case("ViewportMaskNV", BuiltIn::ViewportMaskNV)
4252
52
            .Default(BuiltIn::Max);
4253
4254
52
    assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
4255
52
    if (spvBuiltIn == BuiltIn::HelperInvocation &&
4256
52
        
!featureManager.isTargetEnvVulkan1p3OrAbove()4
) {
4257
      // If [[vk::HelperInvocation]] is used for Vulkan 1.2 or less, we enable
4258
      // SPV_EXT_demote_to_helper_invocation extension to use
4259
      // OpIsHelperInvocationEXT instruction.
4260
2
      featureManager.allowExtension("SPV_EXT_demote_to_helper_invocation");
4261
2
      return spvBuilder.addVarForHelperInvocation(type, isPrecise, srcLoc);
4262
2
    }
4263
50
    return spvBuilder.addStageBuiltinVar(type, sc, spvBuiltIn, isPrecise,
4264
50
                                         srcLoc);
4265
52
  }
4266
4267
  // The following translation assumes that semantic validity in the current
4268
  // shader model is already checked, so it only covers valid SigPoints for
4269
  // each semantic.
4270
4.07k
  switch (semanticKind) {
4271
  // According to DXIL spec, the Position SV can be used by all SigPoints
4272
  // other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
4273
  // According to Vulkan spec, the Position BuiltIn can only be used
4274
  // by VSOut, HS/DS/GS In/Out, MSOut.
4275
306
  case hlsl::Semantic::Kind::Position: {
4276
306
    if (sigPointKind == hlsl::SigPoint::Kind::VSOut &&
4277
306
        !containOnlyVecWithFourFloats(
4278
124
            type, theEmitter.getSpirvOptions().enable16BitTypes)) {
4279
16
      emitError("SV_Position must be a 4-component 32-bit float vector or a "
4280
16
                "composite which recursively contains only such a vector",
4281
16
                srcLoc);
4282
16
    }
4283
4284
306
    switch (sigPointKind) {
4285
14
    case hlsl::SigPoint::Kind::VSIn:
4286
14
    case hlsl::SigPoint::Kind::PCOut:
4287
14
    case hlsl::SigPoint::Kind::DSIn:
4288
14
      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise,
4289
14
                                      isNointerp, srcLoc);
4290
124
    case hlsl::SigPoint::Kind::VSOut:
4291
128
    case hlsl::SigPoint::Kind::HSCPIn:
4292
134
    case hlsl::SigPoint::Kind::HSCPOut:
4293
142
    case hlsl::SigPoint::Kind::DSCPIn:
4294
162
    case hlsl::SigPoint::Kind::DSOut:
4295
170
    case hlsl::SigPoint::Kind::GSVIn:
4296
186
    case hlsl::SigPoint::Kind::GSOut:
4297
220
    case hlsl::SigPoint::Kind::MSOut:
4298
220
      stageVar->setIsSpirvBuiltin();
4299
220
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position,
4300
220
                                           isPrecise, srcLoc);
4301
72
    case hlsl::SigPoint::Kind::PSIn:
4302
72
      stageVar->setIsSpirvBuiltin();
4303
72
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragCoord,
4304
72
                                           isPrecise, srcLoc);
4305
0
    default:
4306
0
      llvm_unreachable("invalid usage of SV_Position sneaked in");
4307
306
    }
4308
306
  }
4309
  // According to DXIL spec, the VertexID SV can only be used by VSIn.
4310
  // According to Vulkan spec, the VertexIndex BuiltIn can only be used by
4311
  // VSIn.
4312
10
  case hlsl::Semantic::Kind::VertexID: {
4313
10
    stageVar->setIsSpirvBuiltin();
4314
10
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::VertexIndex,
4315
10
                                         isPrecise, srcLoc);
4316
306
  }
4317
  // According to DXIL spec, the InstanceID SV can be used by VSIn, VSOut,
4318
  // HSCPIn, HSCPOut, DSCPIn, DSOut, GSVIn, GSOut, PSIn.
4319
  // According to Vulkan spec, the InstanceIndex BuitIn can only be used by
4320
  // VSIn.
4321
32
  case hlsl::Semantic::Kind::InstanceID: {
4322
32
    switch (sigPointKind) {
4323
6
    case hlsl::SigPoint::Kind::VSIn:
4324
6
      stageVar->setIsSpirvBuiltin();
4325
6
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InstanceIndex,
4326
6
                                           isPrecise, srcLoc);
4327
6
    case hlsl::SigPoint::Kind::VSOut:
4328
8
    case hlsl::SigPoint::Kind::HSCPIn:
4329
10
    case hlsl::SigPoint::Kind::HSCPOut:
4330
12
    case hlsl::SigPoint::Kind::DSCPIn:
4331
14
    case hlsl::SigPoint::Kind::DSOut:
4332
16
    case hlsl::SigPoint::Kind::GSVIn:
4333
18
    case hlsl::SigPoint::Kind::GSOut:
4334
26
    case hlsl::SigPoint::Kind::PSIn:
4335
26
      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise,
4336
26
                                      isNointerp, srcLoc);
4337
0
    default:
4338
0
      llvm_unreachable("invalid usage of SV_InstanceID sneaked in");
4339
32
    }
4340
32
  }
4341
  // According to DXIL spec, the StartVertexLocation SV can only be used by
4342
  // VSIn. According to Vulkan spec, the BaseVertex BuiltIn can only be used by
4343
  // VSIn.
4344
2
  case hlsl::Semantic::Kind::StartVertexLocation: {
4345
2
    stageVar->setIsSpirvBuiltin();
4346
2
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::BaseVertex,
4347
2
                                         isPrecise, srcLoc);
4348
32
  }
4349
  // According to DXIL spec, the StartInstanceLocation SV can only be used by
4350
  // VSIn. According to Vulkan spec, the BaseInstance BuiltIn can only be used
4351
  // by VSIn.
4352
2
  case hlsl::Semantic::Kind::StartInstanceLocation: {
4353
2
    stageVar->setIsSpirvBuiltin();
4354
2
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::BaseInstance,
4355
2
                                         isPrecise, srcLoc);
4356
32
  }
4357
  // According to DXIL spec, the Depth{|GreaterEqual|LessEqual} SV can only be
4358
  // used by PSOut.
4359
  // According to Vulkan spec, the FragDepth BuiltIn can only be used by PSOut.
4360
2
  case hlsl::Semantic::Kind::Depth:
4361
4
  case hlsl::Semantic::Kind::DepthGreaterEqual:
4362
6
  case hlsl::Semantic::Kind::DepthLessEqual: {
4363
6
    stageVar->setIsSpirvBuiltin();
4364
    // Vulkan requires the DepthReplacing execution mode to write to FragDepth.
4365
6
    spvBuilder.addExecutionMode(entryFunction,
4366
6
                                spv::ExecutionMode::DepthReplacing, {}, srcLoc);
4367
6
    if (semanticKind == hlsl::Semantic::Kind::DepthGreaterEqual)
4368
2
      spvBuilder.addExecutionMode(entryFunction,
4369
2
                                  spv::ExecutionMode::DepthGreater, {}, srcLoc);
4370
4
    else if (semanticKind == hlsl::Semantic::Kind::DepthLessEqual)
4371
2
      spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::DepthLess,
4372
2
                                  {}, srcLoc);
4373
6
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragDepth,
4374
6
                                         isPrecise, srcLoc);
4375
4
  }
4376
  // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
4377
  // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
4378
  // According to Vulkan spec, the ClipDistance/CullDistance
4379
  // BuiltIn can only be used by VSOut, HS/DS/GS In/Out, MSOut.
4380
4
  case hlsl::Semantic::Kind::ClipDistance:
4381
8
  case hlsl::Semantic::Kind::CullDistance: {
4382
8
    switch (sigPointKind) {
4383
8
    case hlsl::SigPoint::Kind::VSIn:
4384
8
    case hlsl::SigPoint::Kind::PCOut:
4385
8
    case hlsl::SigPoint::Kind::DSIn:
4386
8
      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise,
4387
8
                                      isNointerp, srcLoc);
4388
0
    case hlsl::SigPoint::Kind::VSOut:
4389
0
    case hlsl::SigPoint::Kind::HSCPIn:
4390
0
    case hlsl::SigPoint::Kind::HSCPOut:
4391
0
    case hlsl::SigPoint::Kind::DSCPIn:
4392
0
    case hlsl::SigPoint::Kind::DSOut:
4393
0
    case hlsl::SigPoint::Kind::GSVIn:
4394
0
    case hlsl::SigPoint::Kind::GSOut:
4395
0
    case hlsl::SigPoint::Kind::PSIn:
4396
0
    case hlsl::SigPoint::Kind::MSOut:
4397
0
      llvm_unreachable("should be handled in gl_PerVertex struct");
4398
0
    default:
4399
0
      llvm_unreachable(
4400
8
          "invalid usage of SV_ClipDistance/SV_CullDistance sneaked in");
4401
8
    }
4402
8
  }
4403
  // According to DXIL spec, the IsFrontFace SV can only be used by GSOut and
4404
  // PSIn.
4405
  // According to Vulkan spec, the FrontFacing BuitIn can only be used in PSIn.
4406
6
  case hlsl::Semantic::Kind::IsFrontFace: {
4407
6
    switch (sigPointKind) {
4408
2
    case hlsl::SigPoint::Kind::GSOut:
4409
2
      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise,
4410
2
                                      isNointerp, srcLoc);
4411
4
    case hlsl::SigPoint::Kind::PSIn:
4412
4
      stageVar->setIsSpirvBuiltin();
4413
4
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FrontFacing,
4414
4
                                           isPrecise, srcLoc);
4415
0
    default:
4416
0
      llvm_unreachable("invalid usage of SV_IsFrontFace sneaked in");
4417
6
    }
4418
6
  }
4419
  // According to DXIL spec, the Target SV can only be used by PSOut.
4420
  // There is no corresponding builtin decoration in SPIR-V. So generate normal
4421
  // Vulkan stage input/output variables.
4422
722
  case hlsl::Semantic::Kind::Target:
4423
  // An arbitrary semantic is defined by users. Generate normal Vulkan stage
4424
  // input/output variables.
4425
2.82k
  case hlsl::Semantic::Kind::Arbitrary: {
4426
2.82k
    return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, isNointerp,
4427
2.82k
                                    srcLoc);
4428
    // TODO: patch constant function in hull shader
4429
722
  }
4430
  // According to DXIL spec, the DispatchThreadID SV can only be used by CSIn.
4431
  // According to Vulkan spec, the GlobalInvocationId can only be used in CSIn.
4432
238
  case hlsl::Semantic::Kind::DispatchThreadID: {
4433
238
    stageVar->setIsSpirvBuiltin();
4434
238
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::GlobalInvocationId,
4435
238
                                         isPrecise, srcLoc);
4436
722
  }
4437
  // According to DXIL spec, the GroupID SV can only be used by CSIn.
4438
  // According to Vulkan spec, the WorkgroupId can only be used in CSIn.
4439
32
  case hlsl::Semantic::Kind::GroupID: {
4440
32
    stageVar->setIsSpirvBuiltin();
4441
32
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::WorkgroupId,
4442
32
                                         isPrecise, srcLoc);
4443
722
  }
4444
  // According to DXIL spec, the GroupThreadID SV can only be used by CSIn.
4445
  // According to Vulkan spec, the LocalInvocationId can only be used in CSIn.
4446
84
  case hlsl::Semantic::Kind::GroupThreadID: {
4447
84
    stageVar->setIsSpirvBuiltin();
4448
84
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::LocalInvocationId,
4449
84
                                         isPrecise, srcLoc);
4450
722
  }
4451
  // According to DXIL spec, the GroupIndex SV can only be used by CSIn.
4452
  // According to Vulkan spec, the LocalInvocationIndex can only be used in
4453
  // CSIn.
4454
22
  case hlsl::Semantic::Kind::GroupIndex: {
4455
22
    stageVar->setIsSpirvBuiltin();
4456
22
    return spvBuilder.addStageBuiltinVar(
4457
22
        type, sc, BuiltIn::LocalInvocationIndex, isPrecise, srcLoc);
4458
722
  }
4459
  // According to DXIL spec, the OutputControlID SV can only be used by HSIn.
4460
  // According to Vulkan spec, the InvocationId BuiltIn can only be used in
4461
  // HS/GS In.
4462
78
  case hlsl::Semantic::Kind::OutputControlPointID: {
4463
78
    stageVar->setIsSpirvBuiltin();
4464
78
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
4465
78
                                         isPrecise, srcLoc);
4466
722
  }
4467
  // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
4468
  // DSIn, GSIn, GSOut, PSIn, and MSPOut.
4469
  // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
4470
  // HS/DS/PS In, GS In/Out, MSPOut.
4471
64
  case hlsl::Semantic::Kind::PrimitiveID: {
4472
    // Translate to PrimitiveId BuiltIn for all valid SigPoints.
4473
64
    stageVar->setIsSpirvBuiltin();
4474
64
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::PrimitiveId,
4475
64
                                         isPrecise, srcLoc);
4476
722
  }
4477
  // According to DXIL spec, the TessFactor SV can only be used by PCOut and
4478
  // DSIn.
4479
  // According to Vulkan spec, the TessLevelOuter BuiltIn can only be used in
4480
  // PCOut and DSIn.
4481
108
  case hlsl::Semantic::Kind::TessFactor: {
4482
108
    stageVar->setIsSpirvBuiltin();
4483
108
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelOuter,
4484
108
                                         isPrecise, srcLoc);
4485
722
  }
4486
  // According to DXIL spec, the InsideTessFactor SV can only be used by PCOut
4487
  // and DSIn.
4488
  // According to Vulkan spec, the TessLevelInner BuiltIn can only be used in
4489
  // PCOut and DSIn.
4490
106
  case hlsl::Semantic::Kind::InsideTessFactor: {
4491
106
    stageVar->setIsSpirvBuiltin();
4492
106
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelInner,
4493
106
                                         isPrecise, srcLoc);
4494
722
  }
4495
  // According to DXIL spec, the DomainLocation SV can only be used by DSIn.
4496
  // According to Vulkan spec, the TessCoord BuiltIn can only be used in DSIn.
4497
16
  case hlsl::Semantic::Kind::DomainLocation: {
4498
16
    stageVar->setIsSpirvBuiltin();
4499
16
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessCoord,
4500
16
                                         isPrecise, srcLoc);
4501
722
  }
4502
  // According to DXIL spec, the GSInstanceID SV can only be used by GSIn.
4503
  // According to Vulkan spec, the InvocationId BuiltIn can only be used in
4504
  // HS/GS In.
4505
2
  case hlsl::Semantic::Kind::GSInstanceID: {
4506
2
    stageVar->setIsSpirvBuiltin();
4507
2
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
4508
2
                                         isPrecise, srcLoc);
4509
722
  }
4510
  // According to DXIL spec, the SampleIndex SV can only be used by PSIn.
4511
  // According to Vulkan spec, the SampleId BuiltIn can only be used in PSIn.
4512
4
  case hlsl::Semantic::Kind::SampleIndex: {
4513
4
    setInterlockExecutionMode(spv::ExecutionMode::SampleInterlockOrderedEXT);
4514
4
    stageVar->setIsSpirvBuiltin();
4515
4
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId, isPrecise,
4516
4
                                         srcLoc);
4517
722
  }
4518
  // According to DXIL spec, the StencilRef SV can only be used by PSOut.
4519
2
  case hlsl::Semantic::Kind::StencilRef: {
4520
2
    stageVar->setIsSpirvBuiltin();
4521
2
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT,
4522
2
                                         isPrecise, srcLoc);
4523
722
  }
4524
  // According to DXIL spec, the Barycentrics SV can only be used by PSIn.
4525
12
  case hlsl::Semantic::Kind::Barycentrics: {
4526
12
    stageVar->setIsSpirvBuiltin();
4527
4528
    // Selecting the correct builtin according to interpolation mode
4529
12
    auto bi = BuiltIn::Max;
4530
12
    if (decl->hasAttr<HLSLNoPerspectiveAttr>()) {
4531
6
      bi = BuiltIn::BaryCoordNoPerspKHR;
4532
6
    } else {
4533
6
      bi = BuiltIn::BaryCoordKHR;
4534
6
    }
4535
4536
12
    return spvBuilder.addStageBuiltinVar(type, sc, bi, isPrecise, srcLoc);
4537
722
  }
4538
  // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
4539
  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
4540
  // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut
4541
  // PSIn, and MSPOut.
4542
30
  case hlsl::Semantic::Kind::RenderTargetArrayIndex: {
4543
30
    switch (sigPointKind) {
4544
4
    case hlsl::SigPoint::Kind::VSIn:
4545
6
    case hlsl::SigPoint::Kind::HSCPIn:
4546
8
    case hlsl::SigPoint::Kind::HSCPOut:
4547
10
    case hlsl::SigPoint::Kind::PCOut:
4548
12
    case hlsl::SigPoint::Kind::DSIn:
4549
14
    case hlsl::SigPoint::Kind::DSCPIn:
4550
16
    case hlsl::SigPoint::Kind::GSVIn:
4551
16
      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise,
4552
16
                                      isNointerp, srcLoc);
4553
4
    case hlsl::SigPoint::Kind::VSOut:
4554
6
    case hlsl::SigPoint::Kind::DSOut:
4555
6
      stageVar->setIsSpirvBuiltin();
4556
6
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
4557
6
                                           srcLoc);
4558
2
    case hlsl::SigPoint::Kind::GSOut:
4559
4
    case hlsl::SigPoint::Kind::PSIn:
4560
8
    case hlsl::SigPoint::Kind::MSPOut:
4561
8
      stageVar->setIsSpirvBuiltin();
4562
8
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
4563
8
                                           srcLoc);
4564
0
    default:
4565
0
      llvm_unreachable("invalid usage of SV_RenderTargetArrayIndex sneaked in");
4566
30
    }
4567
30
  }
4568
  // According to DXIL spec, the ViewportArrayIndex SV can only be used by
4569
  // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
4570
  // According to Vulkan spec, the ViewportIndex BuiltIn can only be used in
4571
  // GSOut, PSIn, and MSPOut.
4572
30
  case hlsl::Semantic::Kind::ViewPortArrayIndex: {
4573
30
    switch (sigPointKind) {
4574
4
    case hlsl::SigPoint::Kind::VSIn:
4575
6
    case hlsl::SigPoint::Kind::HSCPIn:
4576
8
    case hlsl::SigPoint::Kind::HSCPOut:
4577
10
    case hlsl::SigPoint::Kind::PCOut:
4578
12
    case hlsl::SigPoint::Kind::DSIn:
4579
14
    case hlsl::SigPoint::Kind::DSCPIn:
4580
16
    case hlsl::SigPoint::Kind::GSVIn:
4581
16
      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise,
4582
16
                                      isNointerp, srcLoc);
4583
4
    case hlsl::SigPoint::Kind::VSOut:
4584
6
    case hlsl::SigPoint::Kind::DSOut:
4585
6
      stageVar->setIsSpirvBuiltin();
4586
6
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
4587
6
                                           isPrecise, srcLoc);
4588
2
    case hlsl::SigPoint::Kind::GSOut:
4589
4
    case hlsl::SigPoint::Kind::PSIn:
4590
8
    case hlsl::SigPoint::Kind::MSPOut:
4591
8
      stageVar->setIsSpirvBuiltin();
4592
8
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
4593
8
                                           isPrecise, srcLoc);
4594
0
    default:
4595
0
      llvm_unreachable("invalid usage of SV_ViewportArrayIndex sneaked in");
4596
30
    }
4597
30
  }
4598
  // According to DXIL spec, the Coverage SV can only be used by PSIn and PSOut.
4599
  // According to Vulkan spec, the SampleMask BuiltIn can only be used in
4600
  // PSIn and PSOut.
4601
8
  case hlsl::Semantic::Kind::Coverage: {
4602
8
    stageVar->setIsSpirvBuiltin();
4603
8
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleMask,
4604
8
                                         isPrecise, srcLoc);
4605
30
  }
4606
  // According to DXIL spec, the ViewID SV can only be used by VSIn, PCIn,
4607
  // HSIn, DSIn, GSIn, PSIn.
4608
  // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
4609
  // VS/HS/DS/GS/PS input.
4610
22
  case hlsl::Semantic::Kind::ViewID: {
4611
22
    stageVar->setIsSpirvBuiltin();
4612
22
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex,
4613
22
                                         isPrecise, srcLoc);
4614
30
  }
4615
  // According to DXIL spec, the InnerCoverage SV can only be used as PSIn.
4616
  // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
4617
  // PSIn.
4618
2
  case hlsl::Semantic::Kind::InnerCoverage: {
4619
2
    stageVar->setIsSpirvBuiltin();
4620
2
    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FullyCoveredEXT,
4621
2
                                         isPrecise, srcLoc);
4622
30
  }
4623
  // According to DXIL spec, the ShadingRate SV can only be used by GSOut,
4624
  // VSOut, or PSIn. According to Vulkan spec, the FragSizeEXT BuiltIn can only
4625
  // be used as VSOut, GSOut, MSOut or PSIn.
4626
10
  case hlsl::Semantic::Kind::ShadingRate: {
4627
10
    setInterlockExecutionMode(
4628
10
        spv::ExecutionMode::ShadingRateInterlockOrderedEXT);
4629
10
    switch (sigPointKind) {
4630
4
    case hlsl::SigPoint::Kind::PSIn:
4631
4
      stageVar->setIsSpirvBuiltin();
4632
4
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ShadingRateKHR,
4633
4
                                           isPrecise, srcLoc);
4634
4
    case hlsl::SigPoint::Kind::VSOut:
4635
4
    case hlsl::SigPoint::Kind::GSOut:
4636
4
    case hlsl::SigPoint::Kind::MSOut:
4637
6
    case hlsl::SigPoint::Kind::MSPOut:
4638
6
      stageVar->setIsSpirvBuiltin();
4639
6
      return spvBuilder.addStageBuiltinVar(
4640
6
          type, sc, BuiltIn::PrimitiveShadingRateKHR, isPrecise, srcLoc);
4641
0
    default:
4642
0
      emitError("semantic ShadingRate must be used only for PSIn, VSOut, "
4643
0
                "GSOut, MSOut",
4644
0
                srcLoc);
4645
0
      break;
4646
10
    }
4647
0
    break;
4648
10
  }
4649
  // According to DXIL spec, the ShadingRate SV can only be used by
4650
  // MSPOut or PSIn.
4651
  // According to Vulkan spec, the CullPrimitiveEXT BuiltIn can only
4652
  // be used as MSOut.
4653
4
  case hlsl::Semantic::Kind::CullPrimitive: {
4654
4
    switch (sigPointKind) {
4655
0
    case hlsl::SigPoint::Kind::PSIn:
4656
0
      stageVar->setIsSpirvBuiltin();
4657
0
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::CullPrimitiveEXT,
4658
0
                                           isPrecise, srcLoc);
4659
4
    case hlsl::SigPoint::Kind::MSPOut:
4660
4
      stageVar->setIsSpirvBuiltin();
4661
4
      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::CullPrimitiveEXT,
4662
4
                                           isPrecise, srcLoc);
4663
0
    default:
4664
0
      emitError("semantic CullPrimitive must be used only for PSIn, MSPOut",
4665
0
                srcLoc);
4666
0
      break;
4667
4
    }
4668
0
    break;
4669
4
  }
4670
0
  default:
4671
0
    emitError("semantic %0 unimplemented", srcLoc)
4672
0
        << stageVar->getSemanticStr();
4673
0
    break;
4674
4.07k
  }
4675
4676
0
  return 0;
4677
4.07k
}
4678
4679
spv::StorageClass
4680
4.06k
DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
4681
  // This translation is done based on the HLSL reference (see docs/dxil.rst).
4682
4.06k
  const auto sigPointKind = sigPoint->GetKind();
4683
4.06k
  const auto signatureKind = sigPoint->GetSignatureKind();
4684
4.06k
  spv::StorageClass sc = spv::StorageClass::Max;
4685
4.06k
  switch (signatureKind) {
4686
1.47k
  case hlsl::DXIL::SignatureKind::Input:
4687
1.47k
    sc = spv::StorageClass::Input;
4688
1.47k
    break;
4689
1.53k
  case hlsl::DXIL::SignatureKind::Output:
4690
1.53k
    sc = spv::StorageClass::Output;
4691
1.53k
    break;
4692
530
  case hlsl::DXIL::SignatureKind::Invalid: {
4693
    // There are some special cases in HLSL (See docs/dxil.rst):
4694
    // SignatureKind is "invalid" for PCIn, HSIn, GSIn, and CSIn.
4695
530
    switch (sigPointKind) {
4696
6
    case hlsl::DXIL::SigPointKind::PCIn:
4697
136
    case hlsl::DXIL::SigPointKind::HSIn:
4698
142
    case hlsl::DXIL::SigPointKind::GSIn:
4699
440
    case hlsl::DXIL::SigPointKind::CSIn:
4700
490
    case hlsl::DXIL::SigPointKind::MSIn:
4701
530
    case hlsl::DXIL::SigPointKind::ASIn:
4702
530
      sc = spv::StorageClass::Input;
4703
530
      break;
4704
0
    default:
4705
0
      llvm_unreachable("Found invalid SigPoint kind for semantic");
4706
530
    }
4707
530
    break;
4708
530
  }
4709
536
  case hlsl::DXIL::SignatureKind::PatchConstOrPrim: {
4710
    // There are some special cases in HLSL (See docs/dxil.rst):
4711
    // SignatureKind is "PatchConstOrPrim" for PCOut, MSPOut and DSIn.
4712
536
    switch (sigPointKind) {
4713
410
    case hlsl::DXIL::SigPointKind::PCOut:
4714
440
    case hlsl::DXIL::SigPointKind::MSPOut:
4715
      // Patch Constant Output (Output of Hull which is passed to Domain).
4716
      // Mesh Shader per-primitive output attributes.
4717
440
      sc = spv::StorageClass::Output;
4718
440
      break;
4719
96
    case hlsl::DXIL::SigPointKind::DSIn:
4720
      // Domain Shader regular input - Patch Constant data plus system values.
4721
96
      sc = spv::StorageClass::Input;
4722
96
      break;
4723
0
    default:
4724
0
      llvm_unreachable("Found invalid SigPoint kind for semantic");
4725
536
    }
4726
536
    break;
4727
536
  }
4728
536
  default:
4729
0
    llvm_unreachable("Found invalid SigPoint kind for semantic");
4730
4.06k
  }
4731
4.06k
  return sc;
4732
4.06k
}
4733
4734
QualType DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
4735
25.8k
    const DeclaratorDecl *decl, bool *shouldBeAlias) {
4736
25.8k
  if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
4737
    // This method is only intended to be used to create SPIR-V variables in the
4738
    // Function or Private storage class.
4739
15.1k
    assert(!SpirvEmitter::isExternalVar(varDecl));
4740
15.1k
  }
4741
4742
25.8k
  const QualType type = getTypeOrFnRetType(decl);
4743
  // Whether we should generate this decl as an alias variable.
4744
25.8k
  bool genAlias = false;
4745
4746
  // For ConstantBuffers, TextureBuffers, StructuredBuffers, ByteAddressBuffers
4747
25.8k
  if (isConstantTextureBuffer(type) ||
4748
25.8k
      
isOrContainsAKindOfStructuredOrByteBuffer(type)25.8k
) {
4749
822
    genAlias = true;
4750
822
  }
4751
4752
  // Return via parameter whether alias was generated.
4753
25.8k
  if (shouldBeAlias)
4754
17.9k
    *shouldBeAlias = genAlias;
4755
4756
25.8k
  if (genAlias) {
4757
822
    needsLegalization = true;
4758
822
    createCounterVarForDecl(decl);
4759
822
  }
4760
4761
25.8k
  return type;
4762
25.8k
}
4763
4764
bool DeclResultIdMapper::getImplicitRegisterType(const ResourceVar &var,
4765
58
                                                 char *registerTypeOut) const {
4766
58
  assert(registerTypeOut);
4767
4768
58
  if (var.getSpirvInstr()) {
4769
58
    if (var.getSpirvInstr()->hasAstResultType()) {
4770
52
      QualType type = var.getSpirvInstr()->getAstResultType();
4771
      // Strip outer arrayness first
4772
58
      while (type->isArrayType())
4773
6
        type = type->getAsArrayTypeUnsafe()->getElementType();
4774
4775
      // t - for shader resource views (SRV)
4776
52
      if (isTexture(type) || 
isNonWritableStructuredBuffer(type)34
||
4777
52
          
isByteAddressBuffer(type)32
||
isBuffer(type)30
) {
4778
24
        *registerTypeOut = 't';
4779
24
        return true;
4780
24
      }
4781
      // s - for samplers
4782
28
      else if (isSampler(type)) {
4783
4
        *registerTypeOut = 's';
4784
4
        return true;
4785
4
      }
4786
      // u - for unordered access views (UAV)
4787
24
      else if (isRWByteAddressBuffer(type) || 
isRWAppendConsumeSBuffer(type)22
||
4788
24
               
isRWBuffer(type)16
||
isRWTexture(type)14
) {
4789
20
        *registerTypeOut = 'u';
4790
20
        return true;
4791
20
      }
4792
4793
      // b - for constant buffer
4794
      // views (CBV)
4795
4
      else if (isConstantBuffer(type)) {
4796
4
        *registerTypeOut = 'b';
4797
4
        return true;
4798
4
      }
4799
52
    } else {
4800
6
      llvm::StringRef hlslUserType = var.getSpirvInstr()->getHlslUserType();
4801
      // b - for constant buffer views (CBV)
4802
6
      if (var.isGlobalsBuffer() || hlslUserType == "cbuffer" ||
4803
6
          hlslUserType == "ConstantBuffer") {
4804
0
        *registerTypeOut = 'b';
4805
0
        return true;
4806
0
      }
4807
6
      if (hlslUserType == "tbuffer") {
4808
2
        *registerTypeOut = 't';
4809
2
        return true;
4810
2
      }
4811
6
    }
4812
58
  }
4813
4814
4
  *registerTypeOut = '\0';
4815
4
  return false;
4816
58
}
4817
4818
SpirvVariable *
4819
DeclResultIdMapper::createRayTracingNVStageVar(spv::StorageClass sc,
4820
154
                                               const VarDecl *decl) {
4821
154
  return createRayTracingNVStageVar(sc, decl->getType(), decl->getName().str(),
4822
154
                                    decl->hasAttr<HLSLPreciseAttr>(),
4823
154
                                    decl->hasAttr<HLSLNoInterpolationAttr>());
4824
154
}
4825
4826
SpirvVariable *DeclResultIdMapper::createRayTracingNVStageVar(
4827
    spv::StorageClass sc, QualType type, std::string name, bool isPrecise,
4828
160
    bool isNointerp) {
4829
160
  SpirvVariable *retVal = nullptr;
4830
4831
  // Raytracing interface variables are special since they do not participate
4832
  // in any interface matching and hence do not create StageVar and
4833
  // track them under StageVars vector
4834
4835
160
  switch (sc) {
4836
58
  case spv::StorageClass::IncomingRayPayloadNV:
4837
64
  case spv::StorageClass::IncomingCallableDataNV:
4838
112
  case spv::StorageClass::HitAttributeNV:
4839
140
  case spv::StorageClass::RayPayloadNV:
4840
160
  case spv::StorageClass::CallableDataNV:
4841
160
    retVal = spvBuilder.addModuleVar(type, sc, isPrecise, isNointerp, name);
4842
160
    break;
4843
4844
0
  default:
4845
0
    assert(false && "Unsupported SPIR-V storage class for raytracing");
4846
160
  }
4847
4848
160
  rayTracingStageVarToEntryPoints[retVal] = entryFunction;
4849
4850
160
  return retVal;
4851
160
}
4852
4853
4
void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
4854
4
  const VarDecl *varDecl = dyn_cast<VarDecl>(decl);
4855
4
  if (!varDecl || !varDecl->isImplicit())
4856
2
    return;
4857
4858
2
  APValue *val = varDecl->evaluateValue();
4859
2
  if (!val)
4860
0
    return;
4861
4862
2
  SpirvInstruction *constVal =
4863
2
      spvBuilder.getConstantInt(astContext.UnsignedIntTy, val->getInt());
4864
2
  constVal->setRValue(true);
4865
2
  registerVariableForDecl(varDecl, constVal);
4866
2
}
4867
4868
void DeclResultIdMapper::decorateWithIntrinsicAttrs(
4869
    const NamedDecl *decl, SpirvVariable *varInst,
4870
4.51k
    llvm::function_ref<void(VKDecorateExtAttr *)> extraFunctionForDecoAttr) {
4871
4.51k
  if (!decl->hasAttrs())
4872
2.67k
    return;
4873
4874
  // TODO: Handle member field in a struct and function parameter.
4875
1.97k
  
for (auto &attr : decl->getAttrs())1.84k
{
4876
1.97k
    if (auto decoAttr = dyn_cast<VKDecorateExtAttr>(attr)) {
4877
24
      spvBuilder.decorateWithLiterals(
4878
24
          varInst, decoAttr->getDecorate(),
4879
24
          {decoAttr->literals_begin(), decoAttr->literals_end()},
4880
24
          varInst->getSourceLocation());
4881
24
      extraFunctionForDecoAttr(decoAttr);
4882
24
      continue;
4883
24
    }
4884
1.94k
    if (auto decoAttr = dyn_cast<VKDecorateIdExtAttr>(attr)) {
4885
4
      llvm::SmallVector<SpirvInstruction *, 2> args;
4886
4
      for (Expr *arg : decoAttr->arguments()) {
4887
4
        args.push_back(theEmitter.doExpr(arg));
4888
4
      }
4889
4
      spvBuilder.decorateWithIds(varInst, decoAttr->getDecorate(), args,
4890
4
                                 varInst->getSourceLocation());
4891
4
      continue;
4892
4
    }
4893
1.94k
    if (auto decoAttr = dyn_cast<VKDecorateStringExtAttr>(attr)) {
4894
4
      llvm::SmallVector<llvm::StringRef, 2> args(decoAttr->arguments_begin(),
4895
4
                                                 decoAttr->arguments_end());
4896
4
      spvBuilder.decorateWithStrings(varInst, decoAttr->getDecorate(), args,
4897
4
                                     varInst->getSourceLocation());
4898
4
      continue;
4899
4
    }
4900
1.94k
  }
4901
1.84k
}
4902
4903
void DeclResultIdMapper::decorateStageVarWithIntrinsicAttrs(
4904
4.12k
    const NamedDecl *decl, StageVar *stageVar, SpirvVariable *varInst) {
4905
4.12k
  auto checkBuiltInLocationDecoration =
4906
4.12k
      [stageVar](const VKDecorateExtAttr *decoAttr) {
4907
20
        auto decorate = static_cast<spv::Decoration>(decoAttr->getDecorate());
4908
20
        if (decorate == spv::Decoration::BuiltIn ||
4909
20
            
decorate == spv::Decoration::Location4
) {
4910
          // This information will be used to avoid
4911
          // assigning multiple location decorations
4912
          // in finalizeStageIOLocations()
4913
18
          stageVar->setIsLocOrBuiltinDecorateAttr();
4914
18
        }
4915
20
      };
4916
4.12k
  decorateWithIntrinsicAttrs(decl, varInst, checkBuiltInLocationDecoration);
4917
4.12k
}
4918
4919
14
void DeclResultIdMapper::setInterlockExecutionMode(spv::ExecutionMode mode) {
4920
14
  interlockExecutionMode = mode;
4921
14
}
4922
4923
36
spv::ExecutionMode DeclResultIdMapper::getInterlockExecutionMode() {
4924
36
  return interlockExecutionMode.getValueOr(
4925
36
      spv::ExecutionMode::PixelInterlockOrderedEXT);
4926
36
}
4927
4928
void DeclResultIdMapper::registerVariableForDecl(const VarDecl *var,
4929
13.3k
                                                 SpirvInstruction *varInstr) {
4930
13.3k
  DeclSpirvInfo spirvInfo;
4931
13.3k
  spirvInfo.instr = varInstr;
4932
13.3k
  spirvInfo.indexInCTBuffer = -1;
4933
13.3k
  registerVariableForDecl(var, spirvInfo);
4934
13.3k
}
4935
4936
void DeclResultIdMapper::registerVariableForDecl(const VarDecl *var,
4937
17.9k
                                                 DeclSpirvInfo spirvInfo) {
4938
17.9k
  for (const auto *v : var->redecls()) {
4939
17.9k
    astDecls[v] = spirvInfo;
4940
17.9k
  }
4941
17.9k
}
4942
4943
void DeclResultIdMapper::copyHullOutStageVarsToOutputPatch(
4944
    SpirvInstruction *hullMainOutputPatch, const ParmVarDecl *outputPatchDecl,
4945
8
    QualType outputControlPointType, uint32_t numOutputControlPoints) {
4946
58
  for (uint32_t outputCtrlPoint = 0; outputCtrlPoint < numOutputControlPoints;
4947
50
       ++outputCtrlPoint) {
4948
50
    SpirvConstant *index = spvBuilder.getConstantInt(
4949
50
        astContext.UnsignedIntTy, llvm::APInt(32, outputCtrlPoint));
4950
50
    auto *tempLocation = spvBuilder.createAccessChain(
4951
50
        outputControlPointType, hullMainOutputPatch, {index}, /*loc=*/{});
4952
50
    storeOutStageVarsToStorage(cast<DeclaratorDecl>(outputPatchDecl), index,
4953
50
                               outputControlPointType, tempLocation);
4954
50
  }
4955
8
}
4956
4957
void DeclResultIdMapper::storeOutStageVarsToStorage(
4958
    const DeclaratorDecl *outputPatchDecl, SpirvConstant *ctrlPointID,
4959
150
    QualType outputControlPointType, SpirvInstruction *ptr) {
4960
150
  if (!outputControlPointType->isStructureType()) {
4961
100
    const auto found = stageVarInstructions.find(outputPatchDecl);
4962
100
    if (found == stageVarInstructions.end()) {
4963
0
      emitError("Shader output variable '%0' was not created", {})
4964
0
          << outputPatchDecl->getName();
4965
0
    }
4966
100
    auto *ptrToOutputStageVar = spvBuilder.createAccessChain(
4967
100
        outputControlPointType, found->second, {ctrlPointID}, /*loc=*/{});
4968
100
    auto *load =
4969
100
        spvBuilder.createLoad(outputControlPointType, ptrToOutputStageVar,
4970
100
                              /*loc=*/{});
4971
100
    spvBuilder.createStore(ptr, load, /*loc=*/{});
4972
100
    return;
4973
100
  }
4974
4975
50
  const auto *recordType = outputControlPointType->getAs<RecordType>();
4976
50
  assert(recordType != nullptr);
4977
50
  const auto *structDecl = recordType->getDecl();
4978
50
  assert(structDecl != nullptr);
4979
4980
50
  uint32_t index = 0;
4981
100
  for (const auto *field : structDecl->fields()) {
4982
100
    SpirvConstant *indexInst = spvBuilder.getConstantInt(
4983
100
        astContext.UnsignedIntTy, llvm::APInt(32, index));
4984
100
    auto *tempLocation = spvBuilder.createAccessChain(field->getType(), ptr,
4985
100
                                                      {indexInst}, /*loc=*/{});
4986
100
    storeOutStageVarsToStorage(cast<DeclaratorDecl>(field), ctrlPointID,
4987
100
                               field->getType(), tempLocation);
4988
100
    ++index;
4989
100
  }
4990
50
}
4991
4992
void DeclResultIdMapper::registerCapabilitiesAndExtensionsForType(
4993
5.76k
    const TypedefType *type) {
4994
5.76k
  for (const auto *decl : typeAliasesWithAttributes) {
4995
10
    if (type == decl->getTypeForDecl()) {
4996
8
      for (auto *attribute : decl->specific_attrs<VKExtensionExtAttr>()) {
4997
0
        clang::StringRef extensionName = attribute->getName();
4998
0
        spvBuilder.requireExtension(extensionName, decl->getLocation());
4999
0
      }
5000
8
      for (auto *attribute : decl->specific_attrs<VKCapabilityExtAttr>()) {
5001
8
        spv::Capability cap = spv::Capability(attribute->getCapability());
5002
8
        spvBuilder.requireCapability(cap, decl->getLocation());
5003
8
      }
5004
8
    }
5005
10
  }
5006
5.76k
}
5007
5008
} // end namespace spirv
5009
} // end namespace clang