From 1b28ecd63eb49917e3711eb7e06739ebe87e8f41 Mon Sep 17 00:00:00 2001
From: riperiperi <rhy3756547@hotmail.com>
Date: Mon, 8 May 2023 11:45:12 +0100
Subject: [PATCH] Vulkan: Simplify MultiFenceHolder and managing them (#4845)

* Vulkan: Simplify waitable add/remove

Removal of unnecessary hashset and dictionary

* Thread safety for GetBufferData in PersistentFlushBuffer

* Fix WaitForFencesImpl thread safety

* Proper methods for risky reference increments

* Wrong type of CB.

* Address feedback
---
 src/Ryujinx.Graphics.Vulkan/Auto.cs           |  17 +++
 src/Ryujinx.Graphics.Vulkan/BufferHolder.cs   |   5 +-
 .../CommandBufferPool.cs                      |  14 ++-
 src/Ryujinx.Graphics.Vulkan/FenceHolder.cs    |  19 +++
 .../MultiFenceHolder.cs                       | 117 ++++++++++++------
 .../PersistentFlushBuffer.cs                  |  14 ++-
 6 files changed, 138 insertions(+), 48 deletions(-)

diff --git a/src/Ryujinx.Graphics.Vulkan/Auto.cs b/src/Ryujinx.Graphics.Vulkan/Auto.cs
index 77261de98..fdce7232c 100644
--- a/src/Ryujinx.Graphics.Vulkan/Auto.cs
+++ b/src/Ryujinx.Graphics.Vulkan/Auto.cs
@@ -105,6 +105,23 @@ namespace Ryujinx.Graphics.Vulkan
             }
         }
 
+        public bool TryIncrementReferenceCount()
+        {
+            int lastValue;
+            do
+            {
+                lastValue = _referenceCount;
+
+                if (lastValue == 0)
+                {
+                    return false;
+                }
+            }
+            while (Interlocked.CompareExchange(ref _referenceCount, lastValue + 1, lastValue) != lastValue);
+
+            return true;
+        }
+
         public void IncrementReferenceCount()
         {
             if (Interlocked.Increment(ref _referenceCount) == 1)
diff --git a/src/Ryujinx.Graphics.Vulkan/BufferHolder.cs b/src/Ryujinx.Graphics.Vulkan/BufferHolder.cs
index a1ea6836f..9a23280d0 100644
--- a/src/Ryujinx.Graphics.Vulkan/BufferHolder.cs
+++ b/src/Ryujinx.Graphics.Vulkan/BufferHolder.cs
@@ -599,9 +599,10 @@ namespace Ryujinx.Graphics.Vulkan
             Auto<DisposableBuffer> dst,
             int srcOffset,
             int dstOffset,
-            int size)
+            int size,
+            bool registerSrcUsage = true)
         {
-            var srcBuffer = src.Get(cbs, srcOffset, size).Value;
+            var srcBuffer = registerSrcUsage ? src.Get(cbs, srcOffset, size).Value : src.GetUnsafe().Value;
             var dstBuffer = dst.Get(cbs, dstOffset, size).Value;
 
             InsertBufferBarrier(
diff --git a/src/Ryujinx.Graphics.Vulkan/CommandBufferPool.cs b/src/Ryujinx.Graphics.Vulkan/CommandBufferPool.cs
index 4cbb24ef7..42b46eaec 100644
--- a/src/Ryujinx.Graphics.Vulkan/CommandBufferPool.cs
+++ b/src/Ryujinx.Graphics.Vulkan/CommandBufferPool.cs
@@ -31,7 +31,7 @@ namespace Ryujinx.Graphics.Vulkan
             public SemaphoreHolder Semaphore;
 
             public List<IAuto> Dependants;
-            public HashSet<MultiFenceHolder> Waitables;
+            public List<MultiFenceHolder> Waitables;
             public HashSet<SemaphoreHolder> Dependencies;
 
             public void Initialize(Vk api, Device device, CommandPool pool)
@@ -47,7 +47,7 @@ namespace Ryujinx.Graphics.Vulkan
                 api.AllocateCommandBuffers(device, allocateInfo, out CommandBuffer);
 
                 Dependants = new List<IAuto>();
-                Waitables = new HashSet<MultiFenceHolder>();
+                Waitables = new List<MultiFenceHolder>();
                 Dependencies = new HashSet<SemaphoreHolder>();
             }
         }
@@ -143,8 +143,10 @@ namespace Ryujinx.Graphics.Vulkan
         public void AddWaitable(int cbIndex, MultiFenceHolder waitable)
         {
             ref var entry = ref _commandBuffers[cbIndex];
-            waitable.AddFence(cbIndex, entry.Fence);
-            entry.Waitables.Add(waitable);
+            if (waitable.AddFence(cbIndex, entry.Fence))
+            {
+                entry.Waitables.Add(waitable);
+            }
         }
 
         public bool HasWaitableOnRentedCommandBuffer(MultiFenceHolder waitable, int offset, int size)
@@ -156,7 +158,7 @@ namespace Ryujinx.Graphics.Vulkan
                     ref var entry = ref _commandBuffers[i];
 
                     if (entry.InUse &&
-                        entry.Waitables.Contains(waitable) &&
+                        waitable.HasFence(i) &&
                         waitable.IsBufferRangeInUse(i, offset, size))
                     {
                         return true;
@@ -331,7 +333,7 @@ namespace Ryujinx.Graphics.Vulkan
 
             foreach (var waitable in entry.Waitables)
             {
-                waitable.RemoveFence(cbIndex, entry.Fence);
+                waitable.RemoveFence(cbIndex);
                 waitable.RemoveBufferUses(cbIndex);
             }
 
diff --git a/src/Ryujinx.Graphics.Vulkan/FenceHolder.cs b/src/Ryujinx.Graphics.Vulkan/FenceHolder.cs
index 1c1e62407..39d226983 100644
--- a/src/Ryujinx.Graphics.Vulkan/FenceHolder.cs
+++ b/src/Ryujinx.Graphics.Vulkan/FenceHolder.cs
@@ -32,6 +32,25 @@ namespace Ryujinx.Graphics.Vulkan
             return _fence;
         }
 
+        public bool TryGet(out Fence fence)
+        {
+            int lastValue;
+            do
+            {
+                lastValue = _referenceCount;
+
+                if (lastValue == 0)
+                {
+                    fence = default;
+                    return false;
+                }
+            }
+            while (Interlocked.CompareExchange(ref _referenceCount, lastValue + 1, lastValue) != lastValue);
+
+            fence = _fence;
+            return true;
+        }
+
         public Fence Get()
         {
             Interlocked.Increment(ref _referenceCount);
diff --git a/src/Ryujinx.Graphics.Vulkan/MultiFenceHolder.cs b/src/Ryujinx.Graphics.Vulkan/MultiFenceHolder.cs
index 9a9a3626c..13a4f4c14 100644
--- a/src/Ryujinx.Graphics.Vulkan/MultiFenceHolder.cs
+++ b/src/Ryujinx.Graphics.Vulkan/MultiFenceHolder.cs
@@ -1,6 +1,5 @@
 using Silk.NET.Vulkan;
-using System.Collections.Generic;
-using System.Linq;
+using System;
 
 namespace Ryujinx.Graphics.Vulkan
 {
@@ -11,7 +10,7 @@ namespace Ryujinx.Graphics.Vulkan
     {
         private static int BufferUsageTrackingGranularity = 4096;
 
-        private readonly Dictionary<FenceHolder, int> _fences;
+        private readonly FenceHolder[] _fences;
         private BufferUsageBitmap _bufferUsageBitmap;
 
         /// <summary>
@@ -19,7 +18,7 @@ namespace Ryujinx.Graphics.Vulkan
         /// </summary>
         public MultiFenceHolder()
         {
-            _fences = new Dictionary<FenceHolder, int>();
+            _fences = new FenceHolder[CommandBufferPool.MaxCommandBuffers];
         }
 
         /// <summary>
@@ -28,7 +27,7 @@ namespace Ryujinx.Graphics.Vulkan
         /// <param name="size">Size of the buffer</param>
         public MultiFenceHolder(int size)
         {
-            _fences = new Dictionary<FenceHolder, int>();
+            _fences = new FenceHolder[CommandBufferPool.MaxCommandBuffers];
             _bufferUsageBitmap = new BufferUsageBitmap(size, BufferUsageTrackingGranularity);
         }
 
@@ -80,25 +79,37 @@ namespace Ryujinx.Graphics.Vulkan
         /// </summary>
         /// <param name="cbIndex">Command buffer index of the command buffer that owns the fence</param>
         /// <param name="fence">Fence to be added</param>
-        public void AddFence(int cbIndex, FenceHolder fence)
+        /// <returns>True if the command buffer's previous fence value was null</returns>
+        public bool AddFence(int cbIndex, FenceHolder fence)
         {
-            lock (_fences)
+            ref FenceHolder fenceRef = ref _fences[cbIndex];
+
+            if (fenceRef == null)
             {
-                _fences.TryAdd(fence, cbIndex);
+                fenceRef = fence;
+                return true;
             }
+
+            return false;
         }
 
         /// <summary>
         /// Removes a fence from the holder.
         /// </summary>
         /// <param name="cbIndex">Command buffer index of the command buffer that owns the fence</param>
-        /// <param name="fence">Fence to be removed</param>
-        public void RemoveFence(int cbIndex, FenceHolder fence)
+        public void RemoveFence(int cbIndex)
         {
-            lock (_fences)
-            {
-                _fences.Remove(fence);
-            }
+            _fences[cbIndex] = null;
+        }
+
+        /// <summary>
+        /// Determines if a fence referenced on the given command buffer.
+        /// </summary>
+        /// <param name="cbIndex">Index of the command buffer to check if it's used</param>
+        /// <returns>True if referenced, false otherwise</returns>
+        public bool HasFence(int cbIndex)
+        {
+            return _fences[cbIndex] != null;
         }
 
         /// <summary>
@@ -147,21 +158,29 @@ namespace Ryujinx.Graphics.Vulkan
         /// <returns>True if all fences were signaled before the timeout expired, false otherwise</returns>
         private bool WaitForFencesImpl(Vk api, Device device, int offset, int size, bool hasTimeout, ulong timeout)
         {
-            FenceHolder[] fenceHolders;
-            Fence[] fences;
+            Span<FenceHolder> fenceHolders = new FenceHolder[CommandBufferPool.MaxCommandBuffers];
 
-            lock (_fences)
+            int count = size != 0 ? GetOverlappingFences(fenceHolders, offset, size) : GetFences(fenceHolders);
+            Span<Fence> fences = stackalloc Fence[count];
+
+            int fenceCount = 0;
+
+            for (int i = 0; i < count; i++)
             {
-                fenceHolders = size != 0 ? GetOverlappingFences(offset, size) : _fences.Keys.ToArray();
-                fences = new Fence[fenceHolders.Length];
-
-                for (int i = 0; i < fenceHolders.Length; i++)
+                if (fenceHolders[i].TryGet(out Fence fence))
                 {
-                    fences[i] = fenceHolders[i].Get();
+                    fences[fenceCount] = fence;
+
+                    if (fenceCount < i)
+                    {
+                        fenceHolders[fenceCount] = fenceHolders[i];
+                    }
+
+                    fenceCount++;
                 }
             }
 
-            if (fences.Length == 0)
+            if (fenceCount == 0)
             {
                 return true;
             }
@@ -170,14 +189,14 @@ namespace Ryujinx.Graphics.Vulkan
 
             if (hasTimeout)
             {
-                signaled = FenceHelper.AllSignaled(api, device, fences, timeout);
+                signaled = FenceHelper.AllSignaled(api, device, fences.Slice(0, fenceCount), timeout);
             }
             else
             {
-                FenceHelper.WaitAllIndefinitely(api, device, fences);
+                FenceHelper.WaitAllIndefinitely(api, device, fences.Slice(0, fenceCount));
             }
 
-            for (int i = 0; i < fenceHolders.Length; i++)
+            for (int i = 0; i < fenceCount; i++)
             {
                 fenceHolders[i].Put();
             }
@@ -186,27 +205,49 @@ namespace Ryujinx.Graphics.Vulkan
         }
 
         /// <summary>
-        /// Gets fences to wait for use of a given buffer region.
+        /// Gets fences to wait for.
         /// </summary>
-        /// <param name="offset">Offset of the range</param>
-        /// <param name="size">Size of the range in bytes</param>
-        /// <returns>Fences for the specified region</returns>
-        private FenceHolder[] GetOverlappingFences(int offset, int size)
+        /// <param name="storage">Span to store fences in</param>
+        /// <returns>Number of fences placed in storage</returns>
+        private int GetFences(Span<FenceHolder> storage)
         {
-            List<FenceHolder> overlapping = new List<FenceHolder>();
+            int count = 0;
 
-            foreach (var kv in _fences)
+            for (int i = 0; i < _fences.Length; i++)
             {
-                var fence = kv.Key;
-                var ownerCbIndex = kv.Value;
+                var fence = _fences[i];
 
-                if (_bufferUsageBitmap.OverlapsWith(ownerCbIndex, offset, size))
+                if (fence != null)
                 {
-                    overlapping.Add(fence);
+                    storage[count++] = fence;
                 }
             }
 
-            return overlapping.ToArray();
+            return count;
+        }
+
+        /// <summary>
+        /// Gets fences to wait for use of a given buffer region.
+        /// </summary>
+        /// <param name="storage">Span to store overlapping fences in</param>
+        /// <param name="offset">Offset of the range</param>
+        /// <param name="size">Size of the range in bytes</param>
+        /// <returns>Number of fences for the specified region placed in storage</returns>
+        private int GetOverlappingFences(Span<FenceHolder> storage, int offset, int size)
+        {
+            int count = 0;
+
+            for (int i = 0; i < _fences.Length; i++)
+            {
+                var fence = _fences[i];
+
+                if (fence != null && _bufferUsageBitmap.OverlapsWith(i, offset, size))
+                {
+                    storage[count++] = fence;
+                }
+            }
+
+            return count;
         }
     }
 }
diff --git a/src/Ryujinx.Graphics.Vulkan/PersistentFlushBuffer.cs b/src/Ryujinx.Graphics.Vulkan/PersistentFlushBuffer.cs
index fca13c314..fc98b68f7 100644
--- a/src/Ryujinx.Graphics.Vulkan/PersistentFlushBuffer.cs
+++ b/src/Ryujinx.Graphics.Vulkan/PersistentFlushBuffer.cs
@@ -34,16 +34,26 @@ namespace Ryujinx.Graphics.Vulkan
         public Span<byte> GetBufferData(CommandBufferPool cbp, BufferHolder buffer, int offset, int size)
         {
             var flushStorage = ResizeIfNeeded(size);
+            Auto<DisposableBuffer> srcBuffer;
 
             using (var cbs = cbp.Rent())
             {
-                var srcBuffer = buffer.GetBuffer(cbs.CommandBuffer);
+                srcBuffer = buffer.GetBuffer(cbs.CommandBuffer);
                 var dstBuffer = flushStorage.GetBuffer(cbs.CommandBuffer);
 
-                BufferHolder.Copy(_gd, cbs, srcBuffer, dstBuffer, offset, 0, size);
+                if (srcBuffer.TryIncrementReferenceCount())
+                {
+                    BufferHolder.Copy(_gd, cbs, srcBuffer, dstBuffer, offset, 0, size, registerSrcUsage: false);
+                }
+                else
+                {
+                    // Source buffer is no longer alive, don't copy anything to flush storage.
+                    srcBuffer = null;
+                }
             }
 
             flushStorage.WaitForFences();
+            srcBuffer?.DecrementReferenceCount();
             return flushStorage.GetDataStorage(0, size);
         }