From 80340c98d3456d92e2a8741134b96b3800c57692 Mon Sep 17 00:00:00 2001
From: Isaac Marovitz <isaacryu@icloud.com>
Date: Tue, 19 Mar 2024 14:05:09 -0400
Subject: [PATCH] Revise ISampler

---
 src/Ryujinx.Graphics.Metal/CounterEvent.cs  |  1 -
 src/Ryujinx.Graphics.Metal/HelperShaders.cs |  4 +--
 src/Ryujinx.Graphics.Metal/MetalRenderer.cs | 20 +----------
 src/Ryujinx.Graphics.Metal/Pipeline.cs      | 37 +++++++++++++++------
 src/Ryujinx.Graphics.Metal/Sampler.cs       | 31 +++++++++++++++--
 src/Ryujinx.Graphics.Metal/Texture.cs       |  8 ++++-
 6 files changed, 65 insertions(+), 36 deletions(-)

diff --git a/src/Ryujinx.Graphics.Metal/CounterEvent.cs b/src/Ryujinx.Graphics.Metal/CounterEvent.cs
index 1773b9b63..46b04997e 100644
--- a/src/Ryujinx.Graphics.Metal/CounterEvent.cs
+++ b/src/Ryujinx.Graphics.Metal/CounterEvent.cs
@@ -4,7 +4,6 @@ namespace Ryujinx.Graphics.Metal
 {
     class CounterEvent : ICounterEvent
     {
-
         public CounterEvent()
         {
             Invalid = false;
diff --git a/src/Ryujinx.Graphics.Metal/HelperShaders.cs b/src/Ryujinx.Graphics.Metal/HelperShaders.cs
index 7c1eada7b..8ca7adb6f 100644
--- a/src/Ryujinx.Graphics.Metal/HelperShaders.cs
+++ b/src/Ryujinx.Graphics.Metal/HelperShaders.cs
@@ -25,7 +25,7 @@ namespace Ryujinx.Graphics.Metal
                 Logger.Error?.PrintMsg(LogClass.Gpu, $"Failed to create Library: {StringHelper.String(error.LocalizedDescription)}");
             }
 
-            BlitShader = new HelperShader(device, library, "vertexBlit", "fragmentBlit");
+            BlitShader = new HelperShader(library, "vertexBlit", "fragmentBlit");
         }
     }
 
@@ -35,7 +35,7 @@ namespace Ryujinx.Graphics.Metal
         public readonly MTLFunction VertexFunction;
         public readonly MTLFunction FragmentFunction;
 
-        public HelperShader(MTLDevice device, MTLLibrary library, string vertex, string fragment)
+        public HelperShader(MTLLibrary library, string vertex, string fragment)
         {
             VertexFunction = library.NewFunction(StringHelper.NSString(vertex));
             FragmentFunction = library.NewFunction(StringHelper.NSString(fragment));
diff --git a/src/Ryujinx.Graphics.Metal/MetalRenderer.cs b/src/Ryujinx.Graphics.Metal/MetalRenderer.cs
index a58b7cb60..0930be9ef 100644
--- a/src/Ryujinx.Graphics.Metal/MetalRenderer.cs
+++ b/src/Ryujinx.Graphics.Metal/MetalRenderer.cs
@@ -84,25 +84,7 @@ namespace Ryujinx.Graphics.Metal
 
         public ISampler CreateSampler(SamplerCreateInfo info)
         {
-            (MTLSamplerMinMagFilter minFilter, MTLSamplerMipFilter mipFilter) = info.MinFilter.Convert();
-
-            var sampler = _device.NewSamplerState(new MTLSamplerDescriptor
-            {
-                BorderColor = MTLSamplerBorderColor.TransparentBlack,
-                MinFilter = minFilter,
-                MagFilter = info.MagFilter.Convert(),
-                MipFilter = mipFilter,
-                CompareFunction = info.CompareOp.Convert(),
-                LodMinClamp = info.MinLod,
-                LodMaxClamp = info.MaxLod,
-                LodAverage = false,
-                MaxAnisotropy = (uint)info.MaxAnisotropy,
-                SAddressMode = info.AddressU.Convert(),
-                TAddressMode = info.AddressV.Convert(),
-                RAddressMode = info.AddressP.Convert()
-            });
-
-            return new Sampler(sampler);
+            return new Sampler(_device, info);
         }
 
         public ITexture CreateTexture(TextureCreateInfo info)
diff --git a/src/Ryujinx.Graphics.Metal/Pipeline.cs b/src/Ryujinx.Graphics.Metal/Pipeline.cs
index b1c3c03de..4d182b7ef 100644
--- a/src/Ryujinx.Graphics.Metal/Pipeline.cs
+++ b/src/Ryujinx.Graphics.Metal/Pipeline.cs
@@ -323,7 +323,6 @@ namespace Ryujinx.Graphics.Metal
         {
             var renderCommandEncoder = GetOrCreateRenderEncoder();
 
-
             // TODO: Support topology re-indexing to provide support for TriangleFans
             var primitiveType = _renderEncoderState.Topology.Convert();
 
@@ -332,26 +331,36 @@ namespace Ryujinx.Graphics.Metal
 
         public void DrawIndexedIndirect(BufferRange indirectBuffer)
         {
+            var renderCommandEncoder = GetOrCreateRenderEncoder();
+
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
         public void DrawIndexedIndirectCount(BufferRange indirectBuffer, BufferRange parameterBuffer, int maxDrawCount, int stride)
         {
+            var renderCommandEncoder = GetOrCreateRenderEncoder();
+
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
         public void DrawIndirect(BufferRange indirectBuffer)
         {
+            var renderCommandEncoder = GetOrCreateRenderEncoder();
+
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
         public void DrawIndirectCount(BufferRange indirectBuffer, BufferRange parameterBuffer, int maxDrawCount, int stride)
         {
+            var renderCommandEncoder = GetOrCreateRenderEncoder();
+
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
         public void DrawTexture(ITexture texture, ISampler sampler, Extents2DF srcRegion, Extents2DF dstRegion)
         {
+            var renderCommandEncoder = GetOrCreateRenderEncoder();
+
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
@@ -437,15 +446,10 @@ namespace Ryujinx.Graphics.Metal
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
-        public void SetImage(int binding, ITexture texture, Format imageFormat)
-        {
-            Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
-        }
-
         public void SetLineParameters(float width, bool smooth)
         {
             // Not supported in Metal
-            Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
+            Logger.Warning?.Print(LogClass.Gpu, "Wide-line is not supported without private Metal API");
         }
 
         public void SetLogicOpState(bool enable, LogicalOp op)
@@ -493,7 +497,7 @@ namespace Ryujinx.Graphics.Metal
         {
             Program prg = (Program)program;
 
-            if (prg.VertexFunction == null)
+            if (prg.VertexFunction == IntPtr.Zero)
             {
                 Logger.Error?.PrintMsg(LogClass.Gpu, "Invalid Vertex Function!");
                 return;
@@ -556,7 +560,10 @@ namespace Ryujinx.Graphics.Metal
             fixed (MTLScissorRect* pMtlScissorRects = mtlScissorRects)
             {
                 // TODO: Fix this function which currently wont accept pointer as intended
-                // _renderCommandEncoder.SetScissorRects(pMtlScissorRects, regions.Length);
+                if (_currentEncoderType == EncoderType.Render)
+                {
+                    // new MTLRenderCommandEncoder(_currentEncoder.Value).SetScissorRects(pMtlScissorRects, (ulong)regions.Length);
+                }
             }
         }
 
@@ -621,6 +628,7 @@ namespace Ryujinx.Graphics.Metal
                     attrib.Format = MTLVertexFormat.Float4;
                     attrib.BufferIndex = (ulong)vertexAttribs[i].BufferIndex;
                     attrib.Offset = (ulong)vertexAttribs[i].Offset;
+                    _vertexDescriptor.Attributes.SetObject(attrib, (ulong)i);
                 }
             }
         }
@@ -668,17 +676,26 @@ namespace Ryujinx.Graphics.Metal
             fixed (MTLViewport* pMtlViewports = mtlViewports)
             {
                 // TODO: Fix this function which currently wont accept pointer as intended
-                // _renderCommandEncoder.SetViewports(pMtlViewports, viewports.Length);
+                if (_currentEncoderType == EncoderType.Render)
+                {
+                    // new MTLRenderCommandEncoder(_currentEncoder.Value).SetViewports(pMtlViewports, (ulong)regions.Length);
+                }
             }
         }
 
         public void TextureBarrier()
         {
+            var renderCommandEncoder = GetOrCreateRenderEncoder();
+
+            // renderCommandEncoder.MemoryBarrier(MTLBarrierScope.Textures, );
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
         public void TextureBarrierTiled()
         {
+            var renderCommandEncoder = GetOrCreateRenderEncoder();
+
+            // renderCommandEncoder.MemoryBarrier(MTLBarrierScope.Textures, );
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
diff --git a/src/Ryujinx.Graphics.Metal/Sampler.cs b/src/Ryujinx.Graphics.Metal/Sampler.cs
index cc1923cc3..f4ffecc02 100644
--- a/src/Ryujinx.Graphics.Metal/Sampler.cs
+++ b/src/Ryujinx.Graphics.Metal/Sampler.cs
@@ -1,15 +1,40 @@
 using Ryujinx.Graphics.GAL;
 using SharpMetal.Metal;
+using System.Runtime.Versioning;
 
 namespace Ryujinx.Graphics.Metal
 {
+    [SupportedOSPlatform("macos")]
     class Sampler : ISampler
     {
-        // private readonly MTLSamplerState _mtlSamplerState;
+        private readonly MTLSamplerState _mtlSamplerState;
 
-        public Sampler(MTLSamplerState mtlSamplerState)
+        public Sampler(MTLDevice device, SamplerCreateInfo info)
         {
-            // _mtlSamplerState = mtlSamplerState;
+            (MTLSamplerMinMagFilter minFilter, MTLSamplerMipFilter mipFilter) = info.MinFilter.Convert();
+
+            var samplerState = device.NewSamplerState(new MTLSamplerDescriptor
+            {
+                BorderColor = MTLSamplerBorderColor.TransparentBlack,
+                MinFilter = minFilter,
+                MagFilter = info.MagFilter.Convert(),
+                MipFilter = mipFilter,
+                CompareFunction = info.CompareOp.Convert(),
+                LodMinClamp = info.MinLod,
+                LodMaxClamp = info.MaxLod,
+                LodAverage = false,
+                MaxAnisotropy = (uint)info.MaxAnisotropy,
+                SAddressMode = info.AddressU.Convert(),
+                TAddressMode = info.AddressV.Convert(),
+                RAddressMode = info.AddressP.Convert()
+            });
+
+            _mtlSamplerState = samplerState;
+        }
+
+        public MTLSamplerState GetSampler()
+        {
+            return _mtlSamplerState;
         }
 
         public void Dispose()
diff --git a/src/Ryujinx.Graphics.Metal/Texture.cs b/src/Ryujinx.Graphics.Metal/Texture.cs
index 7ba9647ac..43ead5bcf 100644
--- a/src/Ryujinx.Graphics.Metal/Texture.cs
+++ b/src/Ryujinx.Graphics.Metal/Texture.cs
@@ -88,6 +88,12 @@ namespace Ryujinx.Graphics.Metal
 
         public void CopyTo(ITexture destination, Extents2D srcRegion, Extents2D dstRegion, bool linearFilter)
         {
+            var blitCommandEncoder = _pipeline.GetOrCreateBlitEncoder();
+
+            if (destination is Texture destinationTexture)
+            {
+
+            }
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
         }
 
@@ -120,7 +126,7 @@ namespace Ryujinx.Graphics.Metal
         public ITexture CreateView(TextureCreateInfo info, int firstLayer, int firstLevel)
         {
             Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
-            return this;
+            throw new NotImplementedException();
         }
 
         public PinnedSpan<byte> GetData()