git.s-ol.nu ~forks/DiligentCore / efa43e2
fixed resource state transitions, some improvements for ray tracing azhirnov 10 months ago
34 changed file(s) with 2370 addition(s) and 715 deletion(s). Raw diff Collapse all Expand all
160160 VERIFY(m_pCurrPtr <= m_pBuffer + m_ReservedSize, "Buffer overflow");
161161 return m_pCurrPtr - m_pBuffer;
162162 }
163 size_t GetReservedSize() const
164 {
165 return m_ReservedSize;
166 }
163167
164168 private:
165169 Char* m_pBuffer = nullptr;
183183 return (this->m_State & State) == State;
184184 }
185185
186 #ifdef DILIGENT_DEVELOPMENT
187 void UpdateVersion()
188 {
189 m_Version.fetch_add(1);
190 }
191
192 Uint32 GetVersion() const
193 {
194 return m_Version.load();
195 }
196 #endif
197
186198 protected:
187199 static void ValidateBottomLevelASDesc(const BottomLevelASDesc& Desc)
188200 {
214226 std::unordered_map<HashMapStringKey, Uint32, HashMapStringKey::Hasher> m_NameToIndex;
215227
216228 StringPool m_StringPool;
229
230 #ifdef DILIGENT_DEVELOPMENT
231 std::atomic<Uint32> m_Version{0};
232 #endif
217233 };
218234
219235 } // namespace Diligent
18541854 DEV_CHECK_ERR(OldState != RESOURCE_STATE_UNKNOWN, "The state of buffer '", BuffDesc.Name, "' is unknown to the engine and is not explicitly specified in the barrier");
18551855 DEV_CHECK_ERR(VerifyResourceStates(OldState, false), "Invlaid old state specified for buffer '", BuffDesc.Name, "'");
18561856 }
1857 else if (RefCntAutoPtr<IBottomLevelAS> pBLAS{Barrier.pResource, IID_BottomLevelAS})
1858 {
1859 // AZ TODO
1860 }
1861 else if (RefCntAutoPtr<ITopLevelAS> pTLAS{Barrier.pResource, IID_TopLevelAS})
1862 {
1863 // AZ TODO
1857 else if (RefCntAutoPtr<IBottomLevelAS> pBottomLevelAS{Barrier.pResource, IID_BottomLevelAS})
1858 {
1859 const auto& BLASDesc = pBottomLevelAS->GetDesc();
1860 OldState = Barrier.OldState != RESOURCE_STATE_UNKNOWN ? Barrier.OldState : pBottomLevelAS->GetState();
1861 DEV_CHECK_ERR(OldState != RESOURCE_STATE_UNKNOWN, "The state of BLAS '", BLASDesc.Name, "' is unknown to the engine and is not explicitly specified in the barrier");
1862 DEV_CHECK_ERR(Barrier.NewState == RESOURCE_STATE_BUILD_AS_READ || Barrier.NewState == RESOURCE_STATE_BUILD_AS_WRITE || Barrier.NewState == RESOURCE_STATE_RAY_TRACING,
1863 "Invlaid new state specified for BLAS '", BLASDesc.Name, "'");
1864 DEV_CHECK_ERR(Barrier.TransitionType != STATE_TRANSITION_TYPE_IMMEDIATE, "Split barriers are not supported for BLAS");
1865 }
1866 else if (RefCntAutoPtr<ITopLevelAS> pTopLevelAS{Barrier.pResource, IID_TopLevelAS})
1867 {
1868 const auto& TLASDesc = pTopLevelAS->GetDesc();
1869 OldState = Barrier.OldState != RESOURCE_STATE_UNKNOWN ? Barrier.OldState : pTopLevelAS->GetState();
1870 DEV_CHECK_ERR(OldState != RESOURCE_STATE_UNKNOWN, "The state of TLAS '", TLASDesc.Name, "' is unknown to the engine and is not explicitly specified in the barrier");
1871 DEV_CHECK_ERR(Barrier.NewState == RESOURCE_STATE_BUILD_AS_READ || Barrier.NewState == RESOURCE_STATE_BUILD_AS_WRITE || Barrier.NewState == RESOURCE_STATE_RAY_TRACING,
1872 "Invlaid new state specified for TLAS '", TLASDesc.Name, "'");
1873 DEV_CHECK_ERR(Barrier.TransitionType != STATE_TRANSITION_TYPE_IMMEDIATE, "Split barriers are not supported for TLAS");
18641874 }
18651875 else
18661876 {
19421952 template <typename BaseInterface, typename ImplementationTraits>
19431953 bool DeviceContextBase<BaseInterface, ImplementationTraits>::BuildBLAS(const BLASBuildAttribs& Attribs, int)
19441954 {
1955 if (m_pActiveRenderPass != nullptr)
1956 {
1957 LOG_ERROR_MESSAGE("BuildBLAS command must be performed outside of render pass");
1958 return false;
1959 }
1960
19451961 if (Attribs.pBLAS == nullptr)
19461962 {
19471963 LOG_ERROR_MESSAGE("IDeviceContext::BuildBLAS: pBLAS must not be null");
20892105 return false;
20902106 }
20912107 }
2108 #endif // DILIGENT_DEVELOPMENT
20922109
20932110 const auto& BLASDesc = Attribs.pBLAS->GetDesc();
20942111
21122129 return false;
21132130 }
21142131
2115 if (ScratchDesc.uiSizeInBytes - Attribs.ScratchBufferOffset > Attribs.pBLAS->GetScratchBufferSizes().Build)
2132 if (ScratchDesc.uiSizeInBytes - Attribs.ScratchBufferOffset < Attribs.pBLAS->GetScratchBufferSizes().Build)
21162133 {
21172134 LOG_ERROR_MESSAGE("IDeviceContext::BuildBLAS: pScratchBuffer size is too small, use pBLAS->GetScratchBufferSizes().Build to get required size for scratch buffer");
21182135 return false;
21232140 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pScratchBuffer must be created with BIND_RAY_TRACING flag");
21242141 return false;
21252142 }
2126 #endif // DILIGENT_DEVELOPMENT
21272143
21282144 return true;
21292145 }
21312147 template <typename BaseInterface, typename ImplementationTraits>
21322148 bool DeviceContextBase<BaseInterface, ImplementationTraits>::BuildTLAS(const TLASBuildAttribs& Attribs, int)
21332149 {
2150 if (m_pActiveRenderPass != nullptr)
2151 {
2152 LOG_ERROR_MESSAGE("BuildTLAS command must be performed outside of render pass");
2153 return false;
2154 }
2155
21342156 if (Attribs.pTLAS == nullptr)
21352157 {
21362158 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pTLAS must not be null");
21612183 return false;
21622184 }
21632185
2186 const auto& TLASDesc = Attribs.pTLAS->GetDesc();
2187
2188 if (Attribs.InstanceCount > TLASDesc.MaxInstanceCount)
2189 {
2190 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: InstanceCount must be less than or equal to Attribs.pTLAS->GetDesc().MaxInstanceCount");
2191 return false;
2192 }
2193
2194 const auto& InstDesc = Attribs.pInstanceBuffer->GetDesc();
2195 const size_t InstDataSize = Attribs.InstanceCount * TLAS_INSTANCE_DATA_SIZE;
2196
21642197 #ifdef DILIGENT_DEVELOPMENT
2165 const auto& TLASDesc = Attribs.pTLAS->GetDesc();
2166
2167 if (Attribs.InstanceCount > TLASDesc.MaxInstanceCount)
2168 {
2169 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: InstanceCount must be less than or equal to Attribs.pTLAS->GetDesc().MaxInstanceCount");
2170 return false;
2171 }
2172
2173 const auto& InstDesc = Attribs.pInstanceBuffer->GetDesc();
2174 const size_t InstDataSize = Attribs.InstanceCount * TLAS_INSTANCE_DATA_SIZE;
2175 Uint32 AutoOffsetCounter = 0;
2198 Uint32 AutoOffsetCounter = 0;
21762199
21772200 // calculate instance data size
21782201 for (Uint32 i = 0; i < Attribs.InstanceCount; ++i)
22022225 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: exactly all pInstances[i].ContributionToHitGroupIndex must be TLAS_INSTANCE_OFFSET_AUTO or not");
22032226 return false;
22042227 }
2228 #endif // DILIGENT_DEVELOPMENT
22052229
22062230 if (Attribs.InstanceBufferOffset > InstDesc.uiSizeInBytes)
22072231 {
22092233 return false;
22102234 }
22112235
2212 if (InstDesc.uiSizeInBytes - Attribs.InstanceBufferOffset > InstDataSize)
2213 {
2214 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pInstanceaBuffer size is too small, ...");
2236 if (InstDesc.uiSizeInBytes - Attribs.InstanceBufferOffset < InstDataSize)
2237 {
2238 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pInstanceBuffer size is too small, ...");
22152239 return false;
22162240 }
22172241
22182242 if ((InstDesc.BindFlags & BIND_RAY_TRACING) != BIND_RAY_TRACING)
22192243 {
2220 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pInstanceaBuffer must be created with BIND_RAY_TRACING flag");
2244 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pInstanceBuffer must be created with BIND_RAY_TRACING flag");
22212245 return false;
22222246 }
22232247
22292253 return false;
22302254 }
22312255
2232 if (ScratchDesc.uiSizeInBytes - Attribs.ScratchBufferOffset > Attribs.pTLAS->GetScratchBufferSizes().Build)
2256 if (ScratchDesc.uiSizeInBytes - Attribs.ScratchBufferOffset < Attribs.pTLAS->GetScratchBufferSizes().Build)
22332257 {
22342258 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pScratchBuffer size is too small, use pTLAS->GetScratchBufferSizes().Build to get required size for scratch buffer");
22352259 return false;
22402264 LOG_ERROR_MESSAGE("IDeviceContext::BuildTLAS: pScratchBuffer must be created with BIND_RAY_TRACING flag");
22412265 return false;
22422266 }
2243 #endif // DILIGENT_DEVELOPMENT
22442267
22452268 return true;
22462269 }
22572280 if (Attribs.pDst == nullptr)
22582281 {
22592282 LOG_ERROR_MESSAGE("IDeviceContext::CopyBLAS: pDst must not be null");
2283 return false;
2284 }
2285
2286 if (m_pActiveRenderPass != nullptr)
2287 {
2288 LOG_ERROR_MESSAGE("CopyBLAS command must be performed outside of render pass");
22602289 return false;
22612290 }
22622291
23372366 return false;
23382367 }
23392368
2369 if (m_pActiveRenderPass != nullptr)
2370 {
2371 LOG_ERROR_MESSAGE("CopyTLAS command must be performed outside of render pass");
2372 return false;
2373 }
2374
23402375 #ifdef DILIGENT_DEVELOPMENT
2376 if (!ValidatedCast<TopLevelASType>(Attribs.pSrc)->CheckBLASVersion())
2377 {
2378 LOG_ERROR_MESSAGE("IDeviceContext::CopyTLAS: pSrc must be rebuilded to apply BLAS changes before being copied to another TLAS");
2379 return false;
2380 }
2381
23412382 if (Attribs.Mode == COPY_AS_MODE_CLONE)
23422383 {
23432384 auto& SrcDesc = Attribs.pSrc->GetDesc();
23692410 return false;
23702411 }
23712412
2413 #ifdef DILIGENT_DEVELOPMENT
2414 if (!Attribs.pSBT->Verify())
2415 {
2416 LOG_ERROR_MESSAGE("IDeviceContext::TraceRays: pSBT content is not valid");
2417 return false;
2418 }
2419 #endif // DILIGENT_DEVELOPMENT
2420
2421 if (!m_pPipelineState)
2422 {
2423 LOG_ERROR_MESSAGE("IDeviceContext::TraceRays command arguments are invalid: no pipeline state is bound.");
2424 return false;
2425 }
2426
2427 if (!m_pPipelineState->GetDesc().IsRayTracingPipeline())
2428 {
2429 LOG_ERROR_MESSAGE("IDeviceContext::TraceRays command arguments are invalid: pipeline state '", m_pPipelineState->GetDesc().Name, "' is not a ray tracing pipeline.");
2430 return false;
2431 }
2432
2433 if (Attribs.pSBT->GetDesc().pPSO != m_pPipelineState)
2434 {
2435 LOG_ERROR_MESSAGE("IDeviceContext::TraceRays command arguments are invalid: currently bound pipeline ", m_pPipelineState->GetDesc().Name,
2436 "doesn't match the pipeline ", Attribs.pSBT->GetDesc().pPSO->GetDesc().Name, " that was used in ShaderBindingTable");
2437 return false;
2438 }
2439
23722440 if (Attribs.DimensionX == 0)
23732441 LOG_WARNING_MESSAGE("IDeviceContext::TraceRays command arguments are invalid: DimensionX is zero.");
23742442
7878
7979 if ((ShdrDesc.ShaderType == SHADER_TYPE_AMPLIFICATION || ShdrDesc.ShaderType == SHADER_TYPE_MESH) && !deviceFeatures.MeshShaders)
8080 LOG_ERROR_AND_THROW("Mesh shaders are not supported by this device");
81
82 if ((ShdrDesc.ShaderType >= SHADER_TYPE_RAY_GEN && ShdrDesc.ShaderType <= SHADER_TYPE_CALLABLE) && !deviceFeatures.RayTracing)
83 LOG_ERROR_AND_THROW("Ray tracing shaders are not supported by this device");
8184 }
8285
8386 IMPLEMENT_QUERY_INTERFACE_IN_PLACE(IID_Shader, TDeviceObjectBase)
6464 TDeviceObjectBase{pRefCounters, pDevice, Desc, bIsDeviceInternal}
6565 {
6666 ValidateShaderBindingTableDesc(Desc);
67
68 this->m_pPSO = ValidatedCast<PipelineStateImplType>(this->m_Desc.pPSO);
69 this->m_ShaderRecordSize = this->m_pPSO->GetRayTracingPipelineDesc().ShaderRecordSize;
70 this->m_ShaderRecordStride = this->m_ShaderRecordSize + this->m_pDevice->GetShaderGroupHandleSize();
6771 }
6872
6973 ~ShaderBindingTableBase()
7074 {
7175 }
7276
73 void BindRayGenShader(const char* ShaderGroupName, const void* Data, Uint32 DataSize) override final
74 {
75 VERIFY(Data == nullptr && DataSize == 0, "not supported yet");
76
77 this->m_RayGenShaderRecord.resize(this->m_ShaderRecordStride);
78 ValidatedCast<PipelineStateImplType>(this->m_Desc.pPSO)->CopyShaderHandle(ShaderGroupName, this->m_RayGenShaderRecord.data(), this->m_ShaderRecordStride);
79 this->m_Changed = true;
80 }
81
82 void BindMissShader(const char* ShaderGroupName, Uint32 MissIndex, const void* Data, Uint32 DataSize) override final
83 {
84 VERIFY(Data == nullptr && DataSize == 0, "not supported yet");
85
86 const Uint32 Offset = MissIndex * this->m_ShaderRecordStride;
87 this->m_MissShadersRecord.resize(std::max<size_t>(this->m_MissShadersRecord.size(), Offset + this->m_ShaderRecordStride));
88
89 ValidatedCast<PipelineStateImplType>(this->m_Desc.pPSO)->CopyShaderHandle(ShaderGroupName, this->m_MissShadersRecord.data() + Offset, this->m_ShaderRecordStride);
90 this->m_Changed = true;
91 }
92
93 void BindHitGroup(ITopLevelAS* pTLAS,
94 const char* InstanceName,
95 const char* GeometryName,
96 Uint32 RayOffsetInHitGroupIndex,
97 const char* ShaderGroupName,
98 const void* Data,
99 Uint32 DataSize) override final
100 {
101 VERIFY(Data == nullptr && DataSize == 0, "not supported yet");
77 void DILIGENT_CALL_TYPE Reset(const ShaderBindingTableDesc& Desc) override final
78 {
79 this->m_RayGenShaderRecord.clear();
80 this->m_MissShadersRecord.clear();
81 this->m_CallableShadersRecord.clear();
82 this->m_HitGroupsRecord.clear();
83 this->m_Changed = true;
84 this->m_pPSO = nullptr;
85 this->m_Desc = {};
86
87 try
88 {
89 ValidateShaderBindingTableDesc(Desc);
90 }
91 catch (const std::runtime_error&)
92 {
93 return;
94 }
95
96 this->m_Desc = Desc;
97 this->m_pPSO = ValidatedCast<PipelineStateImplType>(this->m_Desc.pPSO);
98 this->m_ShaderRecordSize = this->m_pPSO->GetRayTracingPipelineDesc().ShaderRecordSize;
99 this->m_ShaderRecordStride = this->m_ShaderRecordSize + this->m_pDevice->GetShaderGroupHandleSize();
100 }
101
102 void DILIGENT_CALL_TYPE BindRayGenShader(const char* ShaderGroupName, const void* Data, Uint32 DataSize) override final
103 {
104 VERIFY_EXPR((Data == nullptr) == (DataSize == 0));
105 VERIFY_EXPR(Data == nullptr || (DataSize == this->m_ShaderRecordSize));
106
107 this->m_RayGenShaderRecord.resize(this->m_ShaderRecordStride, EmptyElem);
108 this->m_pPSO->CopyShaderHandle(ShaderGroupName, this->m_RayGenShaderRecord.data(), this->m_ShaderRecordStride);
109
110 const Uint32 GroupSize = this->m_pDevice->GetShaderGroupHandleSize();
111 std::memcpy(this->m_RayGenShaderRecord.data() + GroupSize, Data, DataSize);
112 this->m_Changed = true;
113 }
114
115 void DILIGENT_CALL_TYPE BindMissShader(const char* ShaderGroupName, Uint32 MissIndex, const void* Data, Uint32 DataSize) override final
116 {
117 VERIFY_EXPR((Data == nullptr) == (DataSize == 0));
118 VERIFY_EXPR(Data == nullptr || (DataSize == this->m_ShaderRecordSize));
119
120 const Uint32 GroupSize = this->m_pDevice->GetShaderGroupHandleSize();
121 const Uint32 Offset = MissIndex * this->m_ShaderRecordStride;
122 this->m_MissShadersRecord.resize(std::max<size_t>(this->m_MissShadersRecord.size(), Offset + this->m_ShaderRecordStride), EmptyElem);
123
124 this->m_pPSO->CopyShaderHandle(ShaderGroupName, this->m_MissShadersRecord.data() + Offset, this->m_ShaderRecordStride);
125 std::memcpy(this->m_MissShadersRecord.data() + Offset + GroupSize, Data, DataSize);
126 this->m_Changed = true;
127 }
128
129 void DILIGENT_CALL_TYPE BindHitGroup(ITopLevelAS* pTLAS,
130 const char* InstanceName,
131 const char* GeometryName,
132 Uint32 RayOffsetInHitGroupIndex,
133 const char* ShaderGroupName,
134 const void* Data,
135 Uint32 DataSize) override final
136 {
137 VERIFY_EXPR((Data == nullptr) == (DataSize == 0));
138 VERIFY_EXPR(Data == nullptr || (DataSize == this->m_ShaderRecordSize));
102139 VERIFY_EXPR(pTLAS != nullptr);
103140 VERIFY_EXPR(RayOffsetInHitGroupIndex < this->m_Desc.HitShadersPerInstance);
104141 VERIFY_EXPR(pTLAS->GetDesc().BindingMode == SHADER_BINDING_MODE_PER_GEOMETRY);
110147 const Uint32 GeometryIndex = Desc.pBLAS->GetGeometryIndex(GeometryName);
111148 const Uint32 Index = InstanceIndex + GeometryIndex * this->m_Desc.HitShadersPerInstance + RayOffsetInHitGroupIndex;
112149 const Uint32 Offset = Index * this->m_ShaderRecordStride;
113
114 this->m_HitGroupsRecord.resize(std::max<size_t>(this->m_HitGroupsRecord.size(), Offset + this->m_ShaderRecordStride));
115
116 ValidatedCast<PipelineStateImplType>(this->m_Desc.pPSO)->CopyShaderHandle(ShaderGroupName, this->m_HitGroupsRecord.data() + Offset, this->m_ShaderRecordStride);
117 this->m_Changed = true;
118 }
119
120 void BindHitGroups(ITopLevelAS* pTLAS,
121 const char* InstanceName,
122 Uint32 RayOffsetInHitGroupIndex,
123 const char* ShaderGroupName,
124 const void* Data,
125 Uint32 DataSize) override final
126 {
127 VERIFY(Data == nullptr && DataSize == 0, "not supported yet");
150 const Uint32 GroupSize = this->m_pDevice->GetShaderGroupHandleSize();
151
152 this->m_HitGroupsRecord.resize(std::max<size_t>(this->m_HitGroupsRecord.size(), Offset + this->m_ShaderRecordStride), EmptyElem);
153
154 this->m_pPSO->CopyShaderHandle(ShaderGroupName, this->m_HitGroupsRecord.data() + Offset, this->m_ShaderRecordStride);
155 std::memcpy(this->m_HitGroupsRecord.data() + Offset + GroupSize, Data, DataSize);
156 this->m_Changed = true;
157 }
158
159 void DILIGENT_CALL_TYPE BindHitGroups(ITopLevelAS* pTLAS,
160 const char* InstanceName,
161 Uint32 RayOffsetInHitGroupIndex,
162 const char* ShaderGroupName,
163 const void* Data,
164 Uint32 DataSize) override final
165 {
166 VERIFY_EXPR((Data == nullptr) == (DataSize == 0));
128167 VERIFY_EXPR(pTLAS != nullptr);
129168 VERIFY_EXPR(RayOffsetInHitGroupIndex < this->m_Desc.HitShadersPerInstance);
130169 VERIFY_EXPR(pTLAS->GetDesc().BindingMode == SHADER_BINDING_MODE_PER_GEOMETRY ||
133172 const auto Desc = pTLAS->GetInstanceDesc(InstanceName);
134173 VERIFY_EXPR(Desc.pBLAS != nullptr);
135174
136 const Uint32 InstanceIndex = Desc.ContributionToHitGroupIndex;
137 const auto& GeometryDesc = Desc.pBLAS->GetDesc();
138 const Uint32 GeometryCount = GeometryDesc.BoxCount + GeometryDesc.TriangleCount;
139 const Uint32 BeginIndex = InstanceIndex + 0 * this->m_Desc.HitShadersPerInstance + RayOffsetInHitGroupIndex;
140 const Uint32 EndIndex = InstanceIndex + GeometryCount * this->m_Desc.HitShadersPerInstance + RayOffsetInHitGroupIndex;
141 PipelineStateImplType* pPSO = ValidatedCast<PipelineStateImplType>(this->m_Desc.pPSO);
142
143 this->m_HitGroupsRecord.resize(std::max<size_t>(this->m_HitGroupsRecord.size(), EndIndex * this->m_ShaderRecordStride));
175 const Uint32 InstanceIndex = Desc.ContributionToHitGroupIndex;
176 const auto& GeometryDesc = Desc.pBLAS->GetDesc();
177 Uint32 GeometryCount = 0;
178
179 switch (pTLAS->GetDesc().BindingMode)
180 {
181 // clang-format off
182 case SHADER_BINDING_MODE_PER_GEOMETRY: GeometryCount = GeometryDesc.BoxCount + GeometryDesc.TriangleCount; break;
183 case SHADER_BINDING_MODE_PER_INSTANCE: GeometryCount = 1; break;
184 default: UNEXPECTED("unknown binding mode");
185 // clang-format on
186 }
187
188 VERIFY_EXPR(Data == nullptr || (DataSize == this->m_ShaderRecordSize * GeometryCount));
189
190 const Uint32 BeginIndex = InstanceIndex + 0 * this->m_Desc.HitShadersPerInstance + RayOffsetInHitGroupIndex;
191 const Uint32 EndIndex = InstanceIndex + GeometryCount * this->m_Desc.HitShadersPerInstance + RayOffsetInHitGroupIndex;
192 const Uint32 GroupSize = this->m_pDevice->GetShaderGroupHandleSize();
193 const auto* DataPtr = static_cast<const Uint8*>(Data);
194
195 this->m_HitGroupsRecord.resize(std::max<size_t>(this->m_HitGroupsRecord.size(), EndIndex * this->m_ShaderRecordStride), EmptyElem);
144196
145197 for (Uint32 i = 0; i < GeometryCount; ++i)
146198 {
147199 Uint32 Offset = (BeginIndex + i) * this->m_ShaderRecordStride;
148 pPSO->CopyShaderHandle(ShaderGroupName, this->m_HitGroupsRecord.data() + Offset, this->m_ShaderRecordStride);
149 }
150 this->m_Changed = true;
151 }
152
153 void BindCallableShader(const char* ShaderGroupName,
154 Uint32 CallableIndex,
155 const void* Data,
156 Uint32 DataSize) override final
157 {
158 VERIFY(Data == nullptr && DataSize == 0, "not supported yet");
159
160 const Uint32 Offset = CallableIndex * this->m_ShaderRecordStride;
161 this->m_CallableShadersRecord.resize(std::max<size_t>(this->m_CallableShadersRecord.size(), Offset + this->m_ShaderRecordStride));
162
163 ValidatedCast<PipelineStateImplType>(this->m_Desc.pPSO)->CopyShaderHandle(ShaderGroupName, this->m_CallableShadersRecord.data() + Offset, this->m_ShaderRecordStride);
164 this->m_Changed = true;
200 this->m_pPSO->CopyShaderHandle(ShaderGroupName, this->m_HitGroupsRecord.data() + Offset, this->m_ShaderRecordStride);
201
202 std::memcpy(this->m_HitGroupsRecord.data() + Offset + GroupSize, DataPtr, this->m_ShaderRecordSize);
203 DataPtr += this->m_ShaderRecordSize;
204 }
205 this->m_Changed = true;
206 }
207
208 void DILIGENT_CALL_TYPE BindCallableShader(const char* ShaderGroupName,
209 Uint32 CallableIndex,
210 const void* Data,
211 Uint32 DataSize) override final
212 {
213 VERIFY_EXPR((Data == nullptr) == (DataSize == 0));
214 VERIFY_EXPR(Data == nullptr || (DataSize == this->m_ShaderRecordSize));
215
216 const Uint32 GroupSize = this->m_pDevice->GetShaderGroupHandleSize();
217 const Uint32 Offset = CallableIndex * this->m_ShaderRecordStride;
218 this->m_CallableShadersRecord.resize(std::max<size_t>(this->m_CallableShadersRecord.size(), Offset + this->m_ShaderRecordStride), EmptyElem);
219
220 this->m_pPSO->CopyShaderHandle(ShaderGroupName, this->m_CallableShadersRecord.data() + Offset, this->m_ShaderRecordStride);
221 std::memcpy(this->m_CallableShadersRecord.data() + Offset + GroupSize, Data, DataSize);
222 this->m_Changed = true;
223 }
224
225 Bool DILIGENT_CALL_TYPE Verify() const override final
226 {
227 // AZ TODO
228 return true;
165229 }
166230
167231 protected:
168 static void ValidateShaderBindingTableDesc(const ShaderBindingTableDesc& Desc)
232 void ValidateShaderBindingTableDesc(const ShaderBindingTableDesc& Desc) const
169233 {
170234 #define LOG_SBT_ERROR_AND_THROW(...) LOG_ERROR_AND_THROW("Description of Shader binding table '", (Desc.Name ? Desc.Name : ""), "' is invalid: ", ##__VA_ARGS__)
171235
179243 LOG_SBT_ERROR_AND_THROW("pPSO must be ray tracing pipeline");
180244 }
181245
246 const auto ShaderGroupHandleSize = this->m_pDevice->GetShaderGroupHandleSize();
247 const auto MaxShaderRecordStride = this->m_pDevice->GetMaxShaderRecordStride();
248 const auto ShaderRecordSize = Desc.pPSO->GetRayTracingPipelineDesc().ShaderRecordSize;
249 const auto ShaderRecordStride = ShaderRecordSize + ShaderGroupHandleSize;
250
251 if (ShaderRecordStride > MaxShaderRecordStride)
252 {
253 LOG_SBT_ERROR_AND_THROW("ShaderRecordSize(", ShaderRecordSize, ") is too big, max size is: ", MaxShaderRecordStride - ShaderGroupHandleSize);
254 }
255
256 if (ShaderRecordStride % ShaderGroupHandleSize != 0)
257 {
258 LOG_SBT_ERROR_AND_THROW("ShaderRecordSize(", ShaderRecordSize, ") plus ShaderGroupHandleSize(", ShaderGroupHandleSize, ") must be multiple of ", ShaderGroupHandleSize);
259 }
182260 #undef LOG_SBT_ERROR_AND_THROW
183261 }
184262
190268 std::vector<Uint8> m_CallableShadersRecord;
191269 std::vector<Uint8> m_HitGroupsRecord;
192270
271 RefCntAutoPtr<PipelineStateImplType> m_pPSO;
272
273 Uint32 m_ShaderRecordSize = 0;
193274 Uint32 m_ShaderRecordStride = 0;
194275 bool m_Changed = true;
276
277 static const Uint8 EmptyElem = 0xA7;
195278 };
196279
197280 } // namespace Diligent
4646 /// (Diligent::ITopLevelASD3D12 or Diligent::ITopLevelASVk).
4747 /// \tparam RenderDeviceImplType - type of the render device implementation
4848 /// (Diligent::RenderDeviceD3D12Impl or Diligent::RenderDeviceVkImpl)
49 template <class BaseInterface, class RenderDeviceImplType>
49 template <class BaseInterface, class BottomLevelASType, class RenderDeviceImplType>
5050 class TopLevelASBase : public DeviceObjectBase<BaseInterface, RenderDeviceImplType, TopLevelASDesc>
5151 {
5252 public:
7272
7373 void SetInstanceData(const TLASBuildInstanceData* pInstances, Uint32 InstanceCount, Uint32 HitShadersPerInstance)
7474 {
75 m_Instances.clear();
76 m_StringPool.Release();
75 this->m_Instances.clear();
76 this->m_StringPool.Release();
77 this->m_HitShadersPerInstance = HitShadersPerInstance;
7778
7879 size_t StringPoolSize = 0;
7980 for (Uint32 i = 0; i < InstanceCount; ++i)
8182 StringPoolSize += strlen(pInstances[i].InstanceName) + 1;
8283 }
8384
84 m_StringPool.Reserve(StringPoolSize, GetRawAllocator());
85 this->m_StringPool.Reserve(StringPoolSize, GetRawAllocator());
8586
8687 Uint32 InstanceOffset = 0;
8788
8889 for (Uint32 i = 0; i < InstanceCount; ++i)
8990 {
9091 auto& inst = pInstances[i];
91 const char* NameCopy = m_StringPool.CopyString(inst.InstanceName);
92 const char* NameCopy = this->m_StringPool.CopyString(inst.InstanceName);
9293 InstanceDesc Desc = {};
9394
9495 Desc.ContributionToHitGroupIndex = inst.ContributionToHitGroupIndex;
95 Desc.pBLAS = inst.pBLAS;
96 Desc.pBLAS = ValidatedCast<BottomLevelASType>(inst.pBLAS);
97
98 #ifdef DILIGENT_DEVELOPMENT
99 Desc.Version = Desc.pBLAS->GetVersion();
100 #endif
96101
97102 if (Desc.ContributionToHitGroupIndex == TLAS_INSTANCE_OFFSET_AUTO)
98103 {
99104 Desc.ContributionToHitGroupIndex = InstanceOffset;
100105 auto& BLASDesc = Desc.pBLAS->GetDesc();
101 InstanceOffset += (BLASDesc.TriangleCount + BLASDesc.BoxCount) * HitShadersPerInstance;
106 switch (this->m_Desc.BindingMode)
107 {
108 // clang-format off
109 case SHADER_BINDING_MODE_PER_GEOMETRY: InstanceOffset += (BLASDesc.TriangleCount + BLASDesc.BoxCount) * HitShadersPerInstance; break;
110 case SHADER_BINDING_MODE_PER_INSTANCE: InstanceOffset += HitShadersPerInstance; break;
111 case SHADER_BINDING_USER_DEFINED: UNEXPECTED("TLAS_INSTANCE_OFFSET_AUTO is not compatible with SHADER_BINDING_USER_DEFINED"); break;
112 default: UNEXPECTED("unknown ray tracing shader binding mode");
113 // clang-format on
114 }
102115 }
103116
104 bool IsUniqueName = m_Instances.emplace(NameCopy, Desc).second;
117 bool IsUniqueName = this->m_Instances.emplace(NameCopy, Desc).second;
105118 if (!IsUniqueName)
106119 LOG_ERROR_AND_THROW("Instance name must be unique!");
107120 }
121
122 VERIFY_EXPR(this->m_StringPool.GetRemainingSize() == 0);
123 }
124
125 void CopyInstancceData(const TopLevelASBase& Src)
126 {
127 this->m_Instances.clear();
128 this->m_StringPool.Release();
129 this->m_StringPool.Reserve(Src.m_StringPool.GetReservedSize(), GetRawAllocator());
130 this->m_HitShadersPerInstance = Src.m_HitShadersPerInstance;
131 this->m_Desc.BindingMode = Src.m_Desc.BindingMode;
132
133 for (auto& SrcInst : Src.m_Instances)
134 {
135 const char* NameCopy = this->m_StringPool.CopyString(SrcInst.first.GetStr());
136 this->m_Instances.emplace(NameCopy, SrcInst.second);
137 }
138
139 VERIFY_EXPR(this->m_StringPool.GetRemainingSize() == 0);
108140 }
109141
110142 virtual TLASInstanceDesc DILIGENT_CALL_TYPE GetInstanceDesc(const char* Name) const override final
113145
114146 TLASInstanceDesc Result = {};
115147
116 auto iter = m_Instances.find(Name);
117 if (iter != m_Instances.end())
148 auto iter = this->m_Instances.find(Name);
149 if (iter != this->m_Instances.end())
118150 {
119151 Result.ContributionToHitGroupIndex = iter->second.ContributionToHitGroupIndex;
120 Result.pBLAS = iter->second.pBLAS;
152 Result.pBLAS = iter->second.pBLAS.RawPtr<IBottomLevelAS>();
121153 }
122154 else
123155 {
149181 return (this->m_State & State) == State;
150182 }
151183
184 #ifdef DILIGENT_DEVELOPMENT
185 bool CheckBLASVersion() const
186 {
187 for (auto& NameAndInst : m_Instances)
188 {
189 auto& Inst = NameAndInst.second;
190 if (Inst.Version != Inst.pBLAS->GetVersion())
191 {
192 LOG_ERROR_MESSAGE("Instance with name ('", NameAndInst.first.GetStr(), "') has BLAS that was changed after TLAS build, you must rebuild TLAS.");
193 return false;
194 }
195 }
196 return true;
197 }
198 #endif
199
152200 protected:
153201 static void ValidateTopLevelASDesc(const TopLevelASDesc& Desc)
154202 {
171219 IMPLEMENT_QUERY_INTERFACE_IN_PLACE(IID_TopLevelAS, TDeviceObjectBase)
172220
173221 protected:
174 RESOURCE_STATE m_State = RESOURCE_STATE_UNKNOWN;
222 RESOURCE_STATE m_State = RESOURCE_STATE_UNKNOWN;
223 Uint32 m_HitShadersPerInstance = 0;
175224
176225 StringPool m_StringPool;
177226
178227 struct InstanceDesc
179228 {
180 Uint32 ContributionToHitGroupIndex = 0;
181 mutable RefCntAutoPtr<IBottomLevelAS> pBLAS;
229 Uint32 ContributionToHitGroupIndex = 0;
230 RefCntAutoPtr<BottomLevelASType> pBLAS;
231
232 #ifdef DILIGENT_DEVELOPMENT
233 Uint32 Version = 0;
234 #endif
182235 };
183236 std::unordered_map<HashMapStringKey, InstanceDesc, HashMapStringKey::Hasher> m_Instances;
184237 };
740740 /// geometries referenced by this instance. This behavior can be overridden by the SPIR-V OpaqueKHR ray flag.
741741 RAYTRACING_INSTANCE_FORCE_NO_OPAQUE = 0x08,
742742
743 RAYTRACING_INSTANCE_FLAGS_LAST = 0x08
743 RAYTRACING_INSTANCE_FLAGS_LAST = RAYTRACING_INSTANCE_FORCE_NO_OPAQUE
744744 };
745745 DEFINE_FLAG_ENUM_OPERATORS(RAYTRACING_INSTANCE_FLAGS)
746746
756756 // after the build of the acceleration structure specified by src.
757757 //COPY_AS_MODE_COMPACT,
758758
759 COPY_AS_MODE_LAST = 0,
759 COPY_AS_MODE_LAST = COPY_AS_MODE_CLONE,
760760 };
761761
762762 /// Defines geometry flags for ray tracing.
774774 /// If this bit is absent an implementation may invoke the any-hit shader more than once for this geometry.
775775 RAYTRACING_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION = 0x02,
776776
777 RAYTRACING_GEOMETRY_FLAGS_LAST = 0x02
777 RAYTRACING_GEOMETRY_FLAGS_LAST = RAYTRACING_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION
778778 };
779779 DEFINE_FLAG_ENUM_OPERATORS(RAYTRACING_GEOMETRY_FLAGS)
780780
909909 /// AZ TODO
910910 static const Uint32 TLAS_INSTANCE_DATA_SIZE = 64;
911911
912 /// AZ TODO
913 struct InstanceMatrix
914 {
915 /// rotation translation
916 /// (0 1 2) [ 3]
917 /// (4 5 6) [ 7]
918 /// (8 9 10) [11]
919 float data [3][4];
920
921 #if DILIGENT_CPP_INTERFACE
922 /// AZ TODO
923 InstanceMatrix() noexcept :
924 data{{1.0f, 0.0f, 0.0f, 0.0f},
925 {0.0f, 1.0f, 0.0f, 0.0f},
926 {0.0f, 0.0f, 1.0f, 0.0f}}
927 {}
928
929 InstanceMatrix(const InstanceMatrix&) noexcept = default;
930
931 InstanceMatrix& SetTranslation(float x, float y, float z) noexcept
932 {
933 data[0][3] = x;
934 data[1][3] = y;
935 data[2][3] = z;
936 return *this;
937 }
938 #endif
939 };
940 typedef struct InstanceMatrix InstanceMatrix;
912941
913942 /// AZ TODO
914943 struct TLASBuildInstanceData
920949 IBottomLevelAS* pBLAS DEFAULT_INITIALIZER(nullptr); // can be null to deactive instance
921950
922951 /// AZ TODO
923 float Transform[3][4] DEFAULT_INITIALIZER({});
952 InstanceMatrix Transform;
924953
925954 /// AZ TODO
926955 Uint32 CustomId DEFAULT_INITIALIZER(0); // 24 bits, in shader: gl_InstanceCustomIndexNV for GLSL, InstanceID() for HLSL
298298 /// AZ TODO
299299 struct RayTracingPipelineDesc
300300 {
301 /// AZ TODO
302 Uint8 MaxRecursionDepth DEFAULT_INITIALIZER(0); // must be 0..31 (check current device limits)
301 // Size of the additional data passed to the shader.
302 Uint16 ShaderRecordSize DEFAULT_INITIALIZER(0);
303
304 /// AZ TODO
305 Uint8 MaxRecursionDepth DEFAULT_INITIALIZER(0); // must be 0..31 (check current device limits)
303306 };
304307 typedef struct RayTracingPipelineDesc RayTracingPipelineDesc;
305308
437440 struct RayTracingPipelineStateCreateInfo DILIGENT_DERIVE(PipelineStateCreateInfo)
438441
439442 /// AZ TODO
440 RayTracingPipelineDesc RayTracingPipeline;
443 RayTracingPipelineDesc RayTracingPipeline;
441444
442445 /// AZ TODO
443446 const RayTracingGeneralShaderGroup* pGeneralShaders DEFAULT_INITIALIZER(nullptr);
456459
457460 /// AZ TODO
458461 Uint16 ProceduralHitShaderCount DEFAULT_INITIALIZER(0);
462
463 /// Direct3D12 only: set name of constant buffer that will be used by local root signature.
464 /// Ignored if RayTracingPipelineDesc::ShaderRecordSize is zero.
465 const char* ShaderRecordName DEFAULT_INITIALIZER(nullptr);
459466 };
460467 typedef struct RayTracingPipelineStateCreateInfo RayTracingPipelineStateCreateInfo;
461468
5050
5151 /// AZ TODO
5252 IPipelineState* pPSO DEFAULT_INITIALIZER(nullptr);
53
54 // Size of the additional data passed to the shader, maximum size is 4064 bytes.
55 Uint32 ShaderRecordSize DEFAULT_INITIALIZER(0);
5653
5754 /// AZ TODO
5855 Uint32 HitShadersPerInstance DEFAULT_INITIALIZER(1);
113110 #endif
114111
115112 /// AZ TODO
116 VIRTUAL void METHOD(Verify)(THIS) CONST PURE;
113 VIRTUAL Bool METHOD(Verify)(THIS) CONST PURE;
117114
118115 /// AZ TODO
119116 VIRTUAL void METHOD(Reset)(THIS_
6060 using FramebufferType = FramebufferD3D11Impl;
6161 using RenderPassType = RenderPassD3D11Impl;
6262 using BottomLevelASType = BottomLevelASBase<IBottomLevelAS, RenderDeviceD3D11Impl>;
63 using TopLevelASType = TopLevelASBase<ITopLevelAS, RenderDeviceD3D11Impl>;
63 using TopLevelASType = TopLevelASBase<ITopLevelAS, BottomLevelASType, RenderDeviceD3D11Impl>;
6464 };
6565
6666 /// Device context implementation in Direct3D11 backend.
143143
144144 void Destruct();
145145
146 CComPtr<ID3D12DeviceChild> m_pd3d12PSO;
147 RootSignature m_RootSig;
146 void CreateLocalRootSignature(const RayTracingPipelineDesc& Desc);
147
148 CComPtr<ID3D12DeviceChild> m_pd3d12PSO;
149 RootSignature m_RootSig;
150 CComPtr<ID3D12RootSignature> m_LocalRootSignature;
148151
149152 // Must be defined before default SRB
150153 SRBMemoryAllocator m_SRBMemAllocator;
177177 ShaderVersion GetMaxShaderModel() const;
178178 D3D_FEATURE_LEVEL GetD3DFeatureLevel() const;
179179
180 static Uint32 GetShaderGroupHandleSize()
181 {
182 return D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
183 }
180 static Uint32 GetShaderGroupHandleSize() { return D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; }
181 static Uint32 GetMaxShaderRecordStride() { return D3D12_RAYTRACING_MAX_SHADER_RECORD_STRIDE; }
184182
185183 private:
186184 template <typename PSOCreateInfoType>
5353
5454 virtual void DILIGENT_CALL_TYPE QueryInterface(const INTERFACE_ID& IID, IObject** ppInterface) override final;
5555
56 virtual void DILIGENT_CALL_TYPE Verify() const override;
57
58 virtual void DILIGENT_CALL_TYPE Reset(const ShaderBindingTableDesc& Desc) override;
59
6056 virtual void DILIGENT_CALL_TYPE ResetHitGroups(Uint32 HitShadersPerInstance) override;
6157
6258 virtual void DILIGENT_CALL_TYPE BindAll(const BindAllAttribs& Attribs) override;
6965 D3D12_GPU_VIRTUAL_ADDRESS_RANGE_AND_STRIDE& CallableShaderBindingTable) override;
7066
7167 private:
72 void ValidateDesc(const ShaderBindingTableDesc& Desc) const;
73
74 private:
7568 RefCntAutoPtr<IBuffer> m_pBuffer;
7669 };
7770
3232 #include "TopLevelASD3D12.h"
3333 #include "RenderDeviceD3D12.h"
3434 #include "TopLevelASBase.hpp"
35 #include "BottomLevelASD3D12Impl.hpp"
3536 #include "D3D12ResourceBase.hpp"
3637 #include "RenderDeviceD3D12Impl.hpp"
3738
3940 {
4041
4142 /// Top-level acceleration structure object implementation in Direct3D12 backend.
42 class TopLevelASD3D12Impl final : public TopLevelASBase<ITopLevelASD3D12, RenderDeviceD3D12Impl>, public D3D12ResourceBase
43 class TopLevelASD3D12Impl final : public TopLevelASBase<ITopLevelASD3D12, BottomLevelASD3D12Impl, RenderDeviceD3D12Impl>, public D3D12ResourceBase
4344 {
4445 public:
45 using TTopLevelASBase = TopLevelASBase<ITopLevelASD3D12, RenderDeviceD3D12Impl>;
46 using TTopLevelASBase = TopLevelASBase<ITopLevelASD3D12, BottomLevelASD3D12Impl, RenderDeviceD3D12Impl>;
4647
4748 TopLevelASD3D12Impl(IReferenceCounters* pRefCounters,
4849 class RenderDeviceD3D12Impl* pDeviceD3D12,
6464
6565 #if DILIGENT_C_INTERFACE
6666
67 # define IShaderBindingTableD3D12_GetD3D12AddressRangeAndStride(This, ...) CALL_IFACE_METHOD(ShaderBindingTableD3D12, GetD3D12AddressRangeAndStride, This, __VA_ARGS__)
6768
6869 #endif
6970
21972197 RESOURCE_STATE RequiredState,
21982198 const char* OperationName)
21992199 {
2200 // AZ TODO: transit BLAS state too?
2201
22022200 if (TransitionMode == RESOURCE_STATE_TRANSITION_MODE_TRANSITION)
22032201 {
22042202 if (TLAS.IsInKnownState() && !TLAS.CheckState(RequiredState))
22082206 else if (TransitionMode == RESOURCE_STATE_TRANSITION_MODE_VERIFY)
22092207 {
22102208 DvpVerifyTLASState(TLAS, RequiredState, OperationName);
2209 }
2210
2211 if (RequiredState & RESOURCE_STATE_RAY_TRACING)
2212 {
2213 TLAS.CheckBLASVersion();
22112214 }
22122215 #endif
22132216 }
23162319 d3d12Tris.VertexCount = SrcTris.VertexCount;
23172320 d3d12Tris.VertexBuffer.StartAddress = pVB->GetGPUAddress() + SrcTris.VertexOffset;
23182321 d3d12Tris.VertexBuffer.StrideInBytes = SrcTris.VertexStride;
2322
2323 TransitionOrVerifyBufferState(CmdCtx, *pVB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, OpName);
23192324
23202325 if (SrcTris.pIndexBuffer)
23212326 {
23882393
23892394 CmdCtx.AsGraphicsContext4().BuildRaytracingAccelerationStructure(Desc, 0, nullptr);
23902395 ++m_State.NumCommands;
2396
2397 #ifdef DILIGENT_DEVELOPMENT
2398 pBLASD12->UpdateVersion();
2399 #endif
23912400 }
23922401
23932402 void DeviceContextD3D12Impl::BuildTLAS(const TLASBuildAttribs& Attribs)
24002409 auto* pTLASD12 = ValidatedCast<TopLevelASD3D12Impl>(Attribs.pTLAS);
24012410 auto* pScratchD12 = ValidatedCast<BufferD3D12Impl>(Attribs.pScratchBuffer);
24022411 auto* pInstancesD12 = ValidatedCast<BufferD3D12Impl>(Attribs.pInstanceBuffer);
2403 //auto& TLASDesc = pTLASD12->GetDesc();
24042412
24052413 auto& CmdCtx = GetCmdContext();
24062414 const char* OpName = "Build TopLevelAS (DeviceContextD3D12Impl::BuildTLAS)";
24222430 auto* const pBLASD12 = ValidatedCast<BottomLevelASD3D12Impl>(Inst.pBLAS);
24232431
24242432 static_assert(sizeof(d3d12Inst.Transform) == sizeof(Inst.Transform), "size mismatch");
2425 std::memcpy(&d3d12Inst.Transform, Inst.Transform, sizeof(d3d12Inst.Transform));
2433 std::memcpy(&d3d12Inst.Transform, Inst.Transform.data, sizeof(d3d12Inst.Transform));
24262434
24272435 d3d12Inst.InstanceID = Inst.CustomId;
24282436 d3d12Inst.InstanceContributionToHitGroupIndex = pTLASD12->GetInstanceDesc(Inst.InstanceName).ContributionToHitGroupIndex; // AZ TODO: optimize
24572465 if (!TDeviceContextBase::CopyBLAS(Attribs, 0))
24582466 return;
24592467
2460 // AZ TODO
2468 auto* pSrcD3D12 = ValidatedCast<BottomLevelASD3D12Impl>(Attribs.pSrc);
2469 auto* pDstD3D12 = ValidatedCast<BottomLevelASD3D12Impl>(Attribs.pDst);
2470 auto& CmdCtx = GetCmdContext();
2471
2472 const char* OpName = "Copy BottomLevelAS (DeviceContextD3D12Impl::CopyBLAS)";
2473 TransitionOrVerifyBLASState(CmdCtx, *pSrcD3D12, Attribs.TransitionMode, RESOURCE_STATE_BUILD_AS_READ, OpName);
2474 TransitionOrVerifyBLASState(CmdCtx, *pDstD3D12, Attribs.TransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, OpName);
2475
2476 CmdCtx.AsGraphicsContext4().CopyRaytracingAccelerationStructure(pSrcD3D12->GetGPUAddress(), pDstD3D12->GetGPUAddress(), D3D12_RAYTRACING_ACCELERATION_STRUCTURE_COPY_MODE_CLONE);
2477 ++m_State.NumCommands;
2478
2479 #ifdef DILIGENT_DEVELOPMENT
2480 pDstD3D12->UpdateVersion();
2481 #endif
24612482 }
24622483
24632484 void DeviceContextD3D12Impl::CopyTLAS(const CopyTLASAttribs& Attribs)
24652486 if (!TDeviceContextBase::CopyTLAS(Attribs, 0))
24662487 return;
24672488
2468 // AZ TODO
2489 auto* pSrcD3D12 = ValidatedCast<TopLevelASD3D12Impl>(Attribs.pSrc);
2490 auto* pDstD3D12 = ValidatedCast<TopLevelASD3D12Impl>(Attribs.pDst);
2491 auto& CmdCtx = GetCmdContext();
2492
2493 pDstD3D12->CopyInstancceData(*pSrcD3D12);
2494
2495 const char* OpName = "Copy BottomLevelAS (DeviceContextD3D12Impl::CopyTLAS)";
2496 TransitionOrVerifyTLASState(CmdCtx, *pSrcD3D12, Attribs.TransitionMode, RESOURCE_STATE_BUILD_AS_READ, OpName);
2497 TransitionOrVerifyTLASState(CmdCtx, *pDstD3D12, Attribs.TransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, OpName);
2498
2499 CmdCtx.AsGraphicsContext4().CopyRaytracingAccelerationStructure(pSrcD3D12->GetGPUAddress(), pDstD3D12->GetGPUAddress(), D3D12_RAYTRACING_ACCELERATION_STRUCTURE_COPY_MODE_CLONE);
2500 ++m_State.NumCommands;
24692501 }
24702502
24712503 void DeviceContextD3D12Impl::TraceRays(const TraceRaysAttribs& Attribs)
223223 }
224224
225225 template <typename TNameToGroupIndexMap>
226 void GetShaderIdentifiers(ID3D12StateObject* pSO,
226 void GetShaderIdentifiers(ID3D12DeviceChild* pSO,
227227 const RayTracingPipelineStateCreateInfo& CreateInfo,
228228 const TNameToGroupIndexMap& NameToGroupIndex,
229229 Uint8* ShaderData)
624624 {
625625 try
626626 {
627 CreateLocalRootSignature(CreateInfo.RayTracingPipeline);
628
627629 TShaderStages ShaderStages;
628630 std::vector<D3D12_STATE_SUBOBJECT> Subobjects;
629631 DynamicLinearAllocator TempPool{GetRawAllocator(), 4 << 10};
639641 D3D12_GLOBAL_ROOT_SIGNATURE GlobalRoot = {m_RootSig.GetD3D12RootSignature()};
640642 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE, &GlobalRoot});
641643
644 D3D12_LOCAL_ROOT_SIGNATURE LocalRoot = {m_LocalRootSignature};
645 if (m_LocalRootSignature)
646 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_LOCAL_ROOT_SIGNATURE, &LocalRoot});
647
642648 D3D12_STATE_OBJECT_DESC RTPipelineDesc = {};
643649 RTPipelineDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
644650 RTPipelineDesc.NumSubobjects = static_cast<UINT>(Subobjects.size());
645651 RTPipelineDesc.pSubobjects = Subobjects.data();
646652
647 CComPtr<ID3D12StateObject> pSO;
648
649653 auto pd3d12Device = pDeviceD3D12->GetD3D12Device5();
650 HRESULT hr = pd3d12Device->CreateStateObject(&RTPipelineDesc, IID_PPV_ARGS(&pSO));
654 HRESULT hr = pd3d12Device->CreateStateObject(&RTPipelineDesc, IID_PPV_ARGS(&m_pd3d12PSO));
651655 if (FAILED(hr))
652656 LOG_ERROR_AND_THROW("Failed to create ray tracing state object");
653657
654 m_pd3d12PSO = pSO;
655
656 GetShaderIdentifiers(pSO, CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex, m_pRayTracingPipelineData->Shaders);
658 GetShaderIdentifiers(m_pd3d12PSO, CreateInfo, m_pRayTracingPipelineData->NameToGroupIndex, m_pRayTracingPipelineData->Shaders);
657659
658660 if (*m_Desc.Name != 0)
659661 {
669671 Destruct();
670672 throw;
671673 }
674 }
675
676 void PipelineStateD3D12Impl::CreateLocalRootSignature(const RayTracingPipelineDesc& Desc)
677 {
678 // AZ TODO
679 /*if (Desc.ShaderRecordSize == 0)
680 return;
681
682 D3D12_ROOT_SIGNATURE_DESC d3d12RootSignatureDesc = {};
683 D3D12_ROOT_PARAMETER d3d12Params = {};
684
685 d3d12Params.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
686 d3d12Params.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
687 d3d12Params.Constants.Num32BitValues = Desc.ShaderRecordSize / 4;
688 d3d12Params.Constants.RegisterSpace = Desc.LocalRootRegisterSpace;
689 d3d12Params.Constants.ShaderRegister = 0;
690
691 d3d12RootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_LOCAL_ROOT_SIGNATURE;
692 d3d12RootSignatureDesc.NumParameters = 1;
693 d3d12RootSignatureDesc.pParameters = &d3d12Params;
694
695 CComPtr<ID3DBlob> signature;
696 auto hr = D3D12SerializeRootSignature(&d3d12RootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, nullptr);
697 CHECK_D3D_RESULT_THROW(hr, "Failed to serialize root signature");
698
699 auto pd3d12Device = GetDevice()->GetD3D12Device();
700
701 hr = pd3d12Device->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&m_LocalRootSignature));
702 CHECK_D3D_RESULT_THROW(hr, "Failed to create root signature");*/
672703 }
673704
674705 PipelineStateD3D12Impl::~PipelineStateD3D12Impl()
704704 {
705705 VERIFY(RangeType == D3D12_DESCRIPTOR_RANGE_TYPE_SRV, "Unexpected descriptor range type");
706706 auto* pTLASD3D12 = Res.pObject.RawPtr<TopLevelASD3D12Impl>();
707 if (pTLASD3D12->IsInKnownState() && !pTLASD3D12->CheckState(RESOURCE_STATE_RAY_TRACING))
707 if (pTLASD3D12->IsInKnownState())
708708 Ctx.TransitionResource(pTLASD3D12, RESOURCE_STATE_RAY_TRACING);
709709 }
710710 break;
4343 bool bIsDeviceInternal) :
4444 TShaderBindingTableBase{pRefCounters, pDeviceD3D12, Desc, bIsDeviceInternal}
4545 {
46 ValidateDesc(Desc);
47
48 m_ShaderRecordStride = m_Desc.ShaderRecordSize + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
4946 }
5047
5148 ShaderBindingTableD3D12Impl::~ShaderBindingTableD3D12Impl()
5350 }
5451
5552 IMPLEMENT_QUERY_INTERFACE(ShaderBindingTableD3D12Impl, IID_ShaderBindingTableD3D12, TShaderBindingTableBase)
56
57 void ShaderBindingTableD3D12Impl::ValidateDesc(const ShaderBindingTableDesc& Desc) const
58 {
59 if (Desc.ShaderRecordSize + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES > D3D12_RAYTRACING_MAX_SHADER_RECORD_STRIDE)
60 {
61 LOG_ERROR_AND_THROW("Description of Shader binding table '", (Desc.Name ? Desc.Name : ""),
62 "' is invalid: ShaderRecordSize is too big, max size is: ", D3D12_RAYTRACING_MAX_SHADER_RECORD_STRIDE - D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES);
63 }
64 }
65
66 void ShaderBindingTableD3D12Impl::Verify() const
67 {
68 // AZ TODO
69 }
70
71 void ShaderBindingTableD3D12Impl::Reset(const ShaderBindingTableDesc& Desc)
72 {
73 m_RayGenShaderRecord.clear();
74 m_MissShadersRecord.clear();
75 m_CallableShadersRecord.clear();
76 m_HitGroupsRecord.clear();
77 m_Changed = true;
78
79 try
80 {
81 ValidateShaderBindingTableDesc(Desc);
82 ValidateDesc(Desc);
83 }
84 catch (const std::runtime_error&)
85 {
86 // AZ TODO
87 return;
88 }
89
90 m_Desc = Desc;
91 m_ShaderRecordStride = m_Desc.ShaderRecordSize + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
92 }
9353
9454 void ShaderBindingTableD3D12Impl::ResetHitGroups(Uint32 HitShadersPerInstance)
9555 {
5555 using FramebufferType = FramebufferGLImpl;
5656 using RenderPassType = RenderPassGLImpl;
5757 using BottomLevelASType = BottomLevelASBase<IBottomLevelAS, RenderDeviceGLImpl>;
58 using TopLevelASType = TopLevelASBase<ITopLevelAS, RenderDeviceGLImpl>;
58 using TopLevelASType = TopLevelASBase<ITopLevelAS, BottomLevelASType, RenderDeviceGLImpl>;
5959 };
6060
6161 /// Device context implementation in OpenGL backend.
200200 {
201201 return GetPhysicalDevice().GetExtProperties().RayTracing.shaderGroupHandleSize;
202202 }
203 Uint32 GetMaxShaderRecordStride() const
204 {
205 return GetPhysicalDevice().GetExtProperties().RayTracing.maxShaderGroupStride;
206 }
203207
204208 private:
205209 template <typename PSOCreateInfoType>
5050 bool bIsDeviceInternal = false);
5151 ~ShaderBindingTableVkImpl();
5252
53 virtual void DILIGENT_CALL_TYPE Verify() const override;
54
55 virtual void DILIGENT_CALL_TYPE Reset(const ShaderBindingTableDesc& Desc) override;
56
5753 virtual void DILIGENT_CALL_TYPE ResetHitGroups(Uint32 HitShadersPerInstance) override;
5854 virtual void DILIGENT_CALL_TYPE BindAll(const BindAllAttribs& Attribs) override;
5955
6763 IMPLEMENT_QUERY_INTERFACE_IN_PLACE(IID_ShaderBindingTableVk, TShaderBindingTableBase);
6864
6965 private:
70 void ValidateDesc(const ShaderBindingTableDesc& Desc) const;
71
72 private:
7366 RefCntAutoPtr<IBuffer> m_pBuffer;
7467 };
7568
3333 #include "RenderDeviceVkImpl.hpp"
3434 #include "TopLevelASVk.h"
3535 #include "TopLevelASBase.hpp"
36 #include "BottomLevelASVkImpl.hpp"
3637 #include "VulkanUtilities/VulkanObjectWrappers.hpp"
3738
3839 namespace Diligent
3940 {
4041
41 class TopLevelASVkImpl final : public TopLevelASBase<ITopLevelASVk, RenderDeviceVkImpl>
42 class TopLevelASVkImpl final : public TopLevelASBase<ITopLevelASVk, BottomLevelASVkImpl, RenderDeviceVkImpl>
4243 {
4344 public:
44 using TTopLevelASBase = TopLevelASBase<ITopLevelASVk, RenderDeviceVkImpl>;
45 using TTopLevelASBase = TopLevelASBase<ITopLevelASVk, BottomLevelASVkImpl, RenderDeviceVkImpl>;
4546
4647 TopLevelASVkImpl(IReferenceCounters* pRefCounters,
4748 RenderDeviceVkImpl* pRenderDeviceVk,
23332333 }
23342334 }
23352335
2336 namespace
2337 {
2338 NODISCARD inline bool ResourceStateHasWriteAccess(RESOURCE_STATE State)
2339 {
2340 static_assert(RESOURCE_STATE_MAX_BIT == RESOURCE_STATE_RAY_TRACING, "This function must be updated to handle new resource state flag");
2341 constexpr RESOURCE_STATE WriteAccessStates =
2342 RESOURCE_STATE_RENDER_TARGET |
2343 RESOURCE_STATE_UNORDERED_ACCESS |
2344 RESOURCE_STATE_COPY_DEST |
2345 RESOURCE_STATE_RESOLVE_DEST |
2346 RESOURCE_STATE_BUILD_AS_WRITE;
2347
2348 return State & WriteAccessStates;
2349 }
2350 } // namespace
2351
23362352 void DeviceContextVkImpl::TransitionTextureState(TextureVkImpl& TextureVk,
23372353 RESOURCE_STATE OldState,
23382354 RESOURCE_STATE NewState,
23952411 pSubresRange->aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
23962412 }
23972413
2398 // Note that when both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
2399 // to make sure that all UAV writes are complete and visible.
2414 // Always add barrier after writes.
2415 const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
2416
24002417 auto OldLayout = ResourceStateToVkImageLayout(OldState);
24012418 auto NewLayout = ResourceStateToVkImageLayout(NewState);
24022419 auto OldStages = ResourceStateFlagsToVkPipelineStageFlags(OldState, m_CommandBuffer.GetEnabledShaderStages());
24032420 auto NewStages = ResourceStateFlagsToVkPipelineStageFlags(NewState, m_CommandBuffer.GetEnabledShaderStages());
2404 m_CommandBuffer.TransitionImageLayout(vkImg, OldLayout, NewLayout, *pSubresRange, OldStages, NewStages);
2405 if (UpdateTextureState)
2406 {
2407 TextureVk.SetState(NewState);
2408 VERIFY_EXPR(TextureVk.GetLayout() == NewLayout);
2421
2422 if (((OldState & NewState) != NewState) || OldLayout != NewLayout || AfterWrite)
2423 {
2424 m_CommandBuffer.TransitionImageLayout(vkImg, OldLayout, NewLayout, *pSubresRange, OldStages, NewStages);
2425 if (UpdateTextureState)
2426 {
2427 TextureVk.SetState(NewState);
2428 VERIFY_EXPR(TextureVk.GetLayout() == NewLayout);
2429 }
24092430 }
24102431 }
24112432
24202441 VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
24212442 if (Texture.IsInKnownState())
24222443 {
2423 if (!Texture.CheckState(RequiredState))
2424 {
2425 TransitionTextureState(Texture, RESOURCE_STATE_UNKNOWN, RequiredState, true);
2426 }
2444 TransitionTextureState(Texture, RESOURCE_STATE_UNKNOWN, RequiredState, true);
24272445 VERIFY_EXPR(Texture.GetLayout() == ExpectedLayout);
24282446 }
24292447 }
24882506 }
24892507 }
24902508
2491 // When both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
2492 // to make sure that all UAV writes are complete and visible.
2493 if (((OldState & NewState) != NewState) || NewState == RESOURCE_STATE_UNORDERED_ACCESS || NewState == RESOURCE_STATE_BUILD_AS_WRITE)
2509 // Always add barrier after writes.
2510 const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
2511
2512 if (((OldState & NewState) != NewState) || AfterWrite)
24942513 {
24952514 DEV_CHECK_ERR(BufferVk.m_VulkanBuffer != VK_NULL_HANDLE, "Cannot transition suballocated buffer");
24962515 VERIFY_EXPR(BufferVk.GetDynamicOffset(m_ContextId, this) == 0);
25202539 VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
25212540 if (Buffer.IsInKnownState())
25222541 {
2523 if (!Buffer.CheckState(RequiredState))
2524 {
2525 TransitionBufferState(Buffer, RESOURCE_STATE_UNKNOWN, RequiredState, true);
2526 }
2542 TransitionBufferState(Buffer, RESOURCE_STATE_UNKNOWN, RequiredState, true);
25272543 VERIFY_EXPR(Buffer.CheckAccessFlags(ExpectedAccessFlags));
25282544 }
25292545 }
25632579 }
25642580 }
25652581
2566 if ((OldState & NewState) != NewState)
2582 // Always add barrier after writes.
2583 const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
2584
2585 if ((OldState & NewState) != NewState || AfterWrite)
25672586 {
25682587 EnsureVkCmdBuffer();
25692588 auto OldAccessFlags = ResourceStateFlagsToVkAccessFlags(OldState);
25832602 RESOURCE_STATE NewState,
25842603 bool UpdateInternalState)
25852604 {
2586 // AZ TODO: transit BLAS state too?
2587
25882605 VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
25892606 if (OldState == RESOURCE_STATE_UNKNOWN)
25902607 {
26082625 }
26092626 }
26102627
2611 if ((OldState & NewState) != NewState)
2628 // Always add barrier after writes.
2629 const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
2630
2631 if ((OldState & NewState) != NewState || AfterWrite)
26122632 {
26132633 EnsureVkCmdBuffer();
26142634 auto OldAccessFlags = ResourceStateFlagsToVkAccessFlags(OldState);
26332653 VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
26342654 if (BLAS.IsInKnownState())
26352655 {
2636 if (!BLAS.CheckState(RequiredState))
2637 {
2638 TransitionBLASState(BLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
2639 }
2656 TransitionBLASState(BLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
26402657 }
26412658 }
26422659 #ifdef DILIGENT_DEVELOPMENT
26572674 VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
26582675 if (TLAS.IsInKnownState())
26592676 {
2660 if (!TLAS.CheckState(RequiredState))
2661 {
2662 TransitionTLASState(TLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
2663 }
2677 TransitionTLASState(TLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
26642678 }
26652679 }
26662680 #ifdef DILIGENT_DEVELOPMENT
26672681 else if (TransitionMode == RESOURCE_STATE_TRANSITION_MODE_VERIFY)
26682682 {
26692683 DvpVerifyTLASState(TLAS, RequiredState, OperationName);
2684 }
2685
2686 if (RequiredState & RESOURCE_STATE_RAY_TRACING)
2687 {
2688 TLAS.CheckBLASVersion();
26702689 }
26712690 #endif
26722691 }
27172736 {
27182737 TransitionBufferState(*pBuffer, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
27192738 }
2720 else if (RefCntAutoPtr<BottomLevelASVkImpl> pBLAS{Barrier.pResource, IID_BottomLevelAS})
2721 {
2722 TransitionBLASState(*pBLAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
2723 }
2724 else if (RefCntAutoPtr<TopLevelASVkImpl> pTLAS{Barrier.pResource, IID_TopLevelAS})
2725 {
2726 TransitionTLASState(*pTLAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
2739 else if (RefCntAutoPtr<BottomLevelASVkImpl> pBottomLevelAS{Barrier.pResource, IID_BottomLevelAS})
2740 {
2741 TransitionBLASState(*pBottomLevelAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
2742 }
2743 else if (RefCntAutoPtr<TopLevelASVkImpl> pTopLevelAS{Barrier.pResource, IID_TopLevelAS})
2744 {
2745 TransitionTLASState(*pTopLevelAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
27272746 }
27282747 else
27292748 {
28072826
28082827 const char* OpName = "Build BottomLevelAS (DeviceContextVkImpl::BuildBLAS)";
28092828 TransitionOrVerifyBLASState(*pBLASVk, Attribs.BLASTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, OpName);
2810 TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VkAccessFlagBits(0), OpName);
2829 TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR, OpName);
28112830
28122831 VkAccelerationStructureBuildGeometryInfoKHR Info = {};
28132832 std::vector<VkAccelerationStructureBuildOffsetInfoKHR> Offsets;
28442863 vkTris.vertexStride = SrcTris.VertexStride;
28452864 vkTris.vertexData.deviceAddress = pVB->GetVkDeviceAddress() + SrcTris.VertexOffset;
28462865
2847 TransitionOrVerifyBufferState(*pVB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, static_cast<VkAccessFlagBits>(0), OpName);
2866 TransitionOrVerifyBufferState(*pVB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
28482867
28492868 if (SrcTris.pIndexBuffer)
28502869 {
28532872 vkTris.indexData.deviceAddress = pIB->GetVkDeviceAddress() + SrcTris.IndexOffset;
28542873 off.primitiveCount = SrcTris.IndexCount / 3;
28552874
2856 TransitionOrVerifyBufferState(*pIB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, static_cast<VkAccessFlagBits>(0), OpName);
2875 TransitionOrVerifyBufferState(*pIB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
28572876 }
28582877 else
28592878 {
28692888 auto* const pTB = ValidatedCast<BufferVkImpl>(SrcTris.pTransformBuffer);
28702889 vkTris.transformData.deviceAddress = pTB->GetVkDeviceAddress() + SrcTris.TransformBufferOffset;
28712890
2872 TransitionOrVerifyBufferState(*pTB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VkAccessFlagBits(0), OpName);
2891 TransitionOrVerifyBufferState(*pTB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
28732892 }
28742893 else
28752894 {
29122931 vkAABBs.stride = SrcBoxes.BoxStride;
29132932 vkAABBs.data.deviceAddress = pBB->GetVkDeviceAddress() + SrcBoxes.BoxOffset;
29142933
2915 TransitionOrVerifyBufferState(*pBB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VkAccessFlagBits(0), OpName);
2934 TransitionOrVerifyBufferState(*pBB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
29162935
29172936 off.firstVertex = 0;
29182937 off.transformOffset = 0;
29382957 EnsureVkCmdBuffer();
29392958 m_CommandBuffer.BuildAccelerationStructure(1, &Info, &OffsetsPtr);
29402959 ++m_State.NumCommands;
2960
2961 #ifdef DILIGENT_DEVELOPMENT
2962 pBLASVk->UpdateVersion();
2963 #endif
29412964 }
29422965
29432966 void DeviceContextVkImpl::BuildTLAS(const TLASBuildAttribs& Attribs)
29632986
29642987 const char* OpName = "Build TopLevelAS (DeviceContextVkImpl::BuildTLAS)";
29652988 TransitionOrVerifyTLASState(*pTLASVk, Attribs.TLASTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, OpName);
2966 TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VkAccessFlagBits(0), OpName);
2989 TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR, OpName);
29672990
29682991 pTLASVk->SetInstanceData(Attribs.pInstances, Attribs.InstanceCount, Attribs.HitShadersPerInstance);
29692992
29803003 auto* const pBLASVk = ValidatedCast<BottomLevelASVkImpl>(Inst.pBLAS);
29813004
29823005 static_assert(sizeof(vkASInst.transform) == sizeof(Inst.Transform), "size mismatch");
2983 std::memcpy(&vkASInst.transform, Inst.Transform, sizeof(vkASInst.transform));
3006 std::memcpy(&vkASInst.transform, Inst.Transform.data, sizeof(vkASInst.transform));
29843007
29853008 vkASInst.instanceCustomIndex = Inst.CustomId;
29863009 vkASInst.instanceShaderBindingTableRecordOffset = pTLASVk->GetInstanceDesc(Inst.InstanceName).ContributionToHitGroupIndex; // AZ TODO: optimize
29933016
29943017 UpdateBufferRegion(pInstancesVk, Attribs.InstanceBufferOffset, Size, TmpSpace.vkBuffer, TmpSpace.AlignedOffset, Attribs.InstanceBufferTransitionMode);
29953018 }
2996 TransitionOrVerifyBufferState(*pInstancesVk, Attribs.InstanceBufferTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VkAccessFlagBits(0), OpName);
3019 TransitionOrVerifyBufferState(*pInstancesVk, Attribs.InstanceBufferTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
29973020
29983021 VkAccelerationStructureBuildGeometryInfoKHR vkASBuildInfo = {};
29993022 VkAccelerationStructureBuildOffsetInfoKHR vkASBuildOffset = {};
30593082
30603083 m_CommandBuffer.CopyAccelerationStructure(Info);
30613084 ++m_State.NumCommands;
3085
3086 #ifdef DILIGENT_DEVELOPMENT
3087 pDstVk->UpdateVersion();
3088 #endif
30623089 }
30633090
30643091 void DeviceContextVkImpl::CopyTLAS(const CopyTLASAttribs& Attribs)
30753102
30763103 auto* pSrcVk = ValidatedCast<TopLevelASVkImpl>(Attribs.pSrc);
30773104 auto* pDstVk = ValidatedCast<TopLevelASVkImpl>(Attribs.pDst);
3105
3106 pDstVk->CopyInstancceData(*pSrcVk);
30783107
30793108 VkCopyAccelerationStructureInfoKHR Info = {};
30803109
757757 {
758758 try
759759 {
760 const auto& LogicalDevice = GetDevice()->GetLogicalDevice();
761 const auto ShaderGroupHandleSize = pDeviceVk->GetShaderGroupHandleSize();
762
763 if (LogicalDevice.GetEnabledExtFeatures().RayTracing.rayTracing == VK_FALSE)
764 LOG_ERROR_AND_THROW("Ray tracing is not supported by this device");
765
760766 std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
761767 std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
762
763768 std::vector<VkRayTracingShaderGroupCreateInfoKHR> ShaderGroups;
769
764770 InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules,
765771 [&](const RayTracingPipelineStateCreateInfo& CreateInfo, LinearAllocator& MemPool, TShaderStages& ShaderStages) //
766772 {
771777 );
772778
773779 CreateRayTracingPipeline(pDeviceVk, vkShaderStages, ShaderGroups, m_PipelineLayout, m_Desc, GetRayTracingPipelineDesc(), m_Pipeline);
774
775 const auto& LogicalDevice = GetDevice()->GetLogicalDevice();
776 const auto ShaderGroupHandleSize = pDeviceVk->GetShaderGroupHandleSize();
777780
778781 auto err = LogicalDevice.GetRayTracingShaderGroupHandles(m_Pipeline, 0, static_cast<uint32_t>(ShaderGroups.size()), ShaderGroupHandleSize, &m_pRayTracingPipelineData->Shaders[0]);
779782 VERIFY(err == VK_SUCCESS, "Failed to get shader group handles");
3838 bool bIsDeviceInternal) :
3939 TShaderBindingTableBase{pRefCounters, pRenderDeviceVk, Desc, bIsDeviceInternal}
4040 {
41 ValidateDesc(Desc);
42
43 const auto& RTLimits = GetDevice()->GetPhysicalDevice().GetExtProperties().RayTracing;
44 m_ShaderRecordStride = m_Desc.ShaderRecordSize + RTLimits.shaderGroupHandleSize;
4541 }
4642
4743 ShaderBindingTableVkImpl::~ShaderBindingTableVkImpl()
4844 {
49 }
50
51 void ShaderBindingTableVkImpl::ValidateDesc(const ShaderBindingTableDesc& Desc) const
52 {
53 const auto& RTLimits = GetDevice()->GetPhysicalDevice().GetExtProperties().RayTracing;
54
55 if (Desc.ShaderRecordSize + RTLimits.shaderGroupHandleSize > RTLimits.maxShaderGroupStride)
56 {
57 LOG_ERROR_AND_THROW("Description of Shader binding table '", (Desc.Name ? Desc.Name : ""),
58 "' is invalid: ShaderRecordSize is too big, max size is: ", RTLimits.maxShaderGroupStride - RTLimits.shaderGroupHandleSize);
59 }
60 }
61
62 void ShaderBindingTableVkImpl::Verify() const
63 {
64 // AZ TODO
65 }
66
67 void ShaderBindingTableVkImpl::Reset(const ShaderBindingTableDesc& Desc)
68 {
69 m_RayGenShaderRecord.clear();
70 m_MissShadersRecord.clear();
71 m_CallableShadersRecord.clear();
72 m_HitGroupsRecord.clear();
73 m_Changed = true;
74
75 try
76 {
77 ValidateShaderBindingTableDesc(Desc);
78 ValidateDesc(Desc);
79 }
80 catch (const std::runtime_error&)
81 {
82 // AZ TODO
83 return;
84 }
85
86 m_Desc = Desc;
87
88 const auto& RTLimits = GetDevice()->GetPhysicalDevice().GetExtProperties().RayTracing;
89 m_ShaderRecordStride = m_Desc.ShaderRecordSize + RTLimits.shaderGroupHandleSize;
9045 }
9146
9247 void ShaderBindingTableVkImpl::ResetHitGroups(Uint32 HitShadersPerInstance)
165165 {
166166 constexpr RESOURCE_STATE RequiredState = RESOURCE_STATE_CONSTANT_BUFFER;
167167 VERIFY_EXPR((ResourceStateFlagsToVkAccessFlags(RequiredState) & VK_ACCESS_UNIFORM_READ_BIT) == VK_ACCESS_UNIFORM_READ_BIT);
168 const bool IsInRequiredState = pBufferVk->CheckState(RequiredState);
169168 if (VerifyOnly)
170169 {
171 if (!IsInRequiredState)
170 if (!pBufferVk->CheckState(RequiredState))
172171 {
173172 LOG_ERROR_MESSAGE("State of buffer '", pBufferVk->GetDesc().Name, "' is incorrect. Required state: ",
174173 GetResourceStateString(RequiredState), ". Actual state: ",
180179 }
181180 else
182181 {
183 if (!IsInRequiredState)
184 {
185 pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
186 }
182 pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
187183 VERIFY_EXPR(pBufferVk->CheckAccessFlags(VK_ACCESS_UNIFORM_READ_BIT));
188184 }
189185 }
210206 (VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
211207 VERIFY_EXPR((ResourceStateFlagsToVkAccessFlags(RequiredState) & RequiredAccessFlags) == RequiredAccessFlags);
212208 #endif
213 const bool IsInRequiredState = pBufferVk->CheckState(RequiredState);
214209
215210 if (VerifyOnly)
216211 {
217 if (!IsInRequiredState)
212 if (!pBufferVk->CheckState(RequiredState))
218213 {
219214 LOG_ERROR_MESSAGE("State of buffer '", pBufferVk->GetDesc().Name, "' is incorrect. Required state: ",
220215 GetResourceStateString(RequiredState), ". Actual state: ",
226221 }
227222 else
228223 {
229 // When both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
230 // to make sure that all UAV writes are complete and visible.
231 if (!IsInRequiredState || RequiredState == RESOURCE_STATE_UNORDERED_ACCESS)
232 {
233 pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
234 }
224 pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
235225 VERIFY_EXPR(pBufferVk->CheckAccessFlags(RequiredAccessFlags));
236226 }
237227 }
274264 VERIFY_EXPR(ResourceStateToVkImageLayout(RequiredState) == VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
275265 }
276266 }
277 const bool IsInRequiredState = pTextureVk->CheckState(RequiredState);
278267
279268 if (VerifyOnly)
280269 {
281 if (!IsInRequiredState)
270 if (!pTextureVk->CheckState(RequiredState))
282271 {
283272 LOG_ERROR_MESSAGE("State of texture '", pTextureVk->GetDesc().Name, "' is incorrect. Required state: ",
284273 GetResourceStateString(RequiredState), ". Actual state: ",
290279 }
291280 else
292281 {
293 // When both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
294 // to make sure that all UAV writes are complete and visible.
295 if (!IsInRequiredState || RequiredState == RESOURCE_STATE_UNORDERED_ACCESS)
296 {
297 pCtxVkImpl->TransitionTextureState(*pTextureVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
298 }
282 pCtxVkImpl->TransitionTextureState(*pTextureVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
299283 }
300284 }
301285 }
326310 auto* pTLASVk = Res.pObject.RawPtr<TopLevelASVkImpl>();
327311 if (pTLASVk != nullptr && pTLASVk->IsInKnownState())
328312 {
329 constexpr RESOURCE_STATE RequiredState = RESOURCE_STATE_RAY_TRACING;
330 const bool IsInRequiredState = pTLASVk->CheckState(RequiredState);
313 constexpr RESOURCE_STATE RequiredState = RESOURCE_STATE_RAY_TRACING;
331314 if (VerifyOnly)
332315 {
333 if (!IsInRequiredState)
316 if (!pTLASVk->CheckState(RequiredState))
334317 {
335318 LOG_ERROR_MESSAGE("State of TLAS '", pTLASVk->GetDesc().Name, "' is incorrect. Required state: ",
336319 GetResourceStateString(RequiredState), ". Actual state: ",
339322 "when calling IDeviceContext::CommitShaderResources() or explicitly transition the TLAS state "
340323 "with IDeviceContext::TransitionResourceStates().");
341324 }
325
326 pTLASVk->CheckBLASVersion();
342327 }
343328 else
344329 {
345 if (!IsInRequiredState)
346 {
347 pCtxVkImpl->TransitionTLASState(*pTLASVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
348 }
330 pCtxVkImpl->TransitionTLASState(*pTLASVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
349331 }
350332 }
351333 }
16121612 "Please update the switch below to handle the new ray tracing build flag");
16131613
16141614 VkBuildAccelerationStructureFlagsKHR Result = 0;
1615 for (Uint32 Bit = 1; Bit <= Flags; Bit <<= 1)
1616 {
1617 if ((Flags & Bit) != Bit)
1618 continue;
1619
1620 switch (static_cast<RAYTRACING_BUILD_AS_FLAGS>(Bit))
1615 while (Flags != RAYTRACING_BUILD_AS_NONE)
1616 {
1617 auto FlagBit = static_cast<RAYTRACING_BUILD_AS_FLAGS>(1 << PlatformMisc::GetLSB(Uint32{Flags}));
1618 switch (FlagBit)
16211619 {
16221620 // clang-format off
16231621 case RAYTRACING_BUILD_AS_ALLOW_UPDATE: Result |= VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_UPDATE_BIT_KHR; break;
16281626 // clang-format on
16291627 default: UNEXPECTED("unknown build AS flag");
16301628 }
1629 Flags = Flags & ~FlagBit;
16311630 }
16321631 return Result;
16331632 }
16381637 "Please update the switch below to handle the new ray tracing geometry flag");
16391638
16401639 VkGeometryFlagsKHR Result = 0;
1641 for (Uint32 Bit = 1; Bit <= Flags; Bit <<= 1)
1642 {
1643 if ((Flags & Bit) != Bit)
1644 continue;
1645
1646 switch (static_cast<RAYTRACING_GEOMETRY_FLAGS>(Bit))
1640 while (Flags != RAYTRACING_GEOMETRY_NONE)
1641 {
1642 auto FlagBit = static_cast<RAYTRACING_GEOMETRY_FLAGS>(1 << PlatformMisc::GetLSB(Uint32{Flags}));
1643 switch (FlagBit)
16471644 {
16481645 // clang-format off
16491646 case RAYTRACING_GEOMETRY_OPAQUE: Result |= VK_GEOMETRY_OPAQUE_BIT_KHR; break;
16511648 // clang-format on
16521649 default: UNEXPECTED("unknown geometry flag");
16531650 }
1651 Flags = Flags & ~FlagBit;
16541652 }
16551653 return Result;
16561654 }
16611659 "Please update the switch below to handle the new ray tracing instance flag");
16621660
16631661 VkGeometryInstanceFlagsKHR Result = 0;
1664 for (Uint32 Bit = 1; Bit <= Flags; Bit <<= 1)
1665 {
1666 if ((Flags & Bit) != Bit)
1667 continue;
1668
1669 switch (static_cast<RAYTRACING_INSTANCE_FLAGS>(Bit))
1662 while (Flags != RAYTRACING_INSTANCE_NONE)
1663 {
1664 auto FlagBit = static_cast<RAYTRACING_INSTANCE_FLAGS>(1 << PlatformMisc::GetLSB(Uint32{Flags}));
1665 switch (FlagBit)
16701666 {
16711667 // clang-format off
16721668 case RAYTRACING_INSTANCE_TRIANGLE_FACING_CULL_DISABLE: Result |= VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR; break;
16761672 // clang-format on
16771673 default: UNEXPECTED("unknown instance flag");
16781674 }
1675 Flags = Flags & ~FlagBit;
16791676 }
16801677 return Result;
16811678 }
16921689 // clang-format on
16931690 default:
16941691 UNEXPECTED("unknown AS copy mode");
1695 return static_cast<VkCopyAccelerationStructureModeKHR>(0);
1692 return VK_COPY_ACCELERATION_STRUCTURE_MODE_MAX_ENUM_KHR;
16961693 }
16971694 }
16981695
5454 gl_RayFlagsNoneEXT, // rayFlags
5555 0xFF, // cullMask
5656 0, // sbtRecordOffset
57 0, // sbtRecordStride
57 1, // sbtRecordStride
5858 0, // missIndex
5959 origin, // ray origin
6060 0.01, // ray min range
121121 gl_RayFlagsSkipClosestHitShaderEXT,
122122 0xFF, // cullMask
123123 0, // sbtRecordOffset
124 0, // sbtRecordStride
124 1, // sbtRecordStride
125125 0, // missIndex
126126 origin, // ray origin
127127 0.01, // ray min range
206206 gl_RayFlagsNoneEXT, // rayFlags
207207 0xFF, // cullMask
208208 0, // sbtRecordOffset
209 0, // sbtRecordStride
209 1, // sbtRecordStride
210210 0, // missIndex
211211 origin, // ray origin
212212 0.01, // ray min range
279279 // clang-format on
280280
281281
282 // clang-format off
283 const std::string RayTracingTest4_RG{
284 R"glsl(
285 #version 460
286 #extension GL_EXT_ray_tracing : require
287
288 layout(set=0, binding=0) uniform accelerationStructureEXT g_TLAS;
289 layout(set=0, binding=1, rgba8) uniform image2D g_ColorBuffer;
290
291 layout(location=0) rayPayloadEXT vec4 payload;
292
293 void main()
294 {
295 const vec2 uv = vec2(gl_LaunchIDEXT.xy) / vec2(gl_LaunchSizeEXT.xy - 1);
296 const vec3 origin = vec3(uv.x, 1.0 - uv.y, -1.0);
297 const vec3 direction = vec3(0.0, 0.0, 1.0);
298
299 payload = vec4(0.0);
300 traceRayEXT(g_TLAS, // acceleration structure
301 gl_RayFlagsNoneEXT, // rayFlags
302 0xFF, // cullMask
303 0, // sbtRecordOffset
304 1, // sbtRecordStride
305 0, // missIndex
306 origin, // ray origin
307 0.01, // ray min range
308 direction, // ray direction
309 10.0, // ray max range
310 0); // payload location
311
312 imageStore(g_ColorBuffer, ivec2(gl_LaunchIDEXT), payload);
313 }
314 )glsl"
315 };
316
317 const std::string RayTracingTest4_RM{
318 R"glsl(
319 #version 460
320 #extension GL_EXT_ray_tracing : require
321
322 layout(location=0) rayPayloadInEXT vec4 payload;
323
324 void main()
325 {
326 payload = vec4(0.0, 0.0, 0.2, 1.0);
327 }
328 )glsl"
329 };
330
331 const std::string RayTracingTest4_Uniforms{
332 R"glsl(
333 #version 460
334 #extension GL_EXT_ray_tracing : require
335
336 layout(shaderRecordEXT) buffer ShaderRecord
337 {
338 vec4 Weights;
339 };
340
341 layout(location=0) rayPayloadInEXT vec4 payload;
342 hitAttributeEXT vec2 hitAttribs;
343
344 layout(set=0, binding=2, std430) readonly buffer PerInstanceData {
345 uint PrimitiveOffsets[3];
346 } g_PerInstance[2];
347
348 layout(set=0, binding=3, std430) readonly buffer PrimitiveData {
349 uvec4 g_Primitives[9];
350 };
351
352 struct Vertex
353 {
354 vec4 Pos;
355 vec4 Color1;
356 vec4 Color2;
357 };
358 layout(set=0, binding=4, std430) readonly buffer VertexData {
359 Vertex g_Vertices[16];
360 };
361 )glsl"
362 };
363
364 const std::string RayTracingTest4_RCH1 = RayTracingTest4_Uniforms +
365 R"glsl(
366 void main()
367 {
368 vec3 barycentrics = vec3(1.0f - hitAttribs.x - hitAttribs.y, hitAttribs.x, hitAttribs.y);// * Weights.xyz;
369 uint primOffset = g_PerInstance[gl_InstanceID].PrimitiveOffsets[gl_GeometryIndexEXT];
370 uvec4 triFace = g_Primitives[primOffset + gl_PrimitiveID];
371 Vertex v0 = g_Vertices[triFace.x];
372 Vertex v1 = g_Vertices[triFace.y];
373 Vertex v2 = g_Vertices[triFace.z];
374 vec4 col = v0.Color2 * barycentrics.x + v1.Color2 * barycentrics.y + v2.Color2 * barycentrics.z;
375 payload = col;
376 }
377 )glsl";
378
379 const std::string RayTracingTest4_RCH2 = RayTracingTest4_Uniforms +
380 R"glsl(
381 void main()
382 {
383 vec3 barycentrics = vec3(1.0f - hitAttribs.x - hitAttribs.y, hitAttribs.x, hitAttribs.y);// * Weights.xyz;
384 uint primOffset = g_PerInstance[gl_InstanceID].PrimitiveOffsets[gl_GeometryIndexEXT];
385 uvec4 triFace = g_Primitives[primOffset + gl_PrimitiveID];
386 Vertex v0 = g_Vertices[triFace.x];
387 Vertex v1 = g_Vertices[triFace.y];
388 Vertex v2 = g_Vertices[triFace.z];
389 vec4 col = v0.Color1 * barycentrics.x + v1.Color1 * barycentrics.y + v2.Color1 * barycentrics.z;
390 payload = col;
391 }
392 )glsl";
393 // clang-format on
394
395
282396 } // namespace GLSL
283397
284398 } // namespace
8585 [shader("closesthit")]
8686 void main(inout RTPayload payload, in BuiltInTriangleIntersectionAttributes attr)
8787 {
88 float3 barycentrics = float3(1 - attr.barycentrics.x - attr.barycentrics.y, attr.barycentrics.x, attr.barycentrics.y);
88 float3 barycentrics = float3(1.0 - attr.barycentrics.x - attr.barycentrics.y, attr.barycentrics.x, attr.barycentrics.y);
8989 payload.Color = float4(barycentrics, 1.0);
9090 }
9191 )hlsl";
146146 [shader("anyhit")]
147147 void main(inout RTPayload payload, in BuiltInTriangleIntersectionAttributes attr)
148148 {
149 float3 barycentrics = float3(1 - attr.barycentrics.x - attr.barycentrics.y, attr.barycentrics.x, attr.barycentrics.y);
149 float3 barycentrics = float3(1.0 - attr.barycentrics.x - attr.barycentrics.y, attr.barycentrics.x, attr.barycentrics.y);
150150 if (barycentrics.y > barycentrics.x)
151151 IgnoreHit();
152152 else
240240 )hlsl";
241241 // clang-format on
242242
243
244 // clang-format off
245 const std::string RayTracingTest4_RG = RayTracingTest_Payload +
246 R"hlsl(
247 RaytracingAccelerationStructure g_TLAS : register(t0);
248 RWTexture2D<float4> g_ColorBuffer : register(u0);
249
250 [shader("raygeneration")]
251 void main()
252 {
253 const float2 uv = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy - 1);
254
255 RayDesc ray;
256 ray.Origin = float3(uv.x, 1.0 - uv.y, -1.0);
257 ray.Direction = float3(0.0, 0.0, 1.0);
258 ray.TMin = 0.01;
259 ray.TMax = 10.0;
260
261 RTPayload payload = {float4(0, 0, 0, 0)};
262 TraceRay(g_TLAS, // Acceleration Structure
263 RAY_FLAG_NONE, // Ray Flags
264 ~0, // Instance Inclusion Mask
265 0, // Ray Contribution To Hit Group Index
266 1, // Multiplier For Geometry Contribution To Hit Group Index
267 0, // Miss Shader Index
268 ray,
269 payload);
270
271 g_ColorBuffer[DispatchRaysIndex().xy] = payload.Color;
272 }
273 )hlsl";
274
275 const std::string RayTracingTest4_RM = RayTracingTest_Payload +
276 R"hlsl(
277 [shader("miss")]
278 void main(inout RTPayload payload)
279 {
280 payload.Color = float4(0.0, 0.0, 0.2, 1.0);
281 }
282 )hlsl";
283
284 const std::string RayTracingTest4_Uniforms = RayTracingTest_Payload +
285 R"hlsl(
286 struct Vertex
287 {
288 float4 Pos;
289 float4 Color1;
290 float4 Color2;
291 };
292 StructuredBuffer<Vertex> g_Vertices : register(t1); // array size = 16
293 StructuredBuffer<uint> g_PerInstance[2] : register(t2); // array size = 3
294 StructuredBuffer<uint4> g_Primitives : register(t4); // array size = 9
295
296 // local root constants
297 struct LocalRootConst
298 {
299 float4 Weight;
300 };
301 //[[vk::shader_record_ext]]
302 //ConstantBuffer<LocalRootConst> g_LocalRoot : register(b0);
303 )hlsl";
304
305 const std::string RayTracingTest4_RCH1 = RayTracingTest4_Uniforms +
306 R"hlsl(
307 [shader("closesthit")]
308 void main(inout RTPayload payload, in BuiltInTriangleIntersectionAttributes attr)
309 {
310 float3 barycentrics = float3(1.0 - attr.barycentrics.x - attr.barycentrics.y, attr.barycentrics.x, attr.barycentrics.y);// * g_LocalRoot.Weight.xyz;
311 uint primOffset = g_PerInstance[InstanceIndex()][GeometryIndex()];
312 uint4 triFace = g_Primitives[primOffset + PrimitiveIndex()];
313 Vertex v0 = g_Vertices[triFace.x];
314 Vertex v1 = g_Vertices[triFace.y];
315 Vertex v2 = g_Vertices[triFace.z];
316 float4 col = v0.Color2 * barycentrics.x + v1.Color2 * barycentrics.y + v2.Color2 * barycentrics.z;
317 payload.Color = col;
318 }
319 )hlsl";
320
321 const std::string RayTracingTest4_RCH2 = RayTracingTest4_Uniforms +
322 R"hlsl(
323 [shader("closesthit")]
324 void main(inout RTPayload payload, in BuiltInTriangleIntersectionAttributes attr)
325 {
326 float3 barycentrics = float3(1.0 - attr.barycentrics.x - attr.barycentrics.y, attr.barycentrics.x, attr.barycentrics.y);// * g_LocalRoot.Weight.xyz;
327 uint primOffset = g_PerInstance[InstanceIndex()][GeometryIndex()];
328 uint4 triFace = g_Primitives[primOffset + PrimitiveIndex()];
329 Vertex v0 = g_Vertices[triFace.x];
330 Vertex v1 = g_Vertices[triFace.y];
331 Vertex v2 = g_Vertices[triFace.z];
332 float4 col = v0.Color1 * barycentrics.x + v1.Color1 * barycentrics.y + v2.Color1 * barycentrics.z;
333 payload.Color = col;
334 }
335 )hlsl";
336 // clang-format on
337
243338 } // namespace HLSL
244339
245340 } // namespace
0 /*
1 * Copyright 2019-2020 Diligent Graphics LLC
2 * Copyright 2015-2019 Egor Yusov
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * In no event and under no legal theory, whether in tort (including negligence),
17 * contract, or otherwise, unless required by applicable law (such as deliberate
18 * and grossly negligent acts) or agreed to in writing, shall any Contributor be
19 * liable for any damages, including any direct, indirect, special, incidental,
20 * or consequential damages of any character arising as a result of this License or
21 * out of the use or inability to use the software (including but not limited to damages
22 * for loss of goodwill, work stoppage, computer failure or malfunction, or any and
23 * all other commercial damages or losses), even if such Contributor has been advised
24 * of the possibility of such damages.
25 */
26
27 #include "BasicMath.hpp"
28
29 namespace Diligent
30 {
31
32 namespace TestingConstants
33 {
34 // clang-format off
35
36 namespace TriangleClosestHit
37 {
38 static const float3 Vertices[] =
39 {
40 float3{0.25f, 0.25f, 0.0f},
41 float3{0.75f, 0.25f, 0.0f},
42 float3{0.50f, 0.75f, 0.0f}
43 };
44 } // namespace TriangleClosestHit
45
46 namespace TriangleAnyHit
47 {
48 static const float3 Vertices[] =
49 {
50 float3{0.25f, 0.25f, 0.0f}, float3{0.75f, 0.25f, 0.0f}, float3{0.50f, 0.75f, 0.0f},
51 float3{0.50f, 0.10f, 0.1f}, float3{0.90f, 0.90f, 0.1f}, float3{0.10f, 0.90f, 0.1f},
52 float3{0.40f, 1.00f, 0.2f}, float3{0.20f, 0.40f, 0.2f}, float3{1.00f, 0.70f, 0.2f}
53 };
54 } // namespace TriangleAnyHit
55
56 namespace ProceduralIntersection
57 {
58 static const float3 Boxes[] =
59 {
60 float3{0.25f, 0.5f, 2.0f} - float3{1.0f, 1.0f, 1.0f},
61 float3{0.25f, 0.5f, 2.0f} + float3{1.0f, 1.0f, 1.0f}
62 };
63 } // namespace ProceduralIntersection
64
65 namespace MultiGeometry
66 {
67 struct VertexType
68 {
69 float4 Pos;
70 float4 Color1;
71 float4 Color2;
72
73 VertexType(float2 _Pos, float3 _Color1, float3 _Color2) :
74 Pos {_Pos.x, _Pos.y, 2.0f, 1.0f},
75 Color1{_Color1.x, _Color1.y, _Color1.z, 1.0f},
76 Color2{_Color2.x, _Color2.y, _Color2.z, 1.0f}
77 {}
78 };
79
80 static const VertexType Vertices[] =
81 {
82 // geometry 1
83 VertexType{{0.10f, 0.10f}, {0.7f, 0.3f, 0.1f}, {0.2f, 0.9f, 0.4f}}, // 0
84 VertexType{{0.17f, 0.30f}, {0.6f, 0.0f, 0.4f}, {0.2f, 0.5f, 0.8f}}, // 1
85 VertexType{{0.10f, 0.31f}, {0.3f, 0.7f, 0.4f}, {0.9f, 0.2f, 0.6f}}, // 2
86 VertexType{{0.22f, 0.45f}, {0.2f, 0.9f, 0.7f}, {0.1f, 0.7f, 0.1f}}, // 3
87 // geometry 2
88 VertexType{{0.27f, 0.10f}, {0.5f, 0.1f, 0.6f}, {0.3f, 0.1f, 0.5f}}, // 4
89 VertexType{{0.40f, 0.30f}, {1.0f, 1.0f, 1.0f}, {0.3f, 1.0f, 0.7f}}, // 5
90 VertexType{{0.26f, 0.30f}, {0.3f, 0.3f, 0.9f}, {1.0f, 0.0f, 0.3f}}, // 6
91 VertexType{{0.40f, 0.47f}, {0.8f, 1.0f, 0.2f}, {1.0f, 0.7f, 0.0f}}, // 7
92 VertexType{{0.54f, 0.30f}, {0.1f, 1.0f, 0.9f}, {0.0f, 1.0f, 0.6f}}, // 8
93 VertexType{{0.53f, 0.10f}, {1.0f, 0.0f, 1.0f}, {0.0f, 0.0f, 1.0f}}, // 9
94 // geometry 3
95 VertexType{{0.65f, 0.10f}, {0.3f, 0.6f, 0.8f}, {1.0f, 0.9f, 0.2f}}, // 10
96 VertexType{{0.63f, 0.25f}, {0.9f, 1.0f, 0.2f}, {0.1f, 0.2f, 0.3f}}, // 11
97 VertexType{{0.82f, 0.20f}, {0.4f, 0.5f, 0.0f}, {1.0f, 0.2f, 0.6f}}, // 12
98 VertexType{{0.76f, 0.30f}, {1.0f, 0.0f, 0.0f}, {0.4f, 0.7f, 0.2f}}, // 13
99 VertexType{{0.55f, 0.48f}, {0.5f, 0.1f, 0.2f}, {1.0f, 0.3f, 0.5f}}, // 14
100 VertexType{{0.90f, 0.40f}, {0.8f, 0.2f, 1.0f}, {0.3f, 0.6f, 0.4f}}, // 15
101 };
102 static const uint Indices[] =
103 {
104 0, 1, 2, 2, 1, 3, // geometry 1
105 4, 5, 6, 6, 7, 8, 8, 5, 9, // geometry 2
106 10, 12, 11, 11, 12, 13, 11, 13, 14, 13, 12, 15, // geometry 3
107 };
108 static const uint4 Primitives[] =
109 {
110 // geometry 1
111 {Indices[ 0], Indices[ 1], Indices[ 2], 0}, // 0
112 {Indices[ 3], Indices[ 4], Indices[ 5], 0}, // 1
113 // geometry 2
114 {Indices[ 6], Indices[ 7], Indices[ 8], 0}, // 2
115 {Indices[ 9], Indices[10], Indices[11], 0}, // 3
116 {Indices[12], Indices[13], Indices[14], 0}, // 4
117 // geometry 3
118 {Indices[15], Indices[16], Indices[17], 0}, // 5
119 {Indices[18], Indices[19], Indices[20], 0}, // 6
120 {Indices[21], Indices[22], Indices[23], 0}, // 7
121 {Indices[24], Indices[25], Indices[26], 0} // 8
122 };
123 static const uint PrimitiveOffsets[] =
124 {
125 0, 2, 5
126 };
127
128 struct ShaderRecord
129 {
130 float4 Weight;
131 float4 Padding;
132 };
133 static const ShaderRecord Weights[] =
134 {
135 ShaderRecord{{1.0f, 0.4f, 0.4f, 1.0f}, {}},
136 ShaderRecord{{0.4f, 1.0f, 0.4f, 1.0f}, {}},
137 ShaderRecord{{0.4f, 0.4f, 1.0f, 1.0f}, {}}
138 };
139 static constexpr Uint32 ShaderRecordSize = sizeof(Weights[0]);
140 static constexpr Uint32 InstanceCount = 2;
141
142 static_assert(_countof(Vertices) == 16, "Update array size in shaders");
143 static_assert(_countof(PrimitiveOffsets) == 3, "Update array size in shaders");
144 static_assert(_countof(Primitives) == 9, "Update array size in shaders");
145 static_assert(_countof(Indices) % 3 == 0, "Invalid index count");
146 static_assert(_countof(Indices) / 3 == _countof(Primitives), "Primitive count mismatch");
147
148 } // namespace MultiGeometry
149
150 // clang-format on
151
152 } // namespace TestingConstants
153
154 } // namespace Diligent
3333 #include "BasicMath.hpp"
3434
3535 #include "InlineShaders/RayTracingTestHLSL.h"
36 #include "RayTracingTestConstants.hpp"
3637
3738 namespace Diligent
3839 {
4546
4647 struct RTContext
4748 {
48 ID3D12Device5* pDevice = nullptr;
49 struct AccelStruct
50 {
51 CComPtr<ID3D12Resource> pAS;
52 UINT64 BuildScratchSize = 0;
53 UINT64 UpdateScratchSize = 0;
54 };
55
56 CComPtr<ID3D12Device5> pDevice;
4957 CComPtr<ID3D12GraphicsCommandList4> pCmdList;
5058 CComPtr<ID3D12StateObject> pRayTracingSO;
5159 CComPtr<ID3D12StateObjectProperties> pStateObjectProperties;
52 CComPtr<ID3D12RootSignature> pRootSignature;
53 CComPtr<ID3D12Resource> pBLAS;
54 UINT64 BLASBuildScratchSize = 0;
55 UINT64 BLASUpdateScratchSize = 0;
56 CComPtr<ID3D12Resource> pTLAS;
57 UINT64 TLASBuildScratchSize = 0;
58 UINT64 TLASUpdateScratchSize = 0;
60 CComPtr<ID3D12RootSignature> pGlobalRootSignature;
61 CComPtr<ID3D12RootSignature> pLocalRootSignature;
62 AccelStruct BLAS;
63 AccelStruct TLAS;
5964 CComPtr<ID3D12Resource> pScratchBuffer;
6065 CComPtr<ID3D12Resource> pVertexBuffer;
6166 CComPtr<ID3D12Resource> pIndexBuffer;
95100 static constexpr UINT DescriptorHeapSize = 16;
96101 };
97102
98 template <typename PSOCtorType>
99 void InitializeRTContext(RTContext& Ctx, ISwapChain* pSwapChain, PSOCtorType&& PSOCtor)
103 template <typename PSOCtorType, typename RootSigCtorType>
104 void InitializeRTContext(RTContext& Ctx, ISwapChain* pSwapChain, Uint32 ShaderRecordSize, PSOCtorType&& PSOCtor, RootSigCtorType&& RootSigCtor)
100105 {
101106 auto* pEnv = TestingEnvironmentD3D12::GetInstance();
102107 auto* pTestingSwapChainD3D12 = ValidatedCast<TestingSwapChainD3D12>(pSwapChain);
109114 hr = pEnv->CreateGraphicsCommandList()->QueryInterface(IID_PPV_ARGS(&Ctx.pCmdList));
110115 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to get ID3D12GraphicsCommandList4";
111116
112 // create root signature
113 {
114 D3D12_ROOT_SIGNATURE_DESC RootSignatureDesc = {};
115 D3D12_DESCRIPTOR_RANGE DescriptorRanges[2] = {};
116 D3D12_ROOT_PARAMETER Params[1] = {};
117
118 DescriptorRanges[0].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_UAV;
119 DescriptorRanges[0].NumDescriptors = 1;
120 DescriptorRanges[0].BaseShaderRegister = 0;
121 DescriptorRanges[0].RegisterSpace = 0;
122 DescriptorRanges[0].OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND;
123
124 DescriptorRanges[1].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_SRV;
125 DescriptorRanges[1].NumDescriptors = 1;
126 DescriptorRanges[1].BaseShaderRegister = 0;
127 DescriptorRanges[1].RegisterSpace = 0;
128 DescriptorRanges[1].OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND;
129
130 Params[0].ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
131 Params[0].ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
132 Params[0].DescriptorTable.NumDescriptorRanges = _countof(DescriptorRanges);
133 Params[0].DescriptorTable.pDescriptorRanges = DescriptorRanges;
117 // create descriptor heap
118 {
119 D3D12_DESCRIPTOR_HEAP_DESC Desc = {};
120
121 Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
122 Desc.NumDescriptors = Ctx.DescriptorHeapSize;
123 Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
124 Desc.NodeMask = 0;
125
126 hr = Ctx.pDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&Ctx.pDescHeap));
127 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to create descriptor heap";
128
129 Ctx.DescHeapCount = 0;
130 Ctx.DescHandleSize = Ctx.pDevice->GetDescriptorHandleIncrementSize(Desc.Type);
131
132 D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {};
133
134 UAVDesc.Format = DXGI_FORMAT_R8G8B8A8_UNORM;
135 UAVDesc.ViewDimension = D3D12_UAV_DIMENSION_TEXTURE2D;
136
137 D3D12_CPU_DESCRIPTOR_HANDLE UAVHandle = Ctx.pDescHeap->GetCPUDescriptorHandleForHeapStart();
138 ASSERT_LT(Ctx.DescHeapCount, Ctx.DescriptorHeapSize);
139 ASSERT_TRUE(Ctx.DescHeapCount == 0);
140 UAVHandle.ptr += Ctx.DescHandleSize * Ctx.DescHeapCount++;
141 Ctx.pDevice->CreateUnorderedAccessView(pTestingSwapChainD3D12->GetD3D12RenderTarget(), nullptr, &UAVDesc, UAVHandle);
142 }
143
144 // create global root signature
145 {
146 D3D12_ROOT_SIGNATURE_DESC RootSignatureDesc = {};
147 D3D12_ROOT_PARAMETER Param = {};
148 D3D12_DESCRIPTOR_RANGE Range = {};
149 std::vector<D3D12_DESCRIPTOR_RANGE> DescriptorRanges;
150
151 RootSigCtor(DescriptorRanges);
152
153 Range.RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_UAV;
154 Range.NumDescriptors = 1;
155 Range.OffsetInDescriptorsFromTableStart = 0;
156 DescriptorRanges.push_back(Range); // g_TLAS
157
158 Range.RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_SRV;
159 Range.NumDescriptors = 1;
160 Range.OffsetInDescriptorsFromTableStart = 1;
161 DescriptorRanges.push_back(Range); // g_ColorBuffer
162
163 Param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE;
164 Param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
165 Param.DescriptorTable.NumDescriptorRanges = static_cast<Uint32>(DescriptorRanges.size());
166 Param.DescriptorTable.pDescriptorRanges = DescriptorRanges.data();
134167
135168 RootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
136 RootSignatureDesc.NumParameters = _countof(Params);
137 RootSignatureDesc.pParameters = Params;
169 RootSignatureDesc.NumParameters = 1;
170 RootSignatureDesc.pParameters = &Param;
138171
139172 CComPtr<ID3DBlob> signature;
140173 hr = D3D12SerializeRootSignature(&RootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, nullptr);
141174 ASSERT_HRESULT_SUCCEEDED(hr);
142175
143 hr = Ctx.pDevice->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&Ctx.pRootSignature));
176 hr = Ctx.pDevice->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&Ctx.pGlobalRootSignature));
177 ASSERT_HRESULT_SUCCEEDED(hr);
178 }
179
180 // create local root signature
181 if (ShaderRecordSize > 0)
182 {
183 D3D12_ROOT_SIGNATURE_DESC RootSignatureDesc = {};
184 D3D12_ROOT_PARAMETER Param = {};
185
186 Param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
187 Param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
188 Param.Constants.Num32BitValues = ShaderRecordSize / 4;
189 Param.Constants.RegisterSpace = 1;
190 Param.Constants.ShaderRegister = 0;
191
192 RootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_LOCAL_ROOT_SIGNATURE;
193 RootSignatureDesc.NumParameters = 1;
194 RootSignatureDesc.pParameters = &Param;
195
196 CComPtr<ID3DBlob> signature;
197 hr = D3D12SerializeRootSignature(&RootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, nullptr);
198 ASSERT_HRESULT_SUCCEEDED(hr);
199
200 hr = Ctx.pDevice->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&Ctx.pLocalRootSignature));
144201 ASSERT_HRESULT_SUCCEEDED(hr);
145202 }
146203
164221 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG, &ShaderConfig});
165222
166223 D3D12_GLOBAL_ROOT_SIGNATURE GlobalRoot;
167 GlobalRoot.pGlobalRootSignature = Ctx.pRootSignature;
224 GlobalRoot.pGlobalRootSignature = Ctx.pGlobalRootSignature;
168225 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE, &GlobalRoot});
226
227 D3D12_LOCAL_ROOT_SIGNATURE LocalRoot;
228 LocalRoot.pLocalRootSignature = Ctx.pLocalRootSignature;
229 if (Ctx.pLocalRootSignature)
230 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_LOCAL_ROOT_SIGNATURE, &LocalRoot});
169231
170232 D3D12_STATE_OBJECT_DESC RTPipelineDesc;
171233 RTPipelineDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
178240 hr = Ctx.pRayTracingSO->QueryInterface(IID_PPV_ARGS(&Ctx.pStateObjectProperties));
179241 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to get state object properties";
180242 }
181
182 // create descriptor heap
183 {
184 D3D12_DESCRIPTOR_HEAP_DESC Desc = {};
185
186 Desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
187 Desc.NumDescriptors = Ctx.DescriptorHeapSize;
188 Desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
189 Desc.NodeMask = 0;
190
191 hr = Ctx.pDevice->CreateDescriptorHeap(&Desc, IID_PPV_ARGS(&Ctx.pDescHeap));
192 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to create descriptor heap";
193
194 Ctx.DescHeapCount = 0;
195 Ctx.DescHandleSize = Ctx.pDevice->GetDescriptorHandleIncrementSize(Desc.Type);
196
197 D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {};
198
199 UAVDesc.Format = DXGI_FORMAT_R8G8B8A8_UNORM;
200 UAVDesc.ViewDimension = D3D12_UAV_DIMENSION_TEXTURE2D;
201
202 D3D12_CPU_DESCRIPTOR_HANDLE UAVHandle = Ctx.pDescHeap->GetCPUDescriptorHandleForHeapStart();
203 ASSERT_LT(Ctx.DescHeapCount, Ctx.DescriptorHeapSize);
204 UAVHandle.ptr += Ctx.DescHandleSize * Ctx.DescHeapCount++;
205 Ctx.pDevice->CreateUnorderedAccessView(pTestingSwapChainD3D12->GetD3D12RenderTarget(), nullptr, &UAVDesc, UAVHandle);
206 }
243 }
244
245 template <typename PSOCtorType>
246 void InitializeRTContext(RTContext& Ctx, ISwapChain* pSwapChain, Uint32 ShaderRecordSize, PSOCtorType&& PSOCtor)
247 {
248 InitializeRTContext(Ctx, pSwapChain, ShaderRecordSize, PSOCtor, [](std::vector<D3D12_DESCRIPTOR_RANGE>&) {});
207249 }
208250
209251 void CreateBLAS(RTContext& Ctx, D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS& BottomLevelInputs)
239281
240282 auto hr = Ctx.pDevice->CreateCommittedResource(&HeapProps, D3D12_HEAP_FLAG_NONE,
241283 &ASDesc, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, nullptr,
242 IID_PPV_ARGS(&Ctx.pBLAS));
284 IID_PPV_ARGS(&Ctx.BLAS.pAS));
243285 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to create acceleration structure";
244286
245 Ctx.BLASBuildScratchSize = BottomLevelPrebuildInfo.ScratchDataSizeInBytes;
246 Ctx.BLASUpdateScratchSize = BottomLevelPrebuildInfo.UpdateScratchDataSizeInBytes;
287 Ctx.BLAS.BuildScratchSize = BottomLevelPrebuildInfo.ScratchDataSizeInBytes;
288 Ctx.BLAS.UpdateScratchSize = BottomLevelPrebuildInfo.UpdateScratchDataSizeInBytes;
247289 }
248290
249291 void CreateTLAS(RTContext& Ctx, D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS& TopLevelInputs)
279321
280322 auto hr = Ctx.pDevice->CreateCommittedResource(&HeapProps, D3D12_HEAP_FLAG_NONE,
281323 &ASDesc, D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, nullptr,
282 IID_PPV_ARGS(&Ctx.pTLAS));
324 IID_PPV_ARGS(&Ctx.TLAS.pAS));
283325 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to create acceleration structure";
284326
285 Ctx.TLASBuildScratchSize = TopLevelPrebuildInfo.ScratchDataSizeInBytes;
286 Ctx.TLASUpdateScratchSize = TopLevelPrebuildInfo.UpdateScratchDataSizeInBytes;
327 Ctx.TLAS.BuildScratchSize = TopLevelPrebuildInfo.ScratchDataSizeInBytes;
328 Ctx.TLAS.UpdateScratchSize = TopLevelPrebuildInfo.UpdateScratchDataSizeInBytes;
287329
288330 D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = {};
289331 SRVDesc.ViewDimension = D3D12_SRV_DIMENSION_RAYTRACING_ACCELERATION_STRUCTURE;
290332 SRVDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING;
291333 SRVDesc.Format = DXGI_FORMAT_UNKNOWN;
292 SRVDesc.RaytracingAccelerationStructure.Location = Ctx.pTLAS->GetGPUVirtualAddress();
334 SRVDesc.RaytracingAccelerationStructure.Location = Ctx.TLAS.pAS->GetGPUVirtualAddress();
293335
294336 D3D12_CPU_DESCRIPTOR_HANDLE DescHandle = Ctx.pDescHeap->GetCPUDescriptorHandleForHeapStart();
295337 ASSERT_LT(Ctx.DescHeapCount, Ctx.DescriptorHeapSize);
338 ASSERT_TRUE(Ctx.DescHeapCount == 1);
296339 DescHandle.ptr += Ctx.DescHandleSize * Ctx.DescHeapCount++;
297340
298341 Ctx.pDevice->CreateShaderResourceView(nullptr, &SRVDesc, DescHandle);
299342 }
300343
301 void CreateRTBuffers(RTContext& Ctx, Uint32 VBSize, Uint32 IBSize, Uint32 InstanceCount, Uint32 NumMissShaders, Uint32 NumHitShaders)
344 void CreateRTBuffers(RTContext& Ctx, Uint32 VBSize, Uint32 IBSize, Uint32 InstanceCount, Uint32 NumMissShaders, Uint32 NumHitShaders, Uint32 ShaderRecordSize = 0, size_t UploadSize = 0)
302345 {
303346 D3D12_RESOURCE_DESC BuffDesc = {};
304347 BuffDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
319362 HeapProps.CreationNodeMask = 1;
320363 HeapProps.VisibleNodeMask = 1;
321364
322 BuffDesc.Width = std::max(Ctx.BLASBuildScratchSize, Ctx.BLASUpdateScratchSize);
323 BuffDesc.Width = std::max(BuffDesc.Width, Ctx.TLASBuildScratchSize);
324 BuffDesc.Width = std::max(BuffDesc.Width, Ctx.TLASUpdateScratchSize);
365 BuffDesc.Width = std::max(Ctx.BLAS.BuildScratchSize, Ctx.BLAS.UpdateScratchSize);
366 BuffDesc.Width = std::max(BuffDesc.Width, Ctx.TLAS.BuildScratchSize);
367 BuffDesc.Width = std::max(BuffDesc.Width, Ctx.TLAS.UpdateScratchSize);
325368
326369 auto hr = Ctx.pDevice->CreateCommittedResource(&HeapProps, D3D12_HEAP_FLAG_NONE,
327370 &BuffDesc, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, nullptr,
328371 IID_PPV_ARGS(&Ctx.pScratchBuffer));
329372 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to create buffer";
330
331 size_t UploadSize = 0;
332373
333374 if (VBSize > 0)
334375 {
365406
366407 // SBT
367408 {
368 const UINT64 handleSize = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
409 const UINT64 RecordSize = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES + ShaderRecordSize;
369410 const UINT64 align = D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT;
370411
371 BuffDesc.Width = Align(handleSize, align);
372 BuffDesc.Width = Align(BuffDesc.Width + NumMissShaders * handleSize, align);
373 BuffDesc.Width = Align(BuffDesc.Width + NumHitShaders * handleSize, align);
412 BuffDesc.Width = Align(RecordSize, align);
413 BuffDesc.Width = Align(BuffDesc.Width + NumMissShaders * RecordSize, align);
414 BuffDesc.Width = Align(BuffDesc.Width + NumHitShaders * RecordSize, align);
374415
375416 hr = Ctx.pDevice->CreateCommittedResource(&HeapProps, D3D12_HEAP_FLAG_NONE,
376417 &BuffDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr,
419460 const auto& SCDesc = pSwapChain->GetDesc();
420461
421462 RTContext Ctx = {};
422 InitializeRTContext(Ctx, pSwapChain,
463 InitializeRTContext(Ctx, pSwapChain, 0,
423464 [pEnv](auto& Subobjects, auto& ExportDescs, auto& LibDescs, auto& HitGroups, auto& ShadersByteCode) {
424465 ShadersByteCode.resize(3);
425466 ExportDescs.resize(ShadersByteCode.size());
486527 D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS& TopLevelInputs = TLASDesc.Inputs;
487528 D3D12_RAYTRACING_INSTANCE_DESC Instance = {};
488529
489 const float3 Vertices[] = //
490 {
491 float3{0.25f, 0.25f, 0.0f},
492 float3{0.75f, 0.25f, 0.0f},
493 float3{0.50f, 0.75f, 0.0f} //
494 };
530 const auto& Vertices = TestingConstants::TriangleClosestHit::Vertices;
495531
496532 Geometry.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES;
497533 Geometry.Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE;
517553 Instance.InstanceContributionToHitGroupIndex = 0;
518554 Instance.InstanceMask = 0xFF;
519555 Instance.Flags = D3D12_RAYTRACING_INSTANCE_FLAG_NONE;
520 Instance.AccelerationStructure = Ctx.pBLAS->GetGPUVirtualAddress();
556 Instance.AccelerationStructure = Ctx.BLAS.pAS->GetGPUVirtualAddress();
521557 Instance.Transform[0][0] = 1.0f;
522558 Instance.Transform[1][1] = 1.0f;
523559 Instance.Transform[2][2] = 1.0f;
556592
557593 Geometry.Triangles.VertexBuffer.StartAddress = Ctx.pVertexBuffer->GetGPUVirtualAddress();
558594
559 BLASDesc.DestAccelerationStructureData = Ctx.pBLAS->GetGPUVirtualAddress();
595 BLASDesc.DestAccelerationStructureData = Ctx.BLAS.pAS->GetGPUVirtualAddress();
560596 BLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
561597 BLASDesc.SourceAccelerationStructureData = 0;
562598
578614
579615 TopLevelInputs.InstanceDescs = Ctx.pInstanceBuffer->GetGPUVirtualAddress();
580616
581 TLASDesc.DestAccelerationStructureData = Ctx.pTLAS->GetGPUVirtualAddress();
617 TLASDesc.DestAccelerationStructureData = Ctx.TLAS.pAS->GetGPUVirtualAddress();
582618 TLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
583619 TLASDesc.SourceAccelerationStructureData = 0;
584620
597633 ID3D12DescriptorHeap* DescHeaps[] = {Ctx.pDescHeap};
598634
599635 Ctx.pCmdList->SetPipelineState1(Ctx.pRayTracingSO);
600 Ctx.pCmdList->SetComputeRootSignature(Ctx.pRootSignature);
636 Ctx.pCmdList->SetComputeRootSignature(Ctx.pGlobalRootSignature);
601637
602638 Ctx.pCmdList->SetDescriptorHeaps(_countof(DescHeaps), &DescHeaps[0]);
603639 Ctx.pCmdList->SetComputeRootDescriptorTable(0, DescHeaps[0]->GetGPUDescriptorHandleForHeapStart());
656692 const auto& SCDesc = pSwapChain->GetDesc();
657693
658694 RTContext Ctx = {};
659 InitializeRTContext(Ctx, pSwapChain,
695 InitializeRTContext(Ctx, pSwapChain, 0,
660696 [pEnv](auto& Subobjects, auto& ExportDescs, auto& LibDescs, auto& HitGroups, auto& ShadersByteCode) {
661697 ShadersByteCode.resize(4);
662698 ExportDescs.resize(ShadersByteCode.size());
737773 D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS& TopLevelInputs = TLASDesc.Inputs;
738774 D3D12_RAYTRACING_INSTANCE_DESC Instance = {};
739775
740 const float3 Vertices[] = //
741 {
742 float3{0.25f, 0.25f, 0.0f}, float3{0.75f, 0.25f, 0.0f}, float3{0.50f, 0.75f, 0.0f},
743 float3{0.50f, 0.10f, 0.1f}, float3{0.90f, 0.90f, 0.1f}, float3{0.10f, 0.90f, 0.1f},
744 float3{0.40f, 1.00f, 0.2f}, float3{0.20f, 0.40f, 0.2f}, float3{1.00f, 0.70f, 0.2f} //
745 };
776 const auto& Vertices = TestingConstants::TriangleAnyHit::Vertices;
746777
747778 Geometry.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES;
748779 Geometry.Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_NONE;
768799 Instance.InstanceContributionToHitGroupIndex = 0;
769800 Instance.InstanceMask = 0xFF;
770801 Instance.Flags = D3D12_RAYTRACING_INSTANCE_FLAG_NONE;
771 Instance.AccelerationStructure = Ctx.pBLAS->GetGPUVirtualAddress();
802 Instance.AccelerationStructure = Ctx.BLAS.pAS->GetGPUVirtualAddress();
772803 Instance.Transform[0][0] = 1.0f;
773804 Instance.Transform[1][1] = 1.0f;
774805 Instance.Transform[2][2] = 1.0f;
807838
808839 Geometry.Triangles.VertexBuffer.StartAddress = Ctx.pVertexBuffer->GetGPUVirtualAddress();
809840
810 BLASDesc.DestAccelerationStructureData = Ctx.pBLAS->GetGPUVirtualAddress();
841 BLASDesc.DestAccelerationStructureData = Ctx.BLAS.pAS->GetGPUVirtualAddress();
811842 BLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
812843 BLASDesc.SourceAccelerationStructureData = 0;
813844
829860
830861 TopLevelInputs.InstanceDescs = Ctx.pInstanceBuffer->GetGPUVirtualAddress();
831862
832 TLASDesc.DestAccelerationStructureData = Ctx.pTLAS->GetGPUVirtualAddress();
863 TLASDesc.DestAccelerationStructureData = Ctx.TLAS.pAS->GetGPUVirtualAddress();
833864 TLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
834865 TLASDesc.SourceAccelerationStructureData = 0;
835866
848879 ID3D12DescriptorHeap* DescHeaps[] = {Ctx.pDescHeap};
849880
850881 Ctx.pCmdList->SetPipelineState1(Ctx.pRayTracingSO);
851 Ctx.pCmdList->SetComputeRootSignature(Ctx.pRootSignature);
882 Ctx.pCmdList->SetComputeRootSignature(Ctx.pGlobalRootSignature);
852883
853884 Ctx.pCmdList->SetDescriptorHeaps(_countof(DescHeaps), &DescHeaps[0]);
854885 Ctx.pCmdList->SetComputeRootDescriptorTable(0, DescHeaps[0]->GetGPUDescriptorHandleForHeapStart());
907938 const auto& SCDesc = pSwapChain->GetDesc();
908939
909940 RTContext Ctx = {};
910 InitializeRTContext(Ctx, pSwapChain,
941 InitializeRTContext(Ctx, pSwapChain, 0,
911942 [pEnv](auto& Subobjects, auto& ExportDescs, auto& LibDescs, auto& HitGroups, auto& ShadersByteCode) {
912943 ShadersByteCode.resize(4);
913944 ExportDescs.resize(ShadersByteCode.size());
9881019 D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS& TopLevelInputs = TLASDesc.Inputs;
9891020 D3D12_RAYTRACING_INSTANCE_DESC Instance = {};
9901021
991 const float3 Boxes[] = //
992 {
993 float3{0.25f, 0.5f, 2.0f} - float3{1.0f, 1.0f, 1.0f},
994 float3{0.25f, 0.5f, 2.0f} + float3{1.0f, 1.0f, 1.0f} //
995 };
1022 const auto& Boxes = TestingConstants::ProceduralIntersection::Boxes;
9961023
9971024 Geometry.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS;
9981025 Geometry.Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE;
10131040 Instance.InstanceContributionToHitGroupIndex = 0;
10141041 Instance.InstanceMask = 0xFF;
10151042 Instance.Flags = D3D12_RAYTRACING_INSTANCE_FLAG_NONE;
1016 Instance.AccelerationStructure = Ctx.pBLAS->GetGPUVirtualAddress();
1043 Instance.AccelerationStructure = Ctx.BLAS.pAS->GetGPUVirtualAddress();
10171044 Instance.Transform[0][0] = 1.0f;
10181045 Instance.Transform[1][1] = 1.0f;
10191046 Instance.Transform[2][2] = 1.0f;
10521079
10531080 Geometry.AABBs.AABBs.StartAddress = Ctx.pVertexBuffer->GetGPUVirtualAddress();
10541081
1055 BLASDesc.DestAccelerationStructureData = Ctx.pBLAS->GetGPUVirtualAddress();
1082 BLASDesc.DestAccelerationStructureData = Ctx.BLAS.pAS->GetGPUVirtualAddress();
10561083 BLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
10571084 BLASDesc.SourceAccelerationStructureData = 0;
10581085
10741101
10751102 TopLevelInputs.InstanceDescs = Ctx.pInstanceBuffer->GetGPUVirtualAddress();
10761103
1077 TLASDesc.DestAccelerationStructureData = Ctx.pTLAS->GetGPUVirtualAddress();
1104 TLASDesc.DestAccelerationStructureData = Ctx.TLAS.pAS->GetGPUVirtualAddress();
10781105 TLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
10791106 TLASDesc.SourceAccelerationStructureData = 0;
10801107
10931120 ID3D12DescriptorHeap* DescHeaps[] = {Ctx.pDescHeap};
10941121
10951122 Ctx.pCmdList->SetPipelineState1(Ctx.pRayTracingSO);
1096 Ctx.pCmdList->SetComputeRootSignature(Ctx.pRootSignature);
1123 Ctx.pCmdList->SetComputeRootSignature(Ctx.pGlobalRootSignature);
10971124
10981125 Ctx.pCmdList->SetDescriptorHeaps(_countof(DescHeaps), &DescHeaps[0]);
10991126 Ctx.pCmdList->SetComputeRootDescriptorTable(0, DescHeaps[0]->GetGPUDescriptorHandleForHeapStart());
11431170 pEnv->ExecuteCommandList(Ctx.pCmdList, true);
11441171 }
11451172
1173
1174 void RayTracingMultiGeometryReferenceD3D12(ISwapChain* pSwapChain)
1175 {
1176 static constexpr Uint32 InstanceCount = TestingConstants::MultiGeometry::InstanceCount;
1177 static constexpr Uint32 GeometryCount = 3;
1178 static constexpr Uint32 HitGroupCount = InstanceCount * GeometryCount;
1179
1180 auto* pEnv = TestingEnvironmentD3D12::GetInstance();
1181 auto* pTestingSwapChainD3D12 = ValidatedCast<TestingSwapChainD3D12>(pSwapChain);
1182
1183 const auto& SCDesc = pSwapChain->GetDesc();
1184
1185 RTContext Ctx = {};
1186 InitializeRTContext(
1187 Ctx, pSwapChain,
1188 TestingConstants::MultiGeometry::ShaderRecordSize,
1189 [pEnv](auto& Subobjects, auto& ExportDescs, auto& LibDescs, auto& HitGroups, auto& ShadersByteCode) {
1190 ShadersByteCode.resize(4);
1191 ExportDescs.resize(ShadersByteCode.size());
1192 LibDescs.resize(ShadersByteCode.size());
1193 HitGroups.resize(2);
1194
1195 auto hr = pEnv->CompileDXILShader(HLSL::RayTracingTest4_RG, L"main", nullptr, 0, L"lib_6_5", &ShadersByteCode[0]);
1196 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to compile ray gen shader";
1197
1198 hr = pEnv->CompileDXILShader(HLSL::RayTracingTest4_RM, L"main", nullptr, 0, L"lib_6_5", &ShadersByteCode[1]);
1199 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to compile ray miss shader";
1200
1201 hr = pEnv->CompileDXILShader(HLSL::RayTracingTest4_RCH1, L"main", nullptr, 0, L"lib_6_5", &ShadersByteCode[2]);
1202 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to compile ray closest hit shader";
1203
1204 hr = pEnv->CompileDXILShader(HLSL::RayTracingTest4_RCH2, L"main", nullptr, 0, L"lib_6_5", &ShadersByteCode[3]);
1205 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to compile ray closest hit shader";
1206
1207 D3D12_EXPORT_DESC& RGExportDesc = ExportDescs[0];
1208 D3D12_DXIL_LIBRARY_DESC& RGLibDesc = LibDescs[0];
1209 RGExportDesc.Flags = D3D12_EXPORT_FLAG_NONE;
1210 RGExportDesc.ExportToRename = L"main"; // shader entry name
1211 RGExportDesc.Name = L"Main";
1212 RGLibDesc.DXILLibrary.BytecodeLength = ShadersByteCode[0]->GetBufferSize();
1213 RGLibDesc.DXILLibrary.pShaderBytecode = ShadersByteCode[0]->GetBufferPointer();
1214 RGLibDesc.NumExports = 1;
1215 RGLibDesc.pExports = &RGExportDesc;
1216 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY, &RGLibDesc});
1217
1218 D3D12_EXPORT_DESC& RMExportDesc = ExportDescs[1];
1219 D3D12_DXIL_LIBRARY_DESC& RMLibDesc = LibDescs[1];
1220 RMExportDesc.Flags = D3D12_EXPORT_FLAG_NONE;
1221 RMExportDesc.ExportToRename = L"main"; // shader entry name
1222 RMExportDesc.Name = L"Miss";
1223 RMLibDesc.DXILLibrary.BytecodeLength = ShadersByteCode[1]->GetBufferSize();
1224 RMLibDesc.DXILLibrary.pShaderBytecode = ShadersByteCode[1]->GetBufferPointer();
1225 RMLibDesc.NumExports = 1;
1226 RMLibDesc.pExports = &RMExportDesc;
1227 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY, &RMLibDesc});
1228
1229 D3D12_EXPORT_DESC& RCH1ExportDesc = ExportDescs[2];
1230 D3D12_DXIL_LIBRARY_DESC& RCH1LibDesc = LibDescs[2];
1231 RCH1ExportDesc.Flags = D3D12_EXPORT_FLAG_NONE;
1232 RCH1ExportDesc.ExportToRename = L"main"; // shader entry name
1233 RCH1ExportDesc.Name = L"ClosestHitShader1";
1234 RCH1LibDesc.DXILLibrary.BytecodeLength = ShadersByteCode[2]->GetBufferSize();
1235 RCH1LibDesc.DXILLibrary.pShaderBytecode = ShadersByteCode[2]->GetBufferPointer();
1236 RCH1LibDesc.NumExports = 1;
1237 RCH1LibDesc.pExports = &RCH1ExportDesc;
1238 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY, &RCH1LibDesc});
1239
1240 D3D12_EXPORT_DESC& RCH2ExportDesc = ExportDescs[3];
1241 D3D12_DXIL_LIBRARY_DESC& RCH2LibDesc = LibDescs[3];
1242 RCH2ExportDesc.Flags = D3D12_EXPORT_FLAG_NONE;
1243 RCH2ExportDesc.ExportToRename = L"main"; // shader entry name
1244 RCH2ExportDesc.Name = L"ClosestHitShader2";
1245 RCH2LibDesc.DXILLibrary.BytecodeLength = ShadersByteCode[3]->GetBufferSize();
1246 RCH2LibDesc.DXILLibrary.pShaderBytecode = ShadersByteCode[3]->GetBufferPointer();
1247 RCH2LibDesc.NumExports = 1;
1248 RCH2LibDesc.pExports = &RCH2ExportDesc;
1249 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY, &RCH2LibDesc});
1250
1251 D3D12_HIT_GROUP_DESC& HitGroup1Desc = HitGroups[0];
1252 HitGroup1Desc.HitGroupExport = L"HitGroup1";
1253 HitGroup1Desc.Type = D3D12_HIT_GROUP_TYPE_TRIANGLES;
1254 HitGroup1Desc.ClosestHitShaderImport = L"ClosestHitShader1";
1255 HitGroup1Desc.AnyHitShaderImport = nullptr;
1256 HitGroup1Desc.IntersectionShaderImport = nullptr;
1257 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP, &HitGroup1Desc});
1258
1259 D3D12_HIT_GROUP_DESC& HitGroup2Desc = HitGroups[1];
1260 HitGroup2Desc.HitGroupExport = L"HitGroup2";
1261 HitGroup2Desc.Type = D3D12_HIT_GROUP_TYPE_TRIANGLES;
1262 HitGroup2Desc.ClosestHitShaderImport = L"ClosestHitShader2";
1263 HitGroup2Desc.AnyHitShaderImport = nullptr;
1264 HitGroup2Desc.IntersectionShaderImport = nullptr;
1265 Subobjects.push_back({D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP, &HitGroup2Desc});
1266 },
1267 [](std::vector<D3D12_DESCRIPTOR_RANGE>& DescriptorRanges) {
1268 D3D12_DESCRIPTOR_RANGE Range = {};
1269 Range.RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_SRV;
1270 Range.NumDescriptors = 1;
1271
1272 Range.BaseShaderRegister = 1;
1273 Range.OffsetInDescriptorsFromTableStart = 2;
1274 DescriptorRanges.push_back(Range); // g_Vertices
1275
1276 Range.BaseShaderRegister = 4;
1277 Range.OffsetInDescriptorsFromTableStart = 3;
1278 DescriptorRanges.push_back(Range); // g_Primitives
1279
1280 Range.BaseShaderRegister = 2;
1281 Range.NumDescriptors = 2;
1282 Range.OffsetInDescriptorsFromTableStart = 4;
1283 DescriptorRanges.push_back(Range); // g_PerInstance[2]
1284 });
1285
1286 const auto& PrimitiveOffsets = TestingConstants::MultiGeometry::PrimitiveOffsets;
1287 const auto& Primitives = TestingConstants::MultiGeometry::Primitives;
1288 const auto& Vertices = TestingConstants::MultiGeometry::Vertices;
1289
1290 // create acceleration structurea
1291 {
1292 const auto& Indices = TestingConstants::MultiGeometry::Indices;
1293
1294 D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC BLASDesc = {};
1295 D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS& BottomLevelInputs = BLASDesc.Inputs;
1296 D3D12_RAYTRACING_GEOMETRY_DESC Geometries[3] = {};
1297 D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC TLASDesc = {};
1298 D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS& TopLevelInputs = TLASDesc.Inputs;
1299 D3D12_RAYTRACING_INSTANCE_DESC Instances[2] = {};
1300
1301 static_assert(GeometryCount == _countof(Geometries), "size mismatch");
1302 static_assert(InstanceCount == _countof(Instances), "size mismatch");
1303
1304 Geometries[0].Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES;
1305 Geometries[0].Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE;
1306 Geometries[0].Triangles.VertexBuffer.StartAddress = 0;
1307 Geometries[0].Triangles.VertexBuffer.StrideInBytes = sizeof(Vertices[0]);
1308 Geometries[0].Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT;
1309 Geometries[0].Triangles.VertexCount = _countof(Vertices);
1310 Geometries[0].Triangles.IndexCount = PrimitiveOffsets[1] * 3;
1311 Geometries[0].Triangles.IndexFormat = DXGI_FORMAT_R32_UINT;
1312 Geometries[0].Triangles.IndexBuffer = 0;
1313 Geometries[0].Triangles.Transform3x4 = 0;
1314
1315 Geometries[1].Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES;
1316 Geometries[1].Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE;
1317 Geometries[1].Triangles.VertexBuffer.StartAddress = 0;
1318 Geometries[1].Triangles.VertexBuffer.StrideInBytes = sizeof(Vertices[0]);
1319 Geometries[1].Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT;
1320 Geometries[1].Triangles.VertexCount = _countof(Vertices);
1321 Geometries[1].Triangles.IndexCount = (PrimitiveOffsets[2] - PrimitiveOffsets[1]) * 3;
1322 Geometries[1].Triangles.IndexFormat = DXGI_FORMAT_R32_UINT;
1323 Geometries[1].Triangles.IndexBuffer = 0;
1324 Geometries[1].Triangles.Transform3x4 = 0;
1325
1326 Geometries[2].Type = D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES;
1327 Geometries[2].Flags = D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE;
1328 Geometries[2].Triangles.VertexBuffer.StartAddress = 0;
1329 Geometries[2].Triangles.VertexBuffer.StrideInBytes = sizeof(Vertices[0]);
1330 Geometries[2].Triangles.VertexFormat = DXGI_FORMAT_R32G32B32_FLOAT;
1331 Geometries[2].Triangles.VertexCount = _countof(Vertices);
1332 Geometries[2].Triangles.IndexCount = (_countof(Primitives) - PrimitiveOffsets[2]) * 3;
1333 Geometries[2].Triangles.IndexFormat = DXGI_FORMAT_R32_UINT;
1334 Geometries[2].Triangles.IndexBuffer = 0;
1335 Geometries[2].Triangles.Transform3x4 = 0;
1336
1337 BottomLevelInputs.pGeometryDescs = Geometries;
1338 BottomLevelInputs.NumDescs = _countof(Geometries);
1339
1340 TopLevelInputs.NumDescs = _countof(Instances);
1341
1342 CreateBLAS(Ctx, BottomLevelInputs);
1343 CreateTLAS(Ctx, TopLevelInputs);
1344 CreateRTBuffers(Ctx, sizeof(Vertices), sizeof(Indices), InstanceCount, 1, HitGroupCount,
1345 TestingConstants::MultiGeometry::ShaderRecordSize,
1346 sizeof(PrimitiveOffsets) + sizeof(Primitives));
1347
1348 Instances[0].InstanceID = 0;
1349 Instances[0].InstanceContributionToHitGroupIndex = 0;
1350 Instances[0].InstanceMask = 0xFF;
1351 Instances[0].Flags = D3D12_RAYTRACING_INSTANCE_FLAG_NONE;
1352 Instances[0].AccelerationStructure = Ctx.BLAS.pAS->GetGPUVirtualAddress();
1353 Instances[0].Transform[0][0] = 1.0f;
1354 Instances[0].Transform[1][1] = 1.0f;
1355 Instances[0].Transform[2][2] = 1.0f;
1356
1357 Instances[1].InstanceID = 0;
1358 Instances[1].InstanceContributionToHitGroupIndex = HitGroupCount / 2;
1359 Instances[1].InstanceMask = 0xFF;
1360 Instances[1].Flags = D3D12_RAYTRACING_INSTANCE_FLAG_NONE;
1361 Instances[1].AccelerationStructure = Ctx.BLAS.pAS->GetGPUVirtualAddress();
1362 Instances[1].Transform[0][0] = 1.0f;
1363 Instances[1].Transform[1][1] = 1.0f;
1364 Instances[1].Transform[2][2] = 1.0f;
1365 Instances[1].Transform[0][3] = 0.1f;
1366 Instances[1].Transform[1][3] = 0.5f;
1367 Instances[1].Transform[2][3] = 0.0f;
1368
1369 UpdateBuffer(Ctx, Ctx.pVertexBuffer, 0, Vertices, sizeof(Vertices));
1370 UpdateBuffer(Ctx, Ctx.pIndexBuffer, 0, Indices, sizeof(Indices));
1371 UpdateBuffer(Ctx, Ctx.pInstanceBuffer, 0, Instances, sizeof(Instances));
1372
1373 // vertex & instance buffer barrier
1374 {
1375 std::vector<D3D12_RESOURCE_BARRIER> Barriers;
1376 D3D12_RESOURCE_BARRIER Barrier;
1377
1378 Barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
1379 Barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
1380 Barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
1381 Barrier.Transition.StateBefore = D3D12_RESOURCE_STATE_COPY_DEST;
1382 Barrier.Transition.StateAfter = D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE;
1383
1384 if (Ctx.pVertexBuffer)
1385 {
1386 Barrier.Transition.pResource = Ctx.pVertexBuffer;
1387 Barriers.push_back(Barrier);
1388 }
1389 if (Ctx.pIndexBuffer)
1390 {
1391 Barrier.Transition.pResource = Ctx.pIndexBuffer;
1392 Barriers.push_back(Barrier);
1393 }
1394 if (Ctx.pInstanceBuffer)
1395 {
1396 Barrier.Transition.pResource = Ctx.pInstanceBuffer;
1397 Barriers.push_back(Barrier);
1398 }
1399 Ctx.pCmdList->ResourceBarrier(static_cast<UINT>(Barriers.size()), Barriers.data());
1400 }
1401
1402 Geometries[0].Triangles.VertexBuffer.StartAddress = Ctx.pVertexBuffer->GetGPUVirtualAddress();
1403 Geometries[1].Triangles.VertexBuffer.StartAddress = Ctx.pVertexBuffer->GetGPUVirtualAddress();
1404 Geometries[2].Triangles.VertexBuffer.StartAddress = Ctx.pVertexBuffer->GetGPUVirtualAddress();
1405
1406 Geometries[0].Triangles.IndexBuffer = Ctx.pIndexBuffer->GetGPUVirtualAddress() + PrimitiveOffsets[0] * sizeof(uint) * 3;
1407 Geometries[1].Triangles.IndexBuffer = Ctx.pIndexBuffer->GetGPUVirtualAddress() + PrimitiveOffsets[1] * sizeof(uint) * 3;
1408 Geometries[2].Triangles.IndexBuffer = Ctx.pIndexBuffer->GetGPUVirtualAddress() + PrimitiveOffsets[2] * sizeof(uint) * 3;
1409
1410 BLASDesc.DestAccelerationStructureData = Ctx.BLAS.pAS->GetGPUVirtualAddress();
1411 BLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
1412 BLASDesc.SourceAccelerationStructureData = 0;
1413
1414 ASSERT_TRUE(BLASDesc.DestAccelerationStructureData != 0);
1415 ASSERT_TRUE(BLASDesc.ScratchAccelerationStructureData != 0);
1416
1417 Ctx.pCmdList->BuildRaytracingAccelerationStructure(&BLASDesc, 0, nullptr);
1418
1419 // UAV barrier for scratch buffer
1420 {
1421 D3D12_RESOURCE_BARRIER Barrier;
1422 Barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV;
1423 Barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
1424 Barrier.UAV.pResource = Ctx.pScratchBuffer;
1425
1426 Ctx.pCmdList->ResourceBarrier(1, &Barrier);
1427 }
1428
1429 TopLevelInputs.InstanceDescs = Ctx.pInstanceBuffer->GetGPUVirtualAddress();
1430
1431 TLASDesc.DestAccelerationStructureData = Ctx.TLAS.pAS->GetGPUVirtualAddress();
1432 TLASDesc.ScratchAccelerationStructureData = Ctx.pScratchBuffer->GetGPUVirtualAddress();
1433 TLASDesc.SourceAccelerationStructureData = 0;
1434
1435 ASSERT_TRUE(TLASDesc.DestAccelerationStructureData != 0);
1436 ASSERT_TRUE(TLASDesc.ScratchAccelerationStructureData != 0);
1437
1438 Ctx.pCmdList->BuildRaytracingAccelerationStructure(&TLASDesc, 0, nullptr);
1439 }
1440
1441 // update descriptors
1442 CComPtr<ID3D12Resource> pPerInstanceBuffer;
1443 CComPtr<ID3D12Resource> pPrimitiveBuffer;
1444 {
1445 D3D12_RESOURCE_DESC BuffDesc = {};
1446 BuffDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
1447 BuffDesc.Alignment = 0;
1448 BuffDesc.Width = sizeof(PrimitiveOffsets);
1449 BuffDesc.Height = 1;
1450 BuffDesc.DepthOrArraySize = 1;
1451 BuffDesc.MipLevels = 1;
1452 BuffDesc.Format = DXGI_FORMAT_UNKNOWN;
1453 BuffDesc.SampleDesc.Count = 1;
1454 BuffDesc.SampleDesc.Quality = 0;
1455 BuffDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
1456 BuffDesc.Flags = D3D12_RESOURCE_FLAG_NONE;
1457
1458 D3D12_HEAP_PROPERTIES HeapProps;
1459 HeapProps.Type = D3D12_HEAP_TYPE_DEFAULT;
1460 HeapProps.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
1461 HeapProps.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
1462 HeapProps.CreationNodeMask = 1;
1463 HeapProps.VisibleNodeMask = 1;
1464
1465 auto hr = Ctx.pDevice->CreateCommittedResource(&HeapProps, D3D12_HEAP_FLAG_NONE,
1466 &BuffDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr,
1467 IID_PPV_ARGS(&pPerInstanceBuffer));
1468 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to create per instance buffer";
1469
1470 BuffDesc.Width = sizeof(Primitives);
1471
1472 hr = Ctx.pDevice->CreateCommittedResource(&HeapProps, D3D12_HEAP_FLAG_NONE,
1473 &BuffDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr,
1474 IID_PPV_ARGS(&pPrimitiveBuffer));
1475 ASSERT_HRESULT_SUCCEEDED(hr) << "Failed to create per instance buffer";
1476
1477 UpdateBuffer(Ctx, pPrimitiveBuffer, 0, Primitives, sizeof(Primitives));
1478 UpdateBuffer(Ctx, pPerInstanceBuffer, 0, PrimitiveOffsets, sizeof(PrimitiveOffsets));
1479
1480 // buffer barrier
1481 {
1482 D3D12_RESOURCE_BARRIER Barrier = {};
1483 Barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
1484 Barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
1485 Barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
1486 Barrier.Transition.StateBefore = D3D12_RESOURCE_STATE_COPY_DEST;
1487 Barrier.Transition.StateAfter = D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE;
1488 Barrier.Transition.pResource = pPerInstanceBuffer;
1489 Ctx.pCmdList->ResourceBarrier(1, &Barrier);
1490
1491 Barrier.Transition.pResource = pPrimitiveBuffer;
1492 Ctx.pCmdList->ResourceBarrier(1, &Barrier);
1493 }
1494
1495 D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = {};
1496 D3D12_CPU_DESCRIPTOR_HANDLE SRVHandle;
1497
1498 SRVDesc.Format = DXGI_FORMAT_UNKNOWN;
1499 SRVDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER;
1500 SRVDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING;
1501 SRVDesc.Buffer.NumElements = _countof(Vertices);
1502 SRVDesc.Buffer.StructureByteStride = sizeof(Vertices[0]);
1503
1504 ASSERT_LT(Ctx.DescHeapCount, Ctx.DescriptorHeapSize);
1505 ASSERT_TRUE(Ctx.DescHeapCount == 2);
1506 SRVHandle = Ctx.pDescHeap->GetCPUDescriptorHandleForHeapStart();
1507 SRVHandle.ptr += Ctx.DescHandleSize * Ctx.DescHeapCount++;
1508 Ctx.pDevice->CreateShaderResourceView(Ctx.pVertexBuffer, &SRVDesc, SRVHandle); // g_Vertices
1509
1510 SRVDesc.Buffer.NumElements = _countof(Primitives);
1511 SRVDesc.Buffer.StructureByteStride = sizeof(Primitives[0]);
1512 ASSERT_LT(Ctx.DescHeapCount, Ctx.DescriptorHeapSize);
1513 ASSERT_TRUE(Ctx.DescHeapCount == 3);
1514 SRVHandle = Ctx.pDescHeap->GetCPUDescriptorHandleForHeapStart();
1515 SRVHandle.ptr += Ctx.DescHandleSize * Ctx.DescHeapCount++;
1516 Ctx.pDevice->CreateShaderResourceView(pPrimitiveBuffer, &SRVDesc, SRVHandle); // g_Primitives
1517
1518 SRVDesc.Buffer.NumElements = _countof(PrimitiveOffsets);
1519 SRVDesc.Buffer.StructureByteStride = sizeof(PrimitiveOffsets[0]);
1520 ASSERT_LT(Ctx.DescHeapCount, Ctx.DescriptorHeapSize);
1521 ASSERT_TRUE(Ctx.DescHeapCount == 4);
1522 SRVHandle = Ctx.pDescHeap->GetCPUDescriptorHandleForHeapStart();
1523 SRVHandle.ptr += Ctx.DescHandleSize * Ctx.DescHeapCount++;
1524 Ctx.pDevice->CreateShaderResourceView(pPerInstanceBuffer, &SRVDesc, SRVHandle); // g_PerInstance[0]
1525
1526 ASSERT_TRUE(Ctx.DescHeapCount == 5);
1527 SRVHandle = Ctx.pDescHeap->GetCPUDescriptorHandleForHeapStart();
1528 SRVHandle.ptr += Ctx.DescHandleSize * Ctx.DescHeapCount++;
1529 Ctx.pDevice->CreateShaderResourceView(pPerInstanceBuffer, &SRVDesc, SRVHandle); // g_PerInstance[1]
1530 }
1531
1532 Ctx.ClearRenderTarget(pTestingSwapChainD3D12);
1533
1534 // trace rays
1535 {
1536 pTestingSwapChainD3D12->TransitionRenderTarget(Ctx.pCmdList, D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
1537
1538 ID3D12DescriptorHeap* DescHeaps[] = {Ctx.pDescHeap};
1539
1540 Ctx.pCmdList->SetPipelineState1(Ctx.pRayTracingSO);
1541 Ctx.pCmdList->SetComputeRootSignature(Ctx.pGlobalRootSignature);
1542
1543 Ctx.pCmdList->SetDescriptorHeaps(_countof(DescHeaps), &DescHeaps[0]);
1544 Ctx.pCmdList->SetComputeRootDescriptorTable(0, DescHeaps[0]->GetGPUDescriptorHandleForHeapStart());
1545
1546 D3D12_DISPATCH_RAYS_DESC Desc = {};
1547
1548 Desc.Width = SCDesc.Width;
1549 Desc.Height = SCDesc.Height;
1550 Desc.Depth = 1;
1551
1552 const UINT64 handleSize = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
1553 const UINT64 align = D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT;
1554 const UINT64 ShaderRecordSize = handleSize + TestingConstants::MultiGeometry::ShaderRecordSize;
1555 const size_t RayGenOffset = 0;
1556 const size_t RayMissOffset = Align(RayGenOffset + handleSize, align);
1557 const size_t HitGroupOffset = Align(RayMissOffset + handleSize, align);
1558 const auto& Weights = TestingConstants::MultiGeometry::Weights;
1559
1560 Desc.RayGenerationShaderRecord.StartAddress = Ctx.pSBTBuffer->GetGPUVirtualAddress() + RayGenOffset;
1561 Desc.RayGenerationShaderRecord.SizeInBytes = ShaderRecordSize;
1562 Desc.MissShaderTable.StartAddress = Ctx.pSBTBuffer->GetGPUVirtualAddress() + RayMissOffset;
1563 Desc.MissShaderTable.SizeInBytes = ShaderRecordSize;
1564 Desc.MissShaderTable.StrideInBytes = ShaderRecordSize;
1565 Desc.HitGroupTable.StartAddress = Ctx.pSBTBuffer->GetGPUVirtualAddress() + HitGroupOffset;
1566 Desc.HitGroupTable.SizeInBytes = ShaderRecordSize * HitGroupCount;
1567 Desc.HitGroupTable.StrideInBytes = ShaderRecordSize;
1568
1569 UpdateBuffer(Ctx, Ctx.pSBTBuffer, RayGenOffset, Ctx.pStateObjectProperties->GetShaderIdentifier(L"Main"), handleSize);
1570 UpdateBuffer(Ctx, Ctx.pSBTBuffer, RayMissOffset, Ctx.pStateObjectProperties->GetShaderIdentifier(L"Miss"), handleSize);
1571
1572 const auto SetHitGroup = [&](Uint32 Index, const wchar_t* GroupName, const void* ShaderRecord) {
1573 VERIFY_EXPR(Index < HitGroupCount);
1574 UINT64 Offset = HitGroupOffset + Index * ShaderRecordSize;
1575 UpdateBuffer(Ctx, Ctx.pSBTBuffer, Offset, Ctx.pStateObjectProperties->GetShaderIdentifier(GroupName), handleSize);
1576 UpdateBuffer(Ctx, Ctx.pSBTBuffer, Offset + handleSize, ShaderRecord, sizeof(Weights[0]));
1577 };
1578 // instance 1
1579 SetHitGroup(0, L"HitGroup1", &Weights[2]); // geometry 1
1580 SetHitGroup(1, L"HitGroup1", &Weights[0]); // geometry 2
1581 SetHitGroup(2, L"HitGroup1", &Weights[1]); // geometry 3
1582 // instance 2
1583 SetHitGroup(3, L"HitGroup2", &Weights[2]); // geometry 1
1584 SetHitGroup(4, L"HitGroup2", &Weights[1]); // geometry 2
1585 SetHitGroup(5, L"HitGroup2", &Weights[0]); // geometry 3
1586
1587 // SBT buffer barrier
1588 {
1589 D3D12_RESOURCE_BARRIER Barrier;
1590 Barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
1591 Barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
1592 Barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
1593 Barrier.Transition.StateBefore = D3D12_RESOURCE_STATE_COPY_DEST;
1594 Barrier.Transition.StateAfter = D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE;
1595 Barrier.Transition.pResource = Ctx.pSBTBuffer;
1596 Ctx.pCmdList->ResourceBarrier(1, &Barrier);
1597 }
1598
1599 Ctx.pCmdList->DispatchRays(&Desc);
1600 }
1601
1602 Ctx.pCmdList->Close();
1603
1604 pEnv->ExecuteCommandList(Ctx.pCmdList, true);
1605 }
1606
11461607 } // namespace Testing
11471608
11481609 } // namespace Diligent
3333 #include "gtest/gtest.h"
3434
3535 #include "InlineShaders/RayTracingTestHLSL.h"
36 #include "RayTracingTestConstants.hpp"
3637
3738 namespace Diligent
3839 {
4445 void RayTracingTriangleClosestHitReferenceD3D12(ISwapChain* pSwapChain);
4546 void RayTracingTriangleAnyHitReferenceD3D12(ISwapChain* pSwapChain);
4647 void RayTracingProceduralIntersectionReferenceD3D12(ISwapChain* pSwapChain);
48 void RayTracingMultiGeometryReferenceD3D12(ISwapChain* pSwapChain);
4749 #endif
4850
4951 #if VULKAN_SUPPORTED
5052 void RayTracingTriangleClosestHitReferenceVk(ISwapChain* pSwapChain);
5153 void RayTracingTriangleAnyHitReferenceVk(ISwapChain* pSwapChain);
5254 void RayTracingProceduralIntersectionReferenceVk(ISwapChain* pSwapChain);
55 void RayTracingMultiGeometryReferenceVk(ISwapChain* pSwapChain);
5356 #endif
5457
5558 } // namespace Testing
8285
8386 BottomLevelASDesc ASDesc;
8487 ASDesc.Name = "Triangle BLAS";
88 ASDesc.Flags = RAYTRACING_BUILD_AS_NONE;