diff --git a/util/test/demos/d3d12/d3d12_helpers.cpp b/util/test/demos/d3d12/d3d12_helpers.cpp
index 09eb6074b..0e671b47c 100644
--- a/util/test/demos/d3d12/d3d12_helpers.cpp
+++ b/util/test/demos/d3d12/d3d12_helpers.cpp
@@ -1087,6 +1087,36 @@ D3D12PSOCreator &D3D12PSOCreator::VS(ID3DBlobPtr blob)
return *this;
}
+D3D12PSOCreator &D3D12PSOCreator::AS(ID3DBlobPtr blob)
+{
+ if(blob)
+ {
+ m_AS.pShaderBytecode = blob->GetBufferPointer();
+ m_AS.BytecodeLength = blob->GetBufferSize();
+ }
+ else
+ {
+ m_AS.pShaderBytecode = NULL;
+ m_AS.BytecodeLength = 0;
+ }
+ return *this;
+}
+
+D3D12PSOCreator &D3D12PSOCreator::MS(ID3DBlobPtr blob)
+{
+ if(blob)
+ {
+ m_MS.pShaderBytecode = blob->GetBufferPointer();
+ m_MS.BytecodeLength = blob->GetBufferSize();
+ }
+ else
+ {
+ m_MS.pShaderBytecode = NULL;
+ m_MS.BytecodeLength = 0;
+ }
+ return *this;
+}
+
D3D12PSOCreator &D3D12PSOCreator::HS(ID3DBlobPtr blob)
{
if(blob)
@@ -1223,6 +1253,10 @@ D3D12PSOCreator &D3D12PSOCreator::SampleCount(UINT Samples)
D3D12PSOCreator::operator ID3D12PipelineStatePtr() const
{
ID3D12PipelineStatePtr pso;
+ if((m_MS.BytecodeLength > 0) || (m_AS.BytecodeLength > 0))
+ {
+ return NULL;
+ }
if(ComputeDesc.CS.BytecodeLength > 0)
{
CHECK_HR(m_Dev->CreateComputePipelineState(&ComputeDesc, __uuidof(ID3D12PipelineState),
diff --git a/util/test/demos/d3d12/d3d12_helpers.h b/util/test/demos/d3d12/d3d12_helpers.h
index 480f40477..e9244dc4e 100644
--- a/util/test/demos/d3d12/d3d12_helpers.h
+++ b/util/test/demos/d3d12/d3d12_helpers.h
@@ -88,6 +88,8 @@ public:
D3D12PSOCreator(ID3D12DevicePtr dev);
D3D12PSOCreator &VS(ID3DBlobPtr blob);
+ D3D12PSOCreator &AS(ID3DBlobPtr blob);
+ D3D12PSOCreator &MS(ID3DBlobPtr blob);
D3D12PSOCreator &HS(ID3DBlobPtr blob);
D3D12PSOCreator &DS(ID3DBlobPtr blob);
D3D12PSOCreator &GS(ID3DBlobPtr blob);
@@ -113,7 +115,12 @@ public:
D3D12_GRAPHICS_PIPELINE_STATE_DESC GraphicsDesc = {};
D3D12_COMPUTE_PIPELINE_STATE_DESC ComputeDesc = {};
+ const D3D12_SHADER_BYTECODE &GetAS() const { return m_AS; };
+ const D3D12_SHADER_BYTECODE &GetMS() const { return m_MS; };
+
private:
+ D3D12_SHADER_BYTECODE m_AS = {};
+ D3D12_SHADER_BYTECODE m_MS = {};
ID3D12DevicePtr m_Dev;
};
diff --git a/util/test/demos/d3d12/d3d12_mesh_shader.cpp b/util/test/demos/d3d12/d3d12_mesh_shader.cpp
new file mode 100644
index 000000000..396f03d5c
--- /dev/null
+++ b/util/test/demos/d3d12/d3d12_mesh_shader.cpp
@@ -0,0 +1,340 @@
+/******************************************************************************
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2024 Baldur Karlsson
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ ******************************************************************************/
+
+#include "d3d12_test.h"
+
+// subobject headers have to be aligned to pointer boundaries
+#define SUBOBJECT_HEADER(subobj) \
+ D3D12_PIPELINE_STATE_SUBOBJECT_TYPE alignas(void *) CONCAT(header, subobj) = \
+ CONCAT(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_, subobj);
+
+struct GraphicsStreamData
+{
+ // graphics properties
+ SUBOBJECT_HEADER(ROOT_SIGNATURE);
+ ID3D12RootSignature *pRootSignature = NULL;
+ SUBOBJECT_HEADER(VS);
+ D3D12_SHADER_BYTECODE VS = {};
+ SUBOBJECT_HEADER(AS);
+ D3D12_SHADER_BYTECODE AS = {};
+ SUBOBJECT_HEADER(MS);
+ D3D12_SHADER_BYTECODE MS = {};
+ SUBOBJECT_HEADER(PS);
+ D3D12_SHADER_BYTECODE PS = {};
+ SUBOBJECT_HEADER(DS);
+ D3D12_SHADER_BYTECODE DS = {};
+ SUBOBJECT_HEADER(HS);
+ D3D12_SHADER_BYTECODE HS = {};
+ SUBOBJECT_HEADER(GS);
+ D3D12_SHADER_BYTECODE GS = {};
+ SUBOBJECT_HEADER(RENDER_TARGET_FORMATS);
+ D3D12_RT_FORMAT_ARRAY RTVFormats = {};
+ SUBOBJECT_HEADER(DEPTH_STENCIL_FORMAT);
+ DXGI_FORMAT DSVFormat = DXGI_FORMAT_UNKNOWN;
+ SUBOBJECT_HEADER(PRIMITIVE_TOPOLOGY);
+ D3D12_PRIMITIVE_TOPOLOGY_TYPE PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_UNDEFINED;
+ SUBOBJECT_HEADER(IB_STRIP_CUT_VALUE);
+ D3D12_INDEX_BUFFER_STRIP_CUT_VALUE IBStripCutValue = D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED;
+ SUBOBJECT_HEADER(NODE_MASK);
+ UINT NodeMask = 0;
+ SUBOBJECT_HEADER(SAMPLE_MASK);
+ UINT SampleMask = 0;
+ SUBOBJECT_HEADER(RASTERIZER);
+ D3D12_RASTERIZER_DESC RasterizerState;
+ SUBOBJECT_HEADER(FLAGS);
+ D3D12_PIPELINE_STATE_FLAGS Flags = D3D12_PIPELINE_STATE_FLAG_NONE;
+ SUBOBJECT_HEADER(BLEND);
+ D3D12_BLEND_DESC BlendState = {};
+ UINT pad0;
+ SUBOBJECT_HEADER(SAMPLE_DESC);
+ DXGI_SAMPLE_DESC SampleDesc = {};
+ UINT pad1;
+};
+
+#undef SUBOBJECT_HEADER
+
+std::string GlobalPayload_Shaders = R"EOSHADER(
+
+struct Payload
+{
+ uint tri[2];
+};
+
+groupshared Payload sPayload;
+
+[numthreads(2, 1, 1)]
+void as_amplify(uint gtid : SV_GroupThreadID, uint dtid : SV_DispatchThreadID, uint gid : SV_GroupIndex)
+{
+ sPayload.tri[gid] = dtid;
+ DispatchMesh(2, 1, 1, sPayload);
+}
+
+struct m2f
+{
+ float4 pos : SV_POSITION;
+ float4 col : COLOR0;
+ float2 uv : TEXCOORD0;
+};
+
+[outputtopology("triangle")]
+[numthreads(1, 1, 1)]
+void ms_amplify(uint gtid : SV_GroupThreadID, uint dtid : SV_DispatchThreadID, in payload Payload payload, out indices uint3 triangles[128], out vertices m2f vertices[64])
+{
+ SetMeshOutputCounts(3, 1);
+
+ uint tri = payload.tri[dtid];
+ uint vertIdx = 0;
+ triangles[0] = uint3(0+vertIdx, 1+vertIdx, 2+vertIdx);
+
+ float4 org = float4(-0.65, 0.0, 0.0, 0.0) + float4(0.42, 0.0, 0.0, 0.0) * tri;
+ vertices[0+vertIdx].pos = float4(-0.2, -0.2, 0.0, 1.0) + org;
+ vertices[0+vertIdx].col = float4(0.0, 1.0, 0.0, 1.0);
+ vertices[0+vertIdx].uv = float2(0.0, 0.0);
+
+ vertices[1+vertIdx].pos = float4(0.0, 0.2, 0.0, 1.0) + org;
+ vertices[1+vertIdx].col = float4(0.0, 1.0, 0.0, 1.0);
+ vertices[1+vertIdx].uv = float2(0.0, 1.0);
+
+ vertices[2+vertIdx].pos = float4(0.2, -0.2, 0.0, 1.0) + org;
+ vertices[2+vertIdx].col = float4(0.0, 1.0, 0.0, 1.0);
+ vertices[2+vertIdx].uv = float2(1.0, 0.0);
+}
+
+)EOSHADER";
+
+std::string LocalPayload_Shaders = R"EOSHADER(
+
+struct Payload
+{
+ uint tri[4];
+};
+
+[numthreads(1, 1, 1)]
+void as_amplify(uint gtid : SV_GroupThreadID, uint dtid : SV_DispatchThreadID, uint gid : SV_GroupIndex)
+{
+ Payload sPayload;
+ sPayload.tri[0] = 0;
+ sPayload.tri[1] = 1;
+ sPayload.tri[2] = 2;
+ sPayload.tri[3] = 3;
+ DispatchMesh(4, 1, 1, sPayload);
+}
+
+struct m2f
+{
+ float4 pos : SV_POSITION;
+ float4 col : COLOR0;
+ float2 uv : TEXCOORD0;
+};
+
+[outputtopology("triangle")]
+[numthreads(1, 1, 1)]
+void ms_amplify(uint gtid : SV_GroupThreadID, uint dtid : SV_DispatchThreadID, in payload Payload payload, out indices uint3 triangles[128], out vertices m2f vertices[64])
+{
+ SetMeshOutputCounts(3, 1);
+
+ uint tri = payload.tri[dtid];
+ uint vertIdx = 0;
+ triangles[0] = uint3(0+vertIdx, 1+vertIdx, 2+vertIdx);
+
+ float4 org = float4(-0.65, -0.65, 0.0, 0.0) + float4(0.42, 0.0, 0.0, 0.0) * tri;
+ vertices[0+vertIdx].pos = float4(-0.2, -0.2, 0.0, 1.0) + org;
+ vertices[0+vertIdx].col = float4(0.0, 0.0, 1.0, 1.0);
+ vertices[0+vertIdx].uv = float2(0.0, 0.0);
+
+ vertices[1+vertIdx].pos = float4(0.0, 0.2, 0.0, 1.0) + org;
+ vertices[1+vertIdx].col = float4(0.0, 0.0, 1.0, 1.0);
+ vertices[1+vertIdx].uv = float2(0.0, 1.0);
+
+ vertices[2+vertIdx].pos = float4(0.2, -0.2, 0.0, 1.0) + org;
+ vertices[2+vertIdx].col = float4(0.0, 0.0, 1.0, 1.0);
+ vertices[2+vertIdx].uv = float2(1.0, 0.0);
+}
+
+)EOSHADER";
+
+std::string SimpleMeshShader = R"EOSHADER(
+
+struct m2f
+{
+ float4 pos : SV_POSITION;
+ float4 col : COLOR0;
+ float2 uv : TEXCOORD0;
+};
+
+[outputtopology("triangle")]
+[numthreads(1, 1, 1)]
+void ms_simple(in uint gid : SV_GroupID, out indices uint3 triangles[2], out vertices m2f vertices[6])
+{
+ SetMeshOutputCounts(6, 2);
+
+ for (uint i = 0; i < 2; i++)
+ {
+ uint tri = i;
+ uint vertIdx = tri * 3;
+ triangles[tri] = uint3(0+vertIdx, 1+vertIdx, 2+vertIdx);
+ tri += 2 * gid;
+
+ float4 org = float4(-0.65, +0.65, 0.0, 0.0) + float4(0.42, 0.0, 0.0, 0.0) * tri;
+ vertices[0+vertIdx].pos = float4(-0.2, -0.2, 0.0, 1.0) + org;
+ vertices[0+vertIdx].col = float4(1.0, 0.0, 0.0, 1.0);
+ vertices[0+vertIdx].uv = float2(0.0, 0.0);
+
+ vertices[1+vertIdx].pos = float4(0.0, 0.2, 0.0, 1.0) + org;
+ vertices[1+vertIdx].col = float4(1.0, 0.0, 0.0, 1.0);
+ vertices[1+vertIdx].uv = float2(0.0, 1.0);
+
+ vertices[2+vertIdx].pos = float4(0.2, -0.2, 0.0, 1.0) + org;
+ vertices[2+vertIdx].col = float4(1.0, 0.0, 0.0, 1.0);
+ vertices[2+vertIdx].uv = float2(1.0, 0.0);
+ }
+}
+
+)EOSHADER";
+
+RD_TEST(D3D12_Mesh_Shader, D3D12GraphicsTest)
+{
+ static constexpr const char *Description = "Draws geometry using mesh shader pipeline.";
+
+ void Prepare(int argc, char **argv)
+ {
+ D3D12GraphicsTest::Prepare(argc, argv);
+
+ if(!Avail.empty())
+ return;
+
+ if(opts7.MeshShaderTier == D3D12_MESH_SHADER_TIER_NOT_SUPPORTED)
+ Avail = "Mesh Shaders are not supported";
+ }
+
+ ID3D12PipelineStatePtr CreatePipeline(const D3D12PSOCreator &psoData) const
+ {
+ GraphicsStreamData graphicsStreamData;
+ const D3D12_GRAPHICS_PIPELINE_STATE_DESC &GraphicsDesc = psoData.GraphicsDesc;
+
+ graphicsStreamData.pRootSignature = GraphicsDesc.pRootSignature;
+ graphicsStreamData.VS = GraphicsDesc.VS;
+ graphicsStreamData.AS = psoData.GetAS();
+ graphicsStreamData.MS = psoData.GetMS();
+ graphicsStreamData.PS = GraphicsDesc.PS;
+ graphicsStreamData.DS = GraphicsDesc.DS;
+ graphicsStreamData.HS = GraphicsDesc.HS;
+ graphicsStreamData.GS = GraphicsDesc.GS;
+ graphicsStreamData.BlendState = GraphicsDesc.BlendState;
+ graphicsStreamData.SampleMask = GraphicsDesc.SampleMask;
+ graphicsStreamData.IBStripCutValue = GraphicsDesc.IBStripCutValue;
+ graphicsStreamData.PrimitiveTopologyType = GraphicsDesc.PrimitiveTopologyType;
+ for(uint32_t i = 0; i < 8; ++i)
+ graphicsStreamData.RTVFormats.RTFormats[i] = GraphicsDesc.RTVFormats[i];
+ graphicsStreamData.RTVFormats.NumRenderTargets = GraphicsDesc.NumRenderTargets;
+
+ graphicsStreamData.DSVFormat = GraphicsDesc.DSVFormat;
+ graphicsStreamData.SampleDesc = GraphicsDesc.SampleDesc;
+ graphicsStreamData.NodeMask = GraphicsDesc.NodeMask;
+ graphicsStreamData.Flags = GraphicsDesc.Flags;
+
+ graphicsStreamData.RasterizerState = GraphicsDesc.RasterizerState;
+
+ D3D12_PIPELINE_STATE_STREAM_DESC streamDesc;
+ streamDesc.pPipelineStateSubobjectStream = &graphicsStreamData;
+ streamDesc.SizeInBytes = sizeof(GraphicsStreamData);
+
+ ID3D12PipelineStatePtr pso;
+ dev2->CreatePipelineState(&streamDesc, __uuidof(ID3D12PipelineState), (void **)&pso);
+
+ return pso;
+ }
+
+ int main()
+ {
+ // initialise, create window, create device, etc
+ if(!Init())
+ return 3;
+
+ ID3DBlobPtr as_globalpayload_blob = Compile(GlobalPayload_Shaders, "as_amplify", "as_6_5");
+ ID3DBlobPtr ms_globalpayload_blob = Compile(GlobalPayload_Shaders, "ms_amplify", "ms_6_5");
+ ID3DBlobPtr as_localpayload_blob = Compile(LocalPayload_Shaders, "as_amplify", "as_6_5");
+ ID3DBlobPtr ms_localpayload_blob = Compile(LocalPayload_Shaders, "ms_amplify", "ms_6_5");
+ ID3DBlobPtr msblob = Compile(SimpleMeshShader, "ms_simple", "ms_6_5");
+ ID3DBlobPtr psblob = Compile(D3DDefaultPixel, "main", "ps_6_5");
+
+ ID3D12RootSignaturePtr sig = MakeSig({});
+
+ ID3D12PipelineStatePtr psos[] = {
+ CreatePipeline(MakePSO().RootSig(sig).InputLayout().MS(msblob).PS(psblob)),
+ CreatePipeline(MakePSO()
+ .RootSig(sig)
+ .InputLayout()
+ .AS(as_globalpayload_blob)
+ .MS(ms_globalpayload_blob)
+ .PS(psblob)),
+ CreatePipeline(
+ MakePSO().RootSig(sig).InputLayout().AS(as_localpayload_blob).MS(ms_localpayload_blob).PS(psblob)),
+ };
+
+ while(Running())
+ {
+ ID3D12GraphicsCommandList6Ptr cmd = GetCommandBuffer();
+
+ Reset(cmd);
+
+ ID3D12ResourcePtr bb = StartUsingBackbuffer(cmd, D3D12_RESOURCE_STATE_RENDER_TARGET);
+
+ D3D12_CPU_DESCRIPTOR_HANDLE rtv =
+ MakeRTV(bb).Format(DXGI_FORMAT_R8G8B8A8_UNORM_SRGB).CreateCPU(0);
+
+ ClearRenderTargetView(cmd, rtv, {0.2f, 0.2f, 0.2f, 1.0f});
+
+ cmd->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
+
+ setMarker(cmd, "Mesh Shaders");
+ for(size_t i = 0; i < ARRAY_COUNT(psos); i++)
+ {
+ cmd->SetPipelineState(psos[i]);
+ cmd->SetGraphicsRootSignature(sig);
+
+ RSSetViewport(cmd, {0.0f, 0.0f, (float)screenWidth, (float)screenHeight, 0.0f, 1.0f});
+ RSSetScissorRect(cmd, {0, 0, screenWidth, screenHeight});
+
+ OMSetRenderTargets(cmd, {rtv}, {});
+ if(i < 2)
+ cmd->DispatchMesh(2, 1, 1);
+ else
+ cmd->DispatchMesh(1, 1, 1);
+ }
+
+ FinishUsingBackbuffer(cmd, D3D12_RESOURCE_STATE_RENDER_TARGET);
+
+ cmd->Close();
+
+ Submit({cmd});
+
+ Present();
+ }
+
+ return 0;
+ }
+};
+
+REGISTER_TEST();
diff --git a/util/test/demos/demos.vcxproj b/util/test/demos/demos.vcxproj
index 300371308..33bcdcb47 100644
--- a/util/test/demos/demos.vcxproj
+++ b/util/test/demos/demos.vcxproj
@@ -194,6 +194,7 @@
+
diff --git a/util/test/demos/demos.vcxproj.filters b/util/test/demos/demos.vcxproj.filters
index 552682bbc..7b50263e4 100644
--- a/util/test/demos/demos.vcxproj.filters
+++ b/util/test/demos/demos.vcxproj.filters
@@ -688,6 +688,9 @@
Vulkan\demos
+
+ D3D12\demos
+
diff --git a/util/test/rdtest/analyse.py b/util/test/rdtest/analyse.py
index 9eef165df..74bafed53 100644
--- a/util/test/rdtest/analyse.py
+++ b/util/test/rdtest/analyse.py
@@ -149,6 +149,10 @@ def get_postvs_attrs(controller: rd.ReplayController, mesh: rd.MeshFormat, data_
if data_stage == rd.MeshDataStage.VSOut:
shader = pipe.GetShaderReflection(rd.ShaderStage.Vertex)
+ elif data_stage == rd.MeshDataStage.TaskOut:
+ raise RuntimeError("Use get_postts_attrs to get TaskOut attributes!")
+ elif data_stage == rd.MeshDataStage.MeshOut:
+ shader = pipe.GetShaderReflection(rd.ShaderStage.Mesh)
else:
shader = pipe.GetShaderReflection(rd.ShaderStage.Geometry)
if shader is None:
diff --git a/util/test/rdtest/testcase.py b/util/test/rdtest/testcase.py
index 707fbc77a..3ba1d1171 100644
--- a/util/test/rdtest/testcase.py
+++ b/util/test/rdtest/testcase.py
@@ -378,6 +378,23 @@ class TestCase:
return analyse.decode_mesh_data(self.controller, indices, in_indices, attrs, 0, mesh.baseVertex)
+ def check_task_data(self, task_ref, task_data):
+ for idx in task_ref:
+ ref = task_ref[idx]
+ if idx >= len(task_data):
+ raise TestFailureException('Task data doesn\'t have expected element {}'.format(idx))
+
+ data = task_data[idx]
+
+ for key in ref:
+ if key not in data:
+ raise TestFailureException('Task data[{}] doesn\'t contain data {} as expected. Data is: {}'.format(idx, key, list(data.keys())))
+
+ if not util.value_compare(ref[key], data[key]):
+ raise TestFailureException('Task data[{}] \'{}\': {} is not as expected: {}'.format(idx, key, data[key], ref[key]))
+
+ log.success("Task data is identical to reference")
+
def check_mesh_data(self, mesh_ref, mesh_data):
for idx in mesh_ref:
ref = mesh_ref[idx]
diff --git a/util/test/tests/D3D12/D3D12_Mesh_Shader.py b/util/test/tests/D3D12/D3D12_Mesh_Shader.py
new file mode 100644
index 000000000..8382ca853
--- /dev/null
+++ b/util/test/tests/D3D12/D3D12_Mesh_Shader.py
@@ -0,0 +1,146 @@
+import renderdoc as rd
+import rdtest
+from rdtest import analyse
+
+class D3D12_Mesh_Shader(rdtest.TestCase):
+ demos_test_name = 'D3D12_Mesh_Shader'
+ demos_frame_cap = 5
+
+ def decode_task_data(self, controller: rd.ReplayController, mesh: rd.MeshFormat, payload: rd.ConstantBlock, task: int = 0):
+
+ begin = mesh.vertexByteOffset + mesh.vertexByteStride * task
+ end = min(begin + mesh.vertexByteSize, 0xffffffffffffffff)
+ buffer_data = controller.GetBufferData(mesh.vertexResourceId, begin, end -begin)
+
+ ret = []
+ offset = 0
+ for var in payload.variables:
+ var_data = {}
+ var_data[var.name] = []
+ # This is not complete to decode all possible payload layouts
+ for i in range(var.type.elements):
+ format = rd.ResourceFormat()
+ format.compByteWidth = rd.VarTypeByteSize(var.type.baseType)
+ format.compCount = var.type.columns
+ format.compType = rd.VarTypeCompType(var.type.baseType)
+ format.type = rd.ResourceFormatType.Regular
+
+ data = analyse.unpack_data(format, buffer_data, offset)
+ var_data[var.name] += data
+ offset += format.compByteWidth * format.compCount
+ ret.append(var_data)
+
+ return ret
+
+ def get_task_data(self, action: rd.ActionDescription):
+ mesh: rd.MeshFormat = self.controller.GetPostVSData(0, 0, rd.MeshDataStage.TaskOut)
+ if mesh.numIndices == 0:
+ raise self.TestFailureException("Task data is empty")
+
+ if len(mesh.taskSizes) == 0:
+ raise self.TestFailureException("Task data is empty")
+
+ pipe: rd.PipeState = self.controller.GetPipelineState()
+ shader = pipe.GetShaderReflection(rd.ShaderStage.Task)
+ taskIdx = 0
+ task = action.dispatchDimension
+ data = []
+ for x in range(task[0]):
+ for y in range(task[1]):
+ for z in range(task[2]):
+ data += self.decode_task_data(self.controller, mesh, shader.taskPayload, taskIdx)
+ taskIdx += 1
+ return data
+
+ def build_global_taskout_reference(self):
+ reference = {}
+ for i in range(2):
+ reference[i] = {
+ 'tri': (i*2,i*2+1),
+ }
+ return reference
+
+ def build_local_taskout_reference(self):
+ reference = {}
+ reference[0] = { 'tri': [0, 1, 2, 3] }
+ return reference
+
+ def build_meshout_reference(self, orgY, color):
+ countTris = 4
+ triSize = 0.2
+ deltX = 0.42
+ orgX = -0.65
+ i = 0
+ reference = {}
+ for tri in range(countTris):
+ for vert in range(3):
+ posX = orgX + tri * deltX
+ posY = orgY
+
+ if vert == 0:
+ posX += -0.2
+ posY += -0.2
+ uv = [0.0, 0.0]
+ elif vert == 1:
+ posX += 0.0
+ posY += 0.2
+ uv = [0.0, 1.0]
+ elif vert == 2:
+ posX += 0.2
+ posY += -0.2
+ uv = [1.0, 0.0]
+
+ reference[i] = {
+ 'vtx': i,
+ 'idx': i,
+ 'SV_Position': [posX, posY, 0.0, 1.0],
+ 'COLOR': color,
+ 'TEXCOORD': uv
+ }
+ i += 1
+ return reference
+
+ def check_capture(self):
+ last_action: rd.ActionDescription = self.get_last_action()
+
+ self.controller.SetFrameEvent(last_action.eventId, True)
+
+ action = self.find_action("Mesh Shaders")
+
+ action = action.next
+ self.controller.SetFrameEvent(action.eventId, False)
+ rdtest.log.print(f"Pure Mesh Shader Test EID:{action.eventId}")
+
+ orgY = 0.65
+ color = [1.0, 0.0, 0.0, 1.0]
+ postms_ref = self.build_meshout_reference(orgY, color)
+ postms_data = self.get_postvs(action, rd.MeshDataStage.MeshOut, 0, action.numIndices)
+ self.check_mesh_data(postms_ref, postms_data)
+
+ action = action.next
+ self.controller.SetFrameEvent(action.eventId, False)
+ rdtest.log.print(f"Amplification Shader with Global Payload EID:{action.eventId}")
+
+ postts_ref = self.build_global_taskout_reference()
+ postts_data = self.get_task_data(action)
+ self.check_task_data(postts_ref, postts_data)
+
+ orgY = 0.0
+ color = [0.0, 1.0, 0.0, 1.0]
+ postms_ref = self.build_meshout_reference(orgY, color)
+ postms_data = self.get_postvs(action, rd.MeshDataStage.MeshOut, 0, action.numIndices)
+ self.check_mesh_data(postms_ref, postms_data)
+
+ action = action.next
+ self.controller.SetFrameEvent(action.eventId, False)
+ rdtest.log.print(f"Amplification Shader with Local Payload EID:{action.eventId}")
+
+ postts_ref = self.build_local_taskout_reference()
+ postts_data = self.get_task_data(action)
+ self.check_task_data(postts_ref, postts_data)
+
+ orgY = -0.65
+ color = [0.0, 0.0, 1.0, 1.0]
+ postms_ref = self.build_meshout_reference(orgY, color)
+ postms_data = self.get_postvs(action, rd.MeshDataStage.MeshOut, 0, action.numIndices)
+ self.check_mesh_data(postms_ref, postms_data)