Skip to content

Commit 0407f1e

Browse files
committed
address code review comments and fix tests
1 parent d35b1ca commit 0407f1e

File tree

13 files changed

+74
-42
lines changed

13 files changed

+74
-42
lines changed

include/dxc/DXIL/DxilModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ class DxilModule {
255255
unsigned GetNumThreads(unsigned idx) const;
256256

257257
unsigned GetGroupSharedLimit() const;
258+
// The total amount of group shared memory (in bytes) used by the shader.
259+
unsigned GetTGSMSizeInBytes() const;
258260

259261
// Compute shader
260262
DxilWaveSize &GetWaveSize();

include/dxc/DxilContainer/DxilPipelineStateValidation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ struct PSVRuntimeInfo3 : public PSVRuntimeInfo2 {
176176
};
177177

178178
struct PSVRuntimeInfo4 : public PSVRuntimeInfo3 {
179-
uint32_t GroupSharedLimit;
179+
uint32_t NumBytesGroupSharedMemory;
180180
};
181181

182182
enum class PSVResourceType {

lib/DXIL/DxilMetadataHelper.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,9 +1625,12 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
16251625
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
16261626
}
16271627

1628-
MDVals.emplace_back(
1629-
Uint32ToConstMD(DxilMDHelper::kDxilGroupSharedLimitTag));
1630-
MDVals.emplace_back(Uint32ToConstMD(props.groupSharedLimitBytes));
1628+
const hlsl::ShaderModel *SM = GetShaderModel();
1629+
if (SM->IsSM610Plus()) {
1630+
MDVals.emplace_back(
1631+
Uint32ToConstMD(DxilMDHelper::kDxilGroupSharedLimitTag));
1632+
MDVals.emplace_back(Uint32ToConstMD(props.groupSharedLimitBytes));
1633+
}
16311634
} break;
16321635
// Geometry shader.
16331636
case DXIL::ShaderKind::Geometry: {
@@ -1780,6 +1783,8 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
17801783
case DxilMDHelper::kDxilGroupSharedLimitTag: {
17811784
DXASSERT(props.IsCS(), "else invalid shader kind");
17821785
props.groupSharedLimitBytes = ConstMDToUint32(MDO);
1786+
if (!m_pSM->IsSMAtLeast(6, 10))
1787+
m_bExtraMetadata = true;
17831788
} break;
17841789

17851790
case DxilMDHelper::kDxilGSStateTag: {

lib/DXIL/DxilModule.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,19 @@ unsigned DxilModule::GetGroupSharedLimit() const {
421421
return props.groupSharedLimitBytes;
422422
}
423423

424+
unsigned DxilModule::GetTGSMSizeInBytes() const {
425+
const DataLayout &DL = m_pModule->getDataLayout();
426+
unsigned TGSMSize = 0;
427+
428+
for (GlobalVariable &GV : m_pModule->globals()) {
429+
if (GV.getType()->getAddressSpace() == DXIL::kTGSMAddrSpace) {
430+
TGSMSize += DL.getTypeAllocSize(GV.getType()->getElementType());
431+
}
432+
}
433+
434+
return TGSMSize;
435+
}
436+
424437
DxilWaveSize &DxilModule::GetWaveSize() {
425438
return const_cast<DxilWaveSize &>(
426439
static_cast<const DxilModule *>(this)->GetWaveSize());

lib/DxilContainer/DxilPipelineStateValidation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) {
312312
case ShaderModel::Kind::Compute:
313313
case ShaderModel::Kind::Mesh:
314314
case ShaderModel::Kind::Amplification:
315-
pInfo4->GroupSharedLimit = DM.GetGroupSharedLimit();
315+
pInfo4->NumBytesGroupSharedMemory = DM.GetTGSMSizeInBytes();
316316
break;
317317
default:
318318
break;
@@ -824,7 +824,7 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
824824
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
825825
}
826826
if (pInfo4) {
827-
OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n";
827+
OS << Comment << " NumBytesGroupSharedMemory=" << pInfo4->NumBytesGroupSharedMemory << "\n";
828828
}
829829
break;
830830
case PSVShaderKind::Amplification:
@@ -834,7 +834,7 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
834834
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
835835
}
836836
if (pInfo4) {
837-
OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n";
837+
OS << Comment << " NumBytesGroupSharedMemory=" << pInfo4->NumBytesGroupSharedMemory << "\n";
838838
}
839839
break;
840840
case PSVShaderKind::Mesh:
@@ -863,7 +863,7 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
863863
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
864864
}
865865
if (pInfo4) {
866-
OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n";
866+
OS << Comment << " NumBytesGroupSharedMemory=" << pInfo4->NumBytesGroupSharedMemory << "\n";
867867
}
868868
break;
869869
case PSVShaderKind::Library:

lib/DxilValidation/DxilContainerValidation.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ class PSVContentVerifier {
185185
unsigned PSVVersion);
186186
void VerifyViewIDDependence(PSVRuntimeInfo1 *PSV1, unsigned PSVVersion);
187187
void VerifyEntryProperties(const ShaderModel *SM, PSVRuntimeInfo0 *PSV0,
188-
PSVRuntimeInfo1 *PSV1, PSVRuntimeInfo2 *PSV2);
188+
PSVRuntimeInfo1 *PSV1, PSVRuntimeInfo2 *PSV2,
189+
PSVRuntimeInfo3 *PSV3, PSVRuntimeInfo4 *PSV4);
189190
void EmitMismatchError(StringRef Name, StringRef PartContent,
190191
StringRef ModuleContent) {
191192
ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches,
@@ -412,7 +413,9 @@ void PSVContentVerifier::VerifyResources(unsigned PSVVersion) {
412413
void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
413414
PSVRuntimeInfo0 *PSV0,
414415
PSVRuntimeInfo1 *PSV1,
415-
PSVRuntimeInfo2 *PSV2) {
416+
PSVRuntimeInfo2 *PSV2,
417+
PSVRuntimeInfo3 *PSV3,
418+
PSVRuntimeInfo4 *PSV4) {
416419
PSVRuntimeInfo4 DMPSV;
417420
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4));
418421

@@ -445,6 +448,9 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
445448
else
446449
Mismatched = memcmp(PSV0, &DMPSV, sizeof(PSVRuntimeInfo0)) != 0;
447450

451+
if (PSV4 && PSV4->NumBytesGroupSharedMemory != DMPSV.NumBytesGroupSharedMemory)
452+
Mismatched = true;
453+
448454
if (Mismatched) {
449455
std::string Str;
450456
raw_string_ostream OS(Str);
@@ -477,9 +483,11 @@ void PSVContentVerifier::Verify(unsigned ValMajor, unsigned ValMinor,
477483
PSVRuntimeInfo0 *PSV0 = PSV.GetPSVRuntimeInfo0();
478484
PSVRuntimeInfo1 *PSV1 = PSV.GetPSVRuntimeInfo1();
479485
PSVRuntimeInfo2 *PSV2 = PSV.GetPSVRuntimeInfo2();
486+
PSVRuntimeInfo3 *PSV3 = PSV.GetPSVRuntimeInfo3();
487+
PSVRuntimeInfo4 *PSV4 = PSV.GetPSVRuntimeInfo4();
480488

481489
const ShaderModel *SM = DM.GetShaderModel();
482-
VerifyEntryProperties(SM, PSV0, PSV1, PSV2);
490+
VerifyEntryProperties(SM, PSV0, PSV1, PSV2, PSV3, PSV4);
483491
if (PSVVersion > 0) {
484492
if (((PSV.GetSigInputElements() + PSV.GetSigOutputElements() +
485493
PSV.GetSigPatchConstOrPrimElements()) > 0) &&

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,23 +1648,6 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16481648

16491649
if (const HLSLGroupSharedLimitAttr *Attr =
16501650
FD->getAttr<HLSLGroupSharedLimitAttr>()) {
1651-
if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
1652-
unsigned DiagID = Diags.getCustomDiagID(
1653-
DiagnosticsEngine::Error,
1654-
"attribute GroupSharedLimit only valid for CS/MS/AS.");
1655-
Diags.Report(Attr->getLocation(), DiagID);
1656-
return;
1657-
}
1658-
1659-
// Only valid for SM6.10+
1660-
if (!SM->IsSM610Plus()) {
1661-
unsigned DiagID = Diags.getCustomDiagID(
1662-
DiagnosticsEngine::Error, "attribute GroupSharedLimit only valid for "
1663-
"Shader Model 6.10 and above.");
1664-
Diags.Report(Attr->getLocation(), DiagID);
1665-
return;
1666-
}
1667-
16681651
funcProps->groupSharedLimitBytes = Attr->getLimit();
16691652
} else {
16701653
if (SM->IsMS()) { // Fallback to default limits

tools/clang/lib/Sema/SemaHLSL.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12971,6 +12971,22 @@ void DiagnoseEntryAttrAllowedOnStage(clang::Sema *self,
1297112971
}
1297212972
break;
1297312973
}
12974+
case clang::attr::HLSLGroupSharedLimit: {
12975+
switch (shaderKind) {
12976+
case DXIL::ShaderKind::Compute:
12977+
case DXIL::ShaderKind::Mesh:
12978+
case DXIL::ShaderKind::Amplification:
12979+
case DXIL::ShaderKind::Node:
12980+
break;
12981+
default:
12982+
self->Diag(pAttr->getRange().getBegin(),
12983+
diag::err_hlsl_attribute_unsupported_stage)
12984+
<< "GroupSharedLimit"
12985+
<< "compute, mesh, node, or amplification";
12986+
break;
12987+
}
12988+
break;
12989+
}
1297412990
}
1297512991
}
1297612992
}

tools/clang/test/DXC/dumpPSV_AS.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// CHECK-NEXT: PSVRuntimeInfo:
77
// CHECK-NEXT: Amplification Shader
88
// CHECK-NEXT: NumThreads=(32,1,1)
9+
// CHECK-NEXT: NumBytesGroupSharedMemory=0
910
// CHECK-NEXT: MinimumExpectedWaveLaneCount: 0
1011
// CHECK-NEXT: MaximumExpectedWaveLaneCount: 4294967295
1112
// CHECK-NEXT: UsesViewID: false

tools/clang/test/DXC/dumpPSV_CS.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// CHECK-NEXT: PSVRuntimeInfo:
77
// CHECK-NEXT: Compute Shader
88
// CHECK-NEXT: NumThreads=(128,1,1)
9+
// CHECK-NEXT: NumBytesGroupSharedMemory=2048
910
// CHECK-NEXT: MinimumExpectedWaveLaneCount: 0
1011
// CHECK-NEXT: MaximumExpectedWaveLaneCount: 4294967295
1112
// CHECK-NEXT: UsesViewID: false

0 commit comments

Comments
 (0)