summaryrefslogtreecommitdiffstats
path: root/Graphics/GraphicsEngineVulkan
diff options
context:
space:
mode:
authorazhirnov <zh1dron@gmail.com>2020-11-03 10:52:24 +0000
committerazhirnov <zh1dron@gmail.com>2020-11-03 11:16:11 +0000
commitefa43e2bd2475a4dec6771bf9759f6a99f7d77ed (patch)
treebbb257c3825ff07078626e3e137468f99d235a07 /Graphics/GraphicsEngineVulkan
parentFew improvements to ray tracing tests (diff)
downloadDiligentCore-efa43e2bd2475a4dec6771bf9759f6a99f7d77ed.tar.gz
DiligentCore-efa43e2bd2475a4dec6771bf9759f6a99f7d77ed.zip
fixed resource state transitions, some improvements for ray tracing
Diffstat (limited to 'Graphics/GraphicsEngineVulkan')
-rw-r--r--Graphics/GraphicsEngineVulkan/include/RenderDeviceVkImpl.hpp4
-rw-r--r--Graphics/GraphicsEngineVulkan/include/ShaderBindingTableVkImpl.hpp7
-rw-r--r--Graphics/GraphicsEngineVulkan/include/TopLevelASVkImpl.hpp5
-rw-r--r--Graphics/GraphicsEngineVulkan/src/DeviceContextVkImpl.cpp111
-rw-r--r--Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp11
-rw-r--r--Graphics/GraphicsEngineVulkan/src/ShaderBindingTableVkImpl.cpp45
-rw-r--r--Graphics/GraphicsEngineVulkan/src/ShaderResourceCacheVk.cpp40
-rw-r--r--Graphics/GraphicsEngineVulkan/src/VulkanTypeConversions.cpp29
8 files changed, 108 insertions, 144 deletions
diff --git a/Graphics/GraphicsEngineVulkan/include/RenderDeviceVkImpl.hpp b/Graphics/GraphicsEngineVulkan/include/RenderDeviceVkImpl.hpp
index e13f1a76..5440a6c8 100644
--- a/Graphics/GraphicsEngineVulkan/include/RenderDeviceVkImpl.hpp
+++ b/Graphics/GraphicsEngineVulkan/include/RenderDeviceVkImpl.hpp
@@ -201,6 +201,10 @@ public:
{
return GetPhysicalDevice().GetExtProperties().RayTracing.shaderGroupHandleSize;
}
+ Uint32 GetMaxShaderRecordStride() const
+ {
+ return GetPhysicalDevice().GetExtProperties().RayTracing.maxShaderGroupStride;
+ }
private:
template <typename PSOCreateInfoType>
diff --git a/Graphics/GraphicsEngineVulkan/include/ShaderBindingTableVkImpl.hpp b/Graphics/GraphicsEngineVulkan/include/ShaderBindingTableVkImpl.hpp
index 1b2db950..cef50a4e 100644
--- a/Graphics/GraphicsEngineVulkan/include/ShaderBindingTableVkImpl.hpp
+++ b/Graphics/GraphicsEngineVulkan/include/ShaderBindingTableVkImpl.hpp
@@ -51,10 +51,6 @@ public:
bool bIsDeviceInternal = false);
~ShaderBindingTableVkImpl();
- virtual void DILIGENT_CALL_TYPE Verify() const override;
-
- virtual void DILIGENT_CALL_TYPE Reset(const ShaderBindingTableDesc& Desc) override;
-
virtual void DILIGENT_CALL_TYPE ResetHitGroups(Uint32 HitShadersPerInstance) override;
virtual void DILIGENT_CALL_TYPE BindAll(const BindAllAttribs& Attribs) override;
@@ -68,9 +64,6 @@ public:
IMPLEMENT_QUERY_INTERFACE_IN_PLACE(IID_ShaderBindingTableVk, TShaderBindingTableBase);
private:
- void ValidateDesc(const ShaderBindingTableDesc& Desc) const;
-
-private:
RefCntAutoPtr<IBuffer> m_pBuffer;
};
diff --git a/Graphics/GraphicsEngineVulkan/include/TopLevelASVkImpl.hpp b/Graphics/GraphicsEngineVulkan/include/TopLevelASVkImpl.hpp
index b2801eca..0f8a94c8 100644
--- a/Graphics/GraphicsEngineVulkan/include/TopLevelASVkImpl.hpp
+++ b/Graphics/GraphicsEngineVulkan/include/TopLevelASVkImpl.hpp
@@ -34,15 +34,16 @@
#include "RenderDeviceVkImpl.hpp"
#include "TopLevelASVk.h"
#include "TopLevelASBase.hpp"
+#include "BottomLevelASVkImpl.hpp"
#include "VulkanUtilities/VulkanObjectWrappers.hpp"
namespace Diligent
{
-class TopLevelASVkImpl final : public TopLevelASBase<ITopLevelASVk, RenderDeviceVkImpl>
+class TopLevelASVkImpl final : public TopLevelASBase<ITopLevelASVk, BottomLevelASVkImpl, RenderDeviceVkImpl>
{
public:
- using TTopLevelASBase = TopLevelASBase<ITopLevelASVk, RenderDeviceVkImpl>;
+ using TTopLevelASBase = TopLevelASBase<ITopLevelASVk, BottomLevelASVkImpl, RenderDeviceVkImpl>;
TopLevelASVkImpl(IReferenceCounters* pRefCounters,
RenderDeviceVkImpl* pRenderDeviceVk,
diff --git a/Graphics/GraphicsEngineVulkan/src/DeviceContextVkImpl.cpp b/Graphics/GraphicsEngineVulkan/src/DeviceContextVkImpl.cpp
index 69a3a1ba..a77fb96d 100644
--- a/Graphics/GraphicsEngineVulkan/src/DeviceContextVkImpl.cpp
+++ b/Graphics/GraphicsEngineVulkan/src/DeviceContextVkImpl.cpp
@@ -2334,6 +2334,22 @@ void DeviceContextVkImpl::TransitionImageLayout(ITexture* pTexture, VkImageLayou
}
}
+namespace
+{
+NODISCARD inline bool ResourceStateHasWriteAccess(RESOURCE_STATE State)
+{
+ static_assert(RESOURCE_STATE_MAX_BIT == RESOURCE_STATE_RAY_TRACING, "This function must be updated to handle new resource state flag");
+ constexpr RESOURCE_STATE WriteAccessStates =
+ RESOURCE_STATE_RENDER_TARGET |
+ RESOURCE_STATE_UNORDERED_ACCESS |
+ RESOURCE_STATE_COPY_DEST |
+ RESOURCE_STATE_RESOLVE_DEST |
+ RESOURCE_STATE_BUILD_AS_WRITE;
+
+ return State & WriteAccessStates;
+}
+} // namespace
+
void DeviceContextVkImpl::TransitionTextureState(TextureVkImpl& TextureVk,
RESOURCE_STATE OldState,
RESOURCE_STATE NewState,
@@ -2396,17 +2412,22 @@ void DeviceContextVkImpl::TransitionTextureState(TextureVkImpl& Textur
pSubresRange->aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
}
- // Note that when both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
- // to make sure that all UAV writes are complete and visible.
+ // Always add barrier after writes.
+ const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
+
auto OldLayout = ResourceStateToVkImageLayout(OldState);
auto NewLayout = ResourceStateToVkImageLayout(NewState);
auto OldStages = ResourceStateFlagsToVkPipelineStageFlags(OldState, m_CommandBuffer.GetEnabledShaderStages());
auto NewStages = ResourceStateFlagsToVkPipelineStageFlags(NewState, m_CommandBuffer.GetEnabledShaderStages());
- m_CommandBuffer.TransitionImageLayout(vkImg, OldLayout, NewLayout, *pSubresRange, OldStages, NewStages);
- if (UpdateTextureState)
+
+ if (((OldState & NewState) != NewState) || OldLayout != NewLayout || AfterWrite)
{
- TextureVk.SetState(NewState);
- VERIFY_EXPR(TextureVk.GetLayout() == NewLayout);
+ m_CommandBuffer.TransitionImageLayout(vkImg, OldLayout, NewLayout, *pSubresRange, OldStages, NewStages);
+ if (UpdateTextureState)
+ {
+ TextureVk.SetState(NewState);
+ VERIFY_EXPR(TextureVk.GetLayout() == NewLayout);
+ }
}
}
@@ -2421,10 +2442,7 @@ void DeviceContextVkImpl::TransitionOrVerifyTextureState(TextureVkImpl&
VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
if (Texture.IsInKnownState())
{
- if (!Texture.CheckState(RequiredState))
- {
- TransitionTextureState(Texture, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ TransitionTextureState(Texture, RESOURCE_STATE_UNKNOWN, RequiredState, true);
VERIFY_EXPR(Texture.GetLayout() == ExpectedLayout);
}
}
@@ -2489,9 +2507,10 @@ void DeviceContextVkImpl::TransitionBufferState(BufferVkImpl& BufferVk, RESOURCE
}
}
- // When both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
- // to make sure that all UAV writes are complete and visible.
- if (((OldState & NewState) != NewState) || NewState == RESOURCE_STATE_UNORDERED_ACCESS || NewState == RESOURCE_STATE_BUILD_AS_WRITE)
+ // Always add barrier after writes.
+ const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
+
+ if (((OldState & NewState) != NewState) || AfterWrite)
{
DEV_CHECK_ERR(BufferVk.m_VulkanBuffer != VK_NULL_HANDLE, "Cannot transition suballocated buffer");
VERIFY_EXPR(BufferVk.GetDynamicOffset(m_ContextId, this) == 0);
@@ -2521,10 +2540,7 @@ void DeviceContextVkImpl::TransitionOrVerifyBufferState(BufferVkImpl&
VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
if (Buffer.IsInKnownState())
{
- if (!Buffer.CheckState(RequiredState))
- {
- TransitionBufferState(Buffer, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ TransitionBufferState(Buffer, RESOURCE_STATE_UNKNOWN, RequiredState, true);
VERIFY_EXPR(Buffer.CheckAccessFlags(ExpectedAccessFlags));
}
}
@@ -2564,7 +2580,10 @@ void DeviceContextVkImpl::TransitionBLASState(BottomLevelASVkImpl& BLAS,
}
}
- if ((OldState & NewState) != NewState)
+ // Always add barrier after writes.
+ const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
+
+ if ((OldState & NewState) != NewState || AfterWrite)
{
EnsureVkCmdBuffer();
auto OldAccessFlags = ResourceStateFlagsToVkAccessFlags(OldState);
@@ -2584,8 +2603,6 @@ void DeviceContextVkImpl::TransitionTLASState(TopLevelASVkImpl& TLAS,
RESOURCE_STATE NewState,
bool UpdateInternalState)
{
- // AZ TODO: transit BLAS state too?
-
VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
if (OldState == RESOURCE_STATE_UNKNOWN)
{
@@ -2609,7 +2626,10 @@ void DeviceContextVkImpl::TransitionTLASState(TopLevelASVkImpl& TLAS,
}
}
- if ((OldState & NewState) != NewState)
+ // Always add barrier after writes.
+ const bool AfterWrite = ResourceStateHasWriteAccess(OldState);
+
+ if ((OldState & NewState) != NewState || AfterWrite)
{
EnsureVkCmdBuffer();
auto OldAccessFlags = ResourceStateFlagsToVkAccessFlags(OldState);
@@ -2634,10 +2654,7 @@ void DeviceContextVkImpl::TransitionOrVerifyBLASState(BottomLevelASVkImpl&
VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
if (BLAS.IsInKnownState())
{
- if (!BLAS.CheckState(RequiredState))
- {
- TransitionBLASState(BLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ TransitionBLASState(BLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
}
}
#ifdef DILIGENT_DEVELOPMENT
@@ -2658,10 +2675,7 @@ void DeviceContextVkImpl::TransitionOrVerifyTLASState(TopLevelASVkImpl&
VERIFY(m_pActiveRenderPass == nullptr, "State transitions are not allowed inside a render pass");
if (TLAS.IsInKnownState())
{
- if (!TLAS.CheckState(RequiredState))
- {
- TransitionTLASState(TLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ TransitionTLASState(TLAS, RESOURCE_STATE_UNKNOWN, RequiredState, true);
}
}
#ifdef DILIGENT_DEVELOPMENT
@@ -2669,6 +2683,11 @@ void DeviceContextVkImpl::TransitionOrVerifyTLASState(TopLevelASVkImpl&
{
DvpVerifyTLASState(TLAS, RequiredState, OperationName);
}
+
+ if (RequiredState & RESOURCE_STATE_RAY_TRACING)
+ {
+ TLAS.CheckBLASVersion();
+ }
#endif
}
@@ -2718,13 +2737,13 @@ void DeviceContextVkImpl::TransitionResourceStates(Uint32 BarrierCount, StateTra
{
TransitionBufferState(*pBuffer, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
}
- else if (RefCntAutoPtr<BottomLevelASVkImpl> pBLAS{Barrier.pResource, IID_BottomLevelAS})
+ else if (RefCntAutoPtr<BottomLevelASVkImpl> pBottomLevelAS{Barrier.pResource, IID_BottomLevelAS})
{
- TransitionBLASState(*pBLAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
+ TransitionBLASState(*pBottomLevelAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
}
- else if (RefCntAutoPtr<TopLevelASVkImpl> pTLAS{Barrier.pResource, IID_TopLevelAS})
+ else if (RefCntAutoPtr<TopLevelASVkImpl> pTopLevelAS{Barrier.pResource, IID_TopLevelAS})
{
- TransitionTLASState(*pTLAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
+ TransitionTLASState(*pTopLevelAS, Barrier.OldState, Barrier.NewState, Barrier.UpdateResourceState);
}
else
{
@@ -2808,7 +2827,7 @@ void DeviceContextVkImpl::BuildBLAS(const BLASBuildAttribs& Attribs)
const char* OpName = "Build BottomLevelAS (DeviceContextVkImpl::BuildBLAS)";
TransitionOrVerifyBLASState(*pBLASVk, Attribs.BLASTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, OpName);
- TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VkAccessFlagBits(0), OpName);
+ TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR, OpName);
VkAccelerationStructureBuildGeometryInfoKHR Info = {};
std::vector<VkAccelerationStructureBuildOffsetInfoKHR> Offsets;
@@ -2845,7 +2864,7 @@ void DeviceContextVkImpl::BuildBLAS(const BLASBuildAttribs& Attribs)
vkTris.vertexStride = SrcTris.VertexStride;
vkTris.vertexData.deviceAddress = pVB->GetVkDeviceAddress() + SrcTris.VertexOffset;
- TransitionOrVerifyBufferState(*pVB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, static_cast<VkAccessFlagBits>(0), OpName);
+ TransitionOrVerifyBufferState(*pVB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
if (SrcTris.pIndexBuffer)
{
@@ -2854,7 +2873,7 @@ void DeviceContextVkImpl::BuildBLAS(const BLASBuildAttribs& Attribs)
vkTris.indexData.deviceAddress = pIB->GetVkDeviceAddress() + SrcTris.IndexOffset;
off.primitiveCount = SrcTris.IndexCount / 3;
- TransitionOrVerifyBufferState(*pIB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, static_cast<VkAccessFlagBits>(0), OpName);
+ TransitionOrVerifyBufferState(*pIB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
}
else
{
@@ -2870,7 +2889,7 @@ void DeviceContextVkImpl::BuildBLAS(const BLASBuildAttribs& Attribs)
auto* const pTB = ValidatedCast<BufferVkImpl>(SrcTris.pTransformBuffer);
vkTris.transformData.deviceAddress = pTB->GetVkDeviceAddress() + SrcTris.TransformBufferOffset;
- TransitionOrVerifyBufferState(*pTB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VkAccessFlagBits(0), OpName);
+ TransitionOrVerifyBufferState(*pTB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
}
else
{
@@ -2913,7 +2932,7 @@ void DeviceContextVkImpl::BuildBLAS(const BLASBuildAttribs& Attribs)
vkAABBs.stride = SrcBoxes.BoxStride;
vkAABBs.data.deviceAddress = pBB->GetVkDeviceAddress() + SrcBoxes.BoxOffset;
- TransitionOrVerifyBufferState(*pBB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VkAccessFlagBits(0), OpName);
+ TransitionOrVerifyBufferState(*pBB, Attribs.GeometryTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
off.firstVertex = 0;
off.transformOffset = 0;
@@ -2939,6 +2958,10 @@ void DeviceContextVkImpl::BuildBLAS(const BLASBuildAttribs& Attribs)
EnsureVkCmdBuffer();
m_CommandBuffer.BuildAccelerationStructure(1, &Info, &OffsetsPtr);
++m_State.NumCommands;
+
+#ifdef DILIGENT_DEVELOPMENT
+ pBLASVk->UpdateVersion();
+#endif
}
void DeviceContextVkImpl::BuildTLAS(const TLASBuildAttribs& Attribs)
@@ -2964,7 +2987,7 @@ void DeviceContextVkImpl::BuildTLAS(const TLASBuildAttribs& Attribs)
const char* OpName = "Build TopLevelAS (DeviceContextVkImpl::BuildTLAS)";
TransitionOrVerifyTLASState(*pTLASVk, Attribs.TLASTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, OpName);
- TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VkAccessFlagBits(0), OpName);
+ TransitionOrVerifyBufferState(*pScratchVk, Attribs.ScratchBufferTransitionMode, RESOURCE_STATE_BUILD_AS_WRITE, VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR, OpName);
pTLASVk->SetInstanceData(Attribs.pInstances, Attribs.InstanceCount, Attribs.HitShadersPerInstance);
@@ -2981,7 +3004,7 @@ void DeviceContextVkImpl::BuildTLAS(const TLASBuildAttribs& Attribs)
auto* const pBLASVk = ValidatedCast<BottomLevelASVkImpl>(Inst.pBLAS);
static_assert(sizeof(vkASInst.transform) == sizeof(Inst.Transform), "size mismatch");
- std::memcpy(&vkASInst.transform, Inst.Transform, sizeof(vkASInst.transform));
+ std::memcpy(&vkASInst.transform, Inst.Transform.data, sizeof(vkASInst.transform));
vkASInst.instanceCustomIndex = Inst.CustomId;
vkASInst.instanceShaderBindingTableRecordOffset = pTLASVk->GetInstanceDesc(Inst.InstanceName).ContributionToHitGroupIndex; // AZ TODO: optimize
@@ -2994,7 +3017,7 @@ void DeviceContextVkImpl::BuildTLAS(const TLASBuildAttribs& Attribs)
UpdateBufferRegion(pInstancesVk, Attribs.InstanceBufferOffset, Size, TmpSpace.vkBuffer, TmpSpace.AlignedOffset, Attribs.InstanceBufferTransitionMode);
}
- TransitionOrVerifyBufferState(*pInstancesVk, Attribs.InstanceBufferTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VkAccessFlagBits(0), OpName);
+ TransitionOrVerifyBufferState(*pInstancesVk, Attribs.InstanceBufferTransitionMode, RESOURCE_STATE_BUILD_AS_READ, VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR, OpName);
VkAccelerationStructureBuildGeometryInfoKHR vkASBuildInfo = {};
VkAccelerationStructureBuildOffsetInfoKHR vkASBuildOffset = {};
@@ -3060,6 +3083,10 @@ void DeviceContextVkImpl::CopyBLAS(const CopyBLASAttribs& Attribs)
m_CommandBuffer.CopyAccelerationStructure(Info);
++m_State.NumCommands;
+
+#ifdef DILIGENT_DEVELOPMENT
+ pDstVk->UpdateVersion();
+#endif
}
void DeviceContextVkImpl::CopyTLAS(const CopyTLASAttribs& Attribs)
@@ -3077,6 +3104,8 @@ void DeviceContextVkImpl::CopyTLAS(const CopyTLASAttribs& Attribs)
auto* pSrcVk = ValidatedCast<TopLevelASVkImpl>(Attribs.pSrc);
auto* pDstVk = ValidatedCast<TopLevelASVkImpl>(Attribs.pDst);
+ pDstVk->CopyInstancceData(*pSrcVk);
+
VkCopyAccelerationStructureInfoKHR Info = {};
Info.sType = VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR;
diff --git a/Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp b/Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp
index 36ca3a42..1cac2b6f 100644
--- a/Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp
+++ b/Graphics/GraphicsEngineVulkan/src/PipelineStateVkImpl.cpp
@@ -758,10 +758,16 @@ PipelineStateVkImpl::PipelineStateVkImpl(IReferenceCounters*
{
try
{
+ const auto& LogicalDevice = GetDevice()->GetLogicalDevice();
+ const auto ShaderGroupHandleSize = pDeviceVk->GetShaderGroupHandleSize();
+
+ if (LogicalDevice.GetEnabledExtFeatures().RayTracing.rayTracing == VK_FALSE)
+ LOG_ERROR_AND_THROW("Ray tracing is not supported by this device");
+
std::vector<VkPipelineShaderStageCreateInfo> vkShaderStages;
std::vector<VulkanUtilities::ShaderModuleWrapper> ShaderModules;
-
std::vector<VkRayTracingShaderGroupCreateInfoKHR> ShaderGroups;
+
InitInternalObjects(CreateInfo, vkShaderStages, ShaderModules,
[&](const RayTracingPipelineStateCreateInfo& CreateInfo, LinearAllocator& MemPool, TShaderStages& ShaderStages) //
{
@@ -773,9 +779,6 @@ PipelineStateVkImpl::PipelineStateVkImpl(IReferenceCounters*
CreateRayTracingPipeline(pDeviceVk, vkShaderStages, ShaderGroups, m_PipelineLayout, m_Desc, GetRayTracingPipelineDesc(), m_Pipeline);
- const auto& LogicalDevice = GetDevice()->GetLogicalDevice();
- const auto ShaderGroupHandleSize = pDeviceVk->GetShaderGroupHandleSize();
-
auto err = LogicalDevice.GetRayTracingShaderGroupHandles(m_Pipeline, 0, static_cast<uint32_t>(ShaderGroups.size()), ShaderGroupHandleSize, &m_pRayTracingPipelineData->Shaders[0]);
VERIFY(err == VK_SUCCESS, "Failed to get shader group handles");
(void)err;
diff --git a/Graphics/GraphicsEngineVulkan/src/ShaderBindingTableVkImpl.cpp b/Graphics/GraphicsEngineVulkan/src/ShaderBindingTableVkImpl.cpp
index 6f5091e0..3940769f 100644
--- a/Graphics/GraphicsEngineVulkan/src/ShaderBindingTableVkImpl.cpp
+++ b/Graphics/GraphicsEngineVulkan/src/ShaderBindingTableVkImpl.cpp
@@ -39,57 +39,12 @@ ShaderBindingTableVkImpl::ShaderBindingTableVkImpl(IReferenceCounters*
bool bIsDeviceInternal) :
TShaderBindingTableBase{pRefCounters, pRenderDeviceVk, Desc, bIsDeviceInternal}
{
- ValidateDesc(Desc);
-
- const auto& RTLimits = GetDevice()->GetPhysicalDevice().GetExtProperties().RayTracing;
- m_ShaderRecordStride = m_Desc.ShaderRecordSize + RTLimits.shaderGroupHandleSize;
}
ShaderBindingTableVkImpl::~ShaderBindingTableVkImpl()
{
}
-void ShaderBindingTableVkImpl::ValidateDesc(const ShaderBindingTableDesc& Desc) const
-{
- const auto& RTLimits = GetDevice()->GetPhysicalDevice().GetExtProperties().RayTracing;
-
- if (Desc.ShaderRecordSize + RTLimits.shaderGroupHandleSize > RTLimits.maxShaderGroupStride)
- {
- LOG_ERROR_AND_THROW("Description of Shader binding table '", (Desc.Name ? Desc.Name : ""),
- "' is invalid: ShaderRecordSize is too big, max size is: ", RTLimits.maxShaderGroupStride - RTLimits.shaderGroupHandleSize);
- }
-}
-
-void ShaderBindingTableVkImpl::Verify() const
-{
- // AZ TODO
-}
-
-void ShaderBindingTableVkImpl::Reset(const ShaderBindingTableDesc& Desc)
-{
- m_RayGenShaderRecord.clear();
- m_MissShadersRecord.clear();
- m_CallableShadersRecord.clear();
- m_HitGroupsRecord.clear();
- m_Changed = true;
-
- try
- {
- ValidateShaderBindingTableDesc(Desc);
- ValidateDesc(Desc);
- }
- catch (const std::runtime_error&)
- {
- // AZ TODO
- return;
- }
-
- m_Desc = Desc;
-
- const auto& RTLimits = GetDevice()->GetPhysicalDevice().GetExtProperties().RayTracing;
- m_ShaderRecordStride = m_Desc.ShaderRecordSize + RTLimits.shaderGroupHandleSize;
-}
-
void ShaderBindingTableVkImpl::ResetHitGroups(Uint32 HitShadersPerInstance)
{
// AZ TODO
diff --git a/Graphics/GraphicsEngineVulkan/src/ShaderResourceCacheVk.cpp b/Graphics/GraphicsEngineVulkan/src/ShaderResourceCacheVk.cpp
index 8101fefc..27e5ee72 100644
--- a/Graphics/GraphicsEngineVulkan/src/ShaderResourceCacheVk.cpp
+++ b/Graphics/GraphicsEngineVulkan/src/ShaderResourceCacheVk.cpp
@@ -166,10 +166,9 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
{
constexpr RESOURCE_STATE RequiredState = RESOURCE_STATE_CONSTANT_BUFFER;
VERIFY_EXPR((ResourceStateFlagsToVkAccessFlags(RequiredState) & VK_ACCESS_UNIFORM_READ_BIT) == VK_ACCESS_UNIFORM_READ_BIT);
- const bool IsInRequiredState = pBufferVk->CheckState(RequiredState);
if (VerifyOnly)
{
- if (!IsInRequiredState)
+ if (!pBufferVk->CheckState(RequiredState))
{
LOG_ERROR_MESSAGE("State of buffer '", pBufferVk->GetDesc().Name, "' is incorrect. Required state: ",
GetResourceStateString(RequiredState), ". Actual state: ",
@@ -181,10 +180,7 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
}
else
{
- if (!IsInRequiredState)
- {
- pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
VERIFY_EXPR(pBufferVk->CheckAccessFlags(VK_ACCESS_UNIFORM_READ_BIT));
}
}
@@ -211,11 +207,10 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
VERIFY_EXPR((ResourceStateFlagsToVkAccessFlags(RequiredState) & RequiredAccessFlags) == RequiredAccessFlags);
#endif
- const bool IsInRequiredState = pBufferVk->CheckState(RequiredState);
if (VerifyOnly)
{
- if (!IsInRequiredState)
+ if (!pBufferVk->CheckState(RequiredState))
{
LOG_ERROR_MESSAGE("State of buffer '", pBufferVk->GetDesc().Name, "' is incorrect. Required state: ",
GetResourceStateString(RequiredState), ". Actual state: ",
@@ -227,12 +222,7 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
}
else
{
- // When both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
- // to make sure that all UAV writes are complete and visible.
- if (!IsInRequiredState || RequiredState == RESOURCE_STATE_UNORDERED_ACCESS)
- {
- pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ pCtxVkImpl->TransitionBufferState(*pBufferVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
VERIFY_EXPR(pBufferVk->CheckAccessFlags(RequiredAccessFlags));
}
}
@@ -275,11 +265,10 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
VERIFY_EXPR(ResourceStateToVkImageLayout(RequiredState) == VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL);
}
}
- const bool IsInRequiredState = pTextureVk->CheckState(RequiredState);
if (VerifyOnly)
{
- if (!IsInRequiredState)
+ if (!pTextureVk->CheckState(RequiredState))
{
LOG_ERROR_MESSAGE("State of texture '", pTextureVk->GetDesc().Name, "' is incorrect. Required state: ",
GetResourceStateString(RequiredState), ". Actual state: ",
@@ -291,12 +280,7 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
}
else
{
- // When both old and new states are RESOURCE_STATE_UNORDERED_ACCESS, we need to execute UAV barrier
- // to make sure that all UAV writes are complete and visible.
- if (!IsInRequiredState || RequiredState == RESOURCE_STATE_UNORDERED_ACCESS)
- {
- pCtxVkImpl->TransitionTextureState(*pTextureVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ pCtxVkImpl->TransitionTextureState(*pTextureVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
}
}
}
@@ -327,11 +311,10 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
auto* pTLASVk = Res.pObject.RawPtr<TopLevelASVkImpl>();
if (pTLASVk != nullptr && pTLASVk->IsInKnownState())
{
- constexpr RESOURCE_STATE RequiredState = RESOURCE_STATE_RAY_TRACING;
- const bool IsInRequiredState = pTLASVk->CheckState(RequiredState);
+ constexpr RESOURCE_STATE RequiredState = RESOURCE_STATE_RAY_TRACING;
if (VerifyOnly)
{
- if (!IsInRequiredState)
+ if (!pTLASVk->CheckState(RequiredState))
{
LOG_ERROR_MESSAGE("State of TLAS '", pTLASVk->GetDesc().Name, "' is incorrect. Required state: ",
GetResourceStateString(RequiredState), ". Actual state: ",
@@ -340,13 +323,12 @@ void ShaderResourceCacheVk::TransitionResources(DeviceContextVkImpl* pCtxVkImpl)
"when calling IDeviceContext::CommitShaderResources() or explicitly transition the TLAS state "
"with IDeviceContext::TransitionResourceStates().");
}
+
+ pTLASVk->CheckBLASVersion();
}
else
{
- if (!IsInRequiredState)
- {
- pCtxVkImpl->TransitionTLASState(*pTLASVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
- }
+ pCtxVkImpl->TransitionTLASState(*pTLASVk, RESOURCE_STATE_UNKNOWN, RequiredState, true);
}
}
}
diff --git a/Graphics/GraphicsEngineVulkan/src/VulkanTypeConversions.cpp b/Graphics/GraphicsEngineVulkan/src/VulkanTypeConversions.cpp
index 143042b8..e039311b 100644
--- a/Graphics/GraphicsEngineVulkan/src/VulkanTypeConversions.cpp
+++ b/Graphics/GraphicsEngineVulkan/src/VulkanTypeConversions.cpp
@@ -1613,12 +1613,10 @@ VkBuildAccelerationStructureFlagsKHR BuildASFlagsToVkBuildAccelerationStructureF
"Please update the switch below to handle the new ray tracing build flag");
VkBuildAccelerationStructureFlagsKHR Result = 0;
- for (Uint32 Bit = 1; Bit <= Flags; Bit <<= 1)
+ while (Flags != RAYTRACING_BUILD_AS_NONE)
{
- if ((Flags & Bit) != Bit)
- continue;
-
- switch (static_cast<RAYTRACING_BUILD_AS_FLAGS>(Bit))
+ auto FlagBit = static_cast<RAYTRACING_BUILD_AS_FLAGS>(1 << PlatformMisc::GetLSB(Uint32{Flags}));
+ switch (FlagBit)
{
// clang-format off
case RAYTRACING_BUILD_AS_ALLOW_UPDATE: Result |= VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_UPDATE_BIT_KHR; break;
@@ -1629,6 +1627,7 @@ VkBuildAccelerationStructureFlagsKHR BuildASFlagsToVkBuildAccelerationStructureF
// clang-format on
default: UNEXPECTED("unknown build AS flag");
}
+ Flags = Flags & ~FlagBit;
}
return Result;
}
@@ -1639,12 +1638,10 @@ VkGeometryFlagsKHR GeometryFlagsToVkGeometryFlags(RAYTRACING_GEOMETRY_FLAGS Flag
"Please update the switch below to handle the new ray tracing geometry flag");
VkGeometryFlagsKHR Result = 0;
- for (Uint32 Bit = 1; Bit <= Flags; Bit <<= 1)
+ while (Flags != RAYTRACING_GEOMETRY_NONE)
{
- if ((Flags & Bit) != Bit)
- continue;
-
- switch (static_cast<RAYTRACING_GEOMETRY_FLAGS>(Bit))
+ auto FlagBit = static_cast<RAYTRACING_GEOMETRY_FLAGS>(1 << PlatformMisc::GetLSB(Uint32{Flags}));
+ switch (FlagBit)
{
// clang-format off
case RAYTRACING_GEOMETRY_OPAQUE: Result |= VK_GEOMETRY_OPAQUE_BIT_KHR; break;
@@ -1652,6 +1649,7 @@ VkGeometryFlagsKHR GeometryFlagsToVkGeometryFlags(RAYTRACING_GEOMETRY_FLAGS Flag
// clang-format on
default: UNEXPECTED("unknown geometry flag");
}
+ Flags = Flags & ~FlagBit;
}
return Result;
}
@@ -1662,12 +1660,10 @@ VkGeometryInstanceFlagsKHR InstanceFlagsToVkGeometryInstanceFlags(RAYTRACING_INS
"Please update the switch below to handle the new ray tracing instance flag");
VkGeometryInstanceFlagsKHR Result = 0;
- for (Uint32 Bit = 1; Bit <= Flags; Bit <<= 1)
+ while (Flags != RAYTRACING_INSTANCE_NONE)
{
- if ((Flags & Bit) != Bit)
- continue;
-
- switch (static_cast<RAYTRACING_INSTANCE_FLAGS>(Bit))
+ auto FlagBit = static_cast<RAYTRACING_INSTANCE_FLAGS>(1 << PlatformMisc::GetLSB(Uint32{Flags}));
+ switch (FlagBit)
{
// clang-format off
case RAYTRACING_INSTANCE_TRIANGLE_FACING_CULL_DISABLE: Result |= VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR; break;
@@ -1677,6 +1673,7 @@ VkGeometryInstanceFlagsKHR InstanceFlagsToVkGeometryInstanceFlags(RAYTRACING_INS
// clang-format on
default: UNEXPECTED("unknown instance flag");
}
+ Flags = Flags & ~FlagBit;
}
return Result;
}
@@ -1693,7 +1690,7 @@ VkCopyAccelerationStructureModeKHR CopyASModeToVkCopyAccelerationStructureMode(C
// clang-format on
default:
UNEXPECTED("unknown AS copy mode");
- return static_cast<VkCopyAccelerationStructureModeKHR>(0);
+ return VK_COPY_ACCELERATION_STRUCTURE_MODE_MAX_ENUM_KHR;
}
}