From a71b5f1a3ad7eba7fd4abfc5296bb435d7489946 Mon Sep 17 00:00:00 2001 From: Isaac Marovitz Date: Sat, 22 Jun 2024 14:38:09 +0100 Subject: [PATCH] VoteAllEqual, FindLSB/MSB --- .../CodeGen/Msl/Declarations.cs | 26 +++++++++++++++++++ .../CodeGen/Msl/HelperFunctions/FindLSB.metal | 5 ++++ .../Msl/HelperFunctions/FindMSBS32.metal | 5 ++++ .../Msl/HelperFunctions/FindMSBU32.metal | 6 +++++ .../HelperFunctions/HelperFunctionNames.cs | 4 ++- .../Msl/HelperFunctions/VoteAllEqual.metal | 4 --- .../CodeGen/Msl/Instructions/InstGen.cs | 13 ++++------ .../CodeGen/Msl/Instructions/InstGenBallot.cs | 9 +++++++ .../Msl/Instructions/InstGenBarrier.cs | 16 ++++++++++++ .../CodeGen/Msl/Instructions/InstGenHelper.cs | 9 +++---- .../Ryujinx.Graphics.Shader.csproj | 8 +++--- .../StructuredIr/HelperFunctionsMask.cs | 5 ++++ .../StructuredIr/StructuredProgram.cs | 14 ++++++++-- 13 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindLSB.metal create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBS32.metal create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBU32.metal delete mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/VoteAllEqual.metal create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBarrier.cs diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 0b6aadd03..fc199da2c 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -1,3 +1,4 @@ +using Ryujinx.Common; using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; @@ -57,6 +58,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl context.AppendLine(); DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values); DeclareBufferStructures(context, context.Properties.StorageBuffers.Values); + + if ((info.HelperFunctionsMask & HelperFunctionsMask.FindLSB) != 0) + { + AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindLSB.metal"); + } + + if ((info.HelperFunctionsMask & HelperFunctionsMask.FindMSBS32) != 0) + { + AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBS32.metal"); + } + + if ((info.HelperFunctionsMask & HelperFunctionsMask.FindMSBU32) != 0) + { + AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBU32.metal"); + } } static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind) @@ -310,5 +326,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl } } } + + private static void AppendHelperFunction(CodeGenContext context, string filename) + { + string code = EmbeddedResources.ReadAllText(filename); + + code = code.Replace("\t", CodeGenContext.Tab); + + context.AppendLine(code); + context.AppendLine(); + } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindLSB.metal b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindLSB.metal new file mode 100644 index 000000000..ad786adb3 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindLSB.metal @@ -0,0 +1,5 @@ +template +inline T findLSB(T x) +{ + return select(ctz(x), T(-1), x == T(0)); +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBS32.metal b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBS32.metal new file mode 100644 index 000000000..af4eb6cbd --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBS32.metal @@ -0,0 +1,5 @@ +template +inline T findMSBS32(T x) +{ + return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0)); +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBU32.metal b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBU32.metal new file mode 100644 index 000000000..6d97c41a9 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBU32.metal @@ -0,0 +1,6 @@ +template +inline T findMSBU32(T x) +{ + T v = select(x, T(-1) - x, x < T(0)); + return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0)); +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/HelperFunctionNames.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/HelperFunctionNames.cs index 1e10f0721..a48da4990 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/HelperFunctionNames.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/HelperFunctionNames.cs @@ -2,6 +2,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { static class HelperFunctionNames { - public static string SwizzleAdd = "helperSwizzleAdd"; + public static string FindLSB = "findLSB"; + public static string FindMSBS32 = "findMSBS32"; + public static string FindMSBU32 = "findMSBU32"; } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/VoteAllEqual.metal b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/VoteAllEqual.metal deleted file mode 100644 index efbcee24d..000000000 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/VoteAllEqual.metal +++ /dev/null @@ -1,4 +0,0 @@ -inline bool voteAllEqual(bool value) -{ - return simd_all(value) || !simd_any(value); -} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs index 0bea4d1aa..6c983445b 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs @@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Translation; using System; using System.Text; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBallot; +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBarrier; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; @@ -123,19 +124,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions case Instruction.Ballot: return Ballot(context, operation); case Instruction.Barrier: - return "threadgroup_barrier(mem_flags::mem_threadgroup)"; + return Barrier(context, operation); case Instruction.Call: return Call(context, operation); case Instruction.FSIBegin: return "|| FSI BEGIN ||"; case Instruction.FSIEnd: return "|| FSI END ||"; - case Instruction.FindLSB: - return "|| FIND LSB ||"; - case Instruction.FindMSBS32: - return "|| FIND MSB S32 ||"; - case Instruction.FindMSBU32: - return "|| FIND MSB U32 ||"; case Instruction.GroupMemoryBarrier: return "|| FIND GROUP MEMORY BARRIER ||"; case Instruction.ImageLoad: @@ -152,6 +147,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions return "|| MEMORY BARRIER ||"; case Instruction.Store: return Store(context, operation); + case Instruction.SwizzleAdd: + return "|| SWIZZLE ADD ||"; case Instruction.TextureSample: return TextureSample(context, operation); case Instruction.TextureQuerySamples: @@ -165,7 +162,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions case Instruction.VectorExtract: return VectorExtract(context, operation); case Instruction.VoteAllEqual: - return "|| VOTE ALL EQUAL ||"; + return VoteAllEqual(context, operation); } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBallot.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBallot.cs index 1f53c74ed..19a065d77 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBallot.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBallot.cs @@ -17,5 +17,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions return $"uint4(as_type((simd_vote::vote_t)simd_ballot({arg})), 0, 0).{component}"; } + + public static string VoteAllEqual(CodeGenContext context, AstOperation operation) + { + AggregateType dstType = GetSrcVarType(operation.Inst, 0); + + string arg = GetSourceExpr(context, operation.GetSource(0), dstType); + + return $"simd_all({arg}) || !simd_any({arg})"; + } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBarrier.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBarrier.cs new file mode 100644 index 000000000..7d681de26 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenBarrier.cs @@ -0,0 +1,16 @@ +using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; + +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; +using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; + +namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions +{ + static class InstGenBarrier + { + public static string Barrier(CodeGenContext context, AstOperation operation) + { + return "threadgroup_barrier(mem_flags::mem_threadgroup)"; + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs index d230e2ed4..68ec872af 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs @@ -71,10 +71,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions Add(Instruction.ExponentB2, InstType.CallUnary, "exp2"); Add(Instruction.FSIBegin, InstType.Special); Add(Instruction.FSIEnd, InstType.Special); - // TODO: LSB and MSB Implementations https://github.com/KhronosGroup/SPIRV-Cross/blob/bccaa94db814af33d8ef05c153e7c34d8bd4d685/reference/shaders-msl-no-opt/asm/comp/bitscan.asm.comp#L8 - Add(Instruction.FindLSB, InstType.Special); - Add(Instruction.FindMSBS32, InstType.Special); - Add(Instruction.FindMSBU32, InstType.Special); + Add(Instruction.FindLSB, InstType.CallUnary, HelperFunctionNames.FindLSB); + Add(Instruction.FindMSBS32, InstType.CallUnary, HelperFunctionNames.FindMSBS32); + Add(Instruction.FindMSBU32, InstType.CallUnary, HelperFunctionNames.FindMSBU32); Add(Instruction.Floor, InstType.CallUnary, "floor"); Add(Instruction.FusedMultiplyAdd, InstType.CallTernary, "fma"); Add(Instruction.GroupMemoryBarrier, InstType.Special); @@ -117,7 +116,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.Store, InstType.Special); Add(Instruction.Subtract, InstType.OpBinary, "-", 2); - Add(Instruction.SwizzleAdd, InstType.CallTernary, HelperFunctionNames.SwizzleAdd); + Add(Instruction.SwizzleAdd, InstType.Special); Add(Instruction.TextureSample, InstType.Special); Add(Instruction.TextureQuerySamples, InstType.Special); Add(Instruction.TextureQuerySize, InstType.Special); diff --git a/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj b/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj index e0a92da5a..7803d9aa5 100644 --- a/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj +++ b/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj @@ -14,8 +14,10 @@ - + - - + + + + diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs index 2a3d65e75..8e7bbd6f1 100644 --- a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs +++ b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs @@ -7,6 +7,11 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { MultiplyHighS32 = 1 << 2, MultiplyHighU32 = 1 << 3, + + FindLSB = 1 << 5, + FindMSBS32 = 1 << 6, + FindMSBU32 = 1 << 7, + SwizzleAdd = 1 << 10, FSI = 1 << 11, } diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs index 88053658d..394099902 100644 --- a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs +++ b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs @@ -321,8 +321,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr } // Those instructions needs to be emulated by using helper functions, - // because they are NVIDIA specific. Those flags helps the backend to - // decide which helper functions are needed on the final generated code. + // because they are NVIDIA specific or because the target language has + // no direct equivalent. Those flags helps the backend to decide which + // helper functions are needed on the final generated code. switch (operation.Inst) { case Instruction.MultiplyHighS32: @@ -331,6 +332,15 @@ namespace Ryujinx.Graphics.Shader.StructuredIr case Instruction.MultiplyHighU32: context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32; break; + case Instruction.FindLSB: + context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindLSB; + break; + case Instruction.FindMSBS32: + context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindMSBS32; + break; + case Instruction.FindMSBU32: + context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindMSBU32; + break; case Instruction.SwizzleAdd: context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd; break;