From 00fc8daf56d0a070dc75036cb13ffbfc7a6567c6 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sun, 28 Oct 2018 13:44:12 -0300
Subject: [PATCH] Use variant instead of creating an object for literals

---
 include/sirit/sirit.h    |  22 +++---
 src/CMakeLists.txt       |   1 -
 src/insts/annotation.cpp |  13 ++--
 src/insts/constant.cpp   |  15 ++--
 src/insts/debug.cpp      |   2 +-
 src/insts/flow.cpp       |  23 +++----
 src/insts/function.cpp   |   9 ++-
 src/insts/type.cpp       | 144 +++++++++++++++++++--------------------
 src/literal.cpp          |  26 -------
 src/op.cpp               |  28 ++++++++
 src/op.h                 |   4 ++
 src/sirit.cpp            |   3 +-
 12 files changed, 146 insertions(+), 144 deletions(-)
 delete mode 100644 src/literal.cpp

diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h
index 61c21a1..8ddf9b5 100644
--- a/include/sirit/sirit.h
+++ b/include/sirit/sirit.h
@@ -11,6 +11,7 @@
 #include <optional>
 #include <set>
 #include <spirv/unified1/spirv.hpp11>
+#include <variant>
 #include <vector>
 
 namespace Sirit {
@@ -20,7 +21,9 @@ constexpr std::uint32_t GENERATOR_MAGIC_NUMBER = 0;
 class Op;
 class Operand;
 
-typedef const Op* Ref;
+using Literal = std::variant<std::uint32_t, std::uint64_t, std::int32_t,
+                             std::int64_t, float, double>;
+using Ref = const Op*;
 
 class Module {
   public:
@@ -135,7 +138,7 @@ class Module {
     Ref ConstantFalse(Ref result_type);
 
     /// Returns a numeric scalar constant.
-    Ref Constant(Ref result_type, Operand* literal);
+    Ref Constant(Ref result_type, const Literal& literal);
 
     /// Returns a numeric scalar constant.
     Ref ConstantComposite(Ref result_type,
@@ -201,18 +204,11 @@ class Module {
 
     /// Add a decoration to target.
     Ref Decorate(Ref target, spv::Decoration decoration,
-                 const std::vector<Operand*>& literals = {});
+                 const std::vector<Literal>& literals = {});
 
-    Ref MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration,
-            const std::vector<Operand*>& literals = {});
-
-    // Literals
-    static Operand* Literal(std::uint32_t value);
-    static Operand* Literal(std::uint64_t value);
-    static Operand* Literal(std::int32_t value);
-    static Operand* Literal(std::int64_t value);
-    static Operand* Literal(float value);
-    static Operand* Literal(double value);
+    Ref MemberDecorate(Ref structure_type, Literal member,
+                       spv::Decoration decoration,
+                       const std::vector<Literal>& literals = {});
 
   private:
     Ref AddCode(Op* op);
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 8602c9f..db94e83 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -7,7 +7,6 @@ add_library(sirit
     stream.h
     operand.cpp
     operand.h
-    literal.cpp
     literal-number.cpp
     literal-number.h
     literal-string.cpp
diff --git a/src/insts/annotation.cpp b/src/insts/annotation.cpp
index c54adfc..485ad6b 100644
--- a/src/insts/annotation.cpp
+++ b/src/insts/annotation.cpp
@@ -10,21 +10,22 @@
 namespace Sirit {
 
 Ref Module::Decorate(Ref target, spv::Decoration decoration,
-                     const std::vector<Operand*>& literals) {
+                     const std::vector<Literal>& literals) {
     auto op{new Op(spv::Op::OpDecorate)};
     op->Add(target);
     AddEnum(op, decoration);
-    op->Sink(literals);
+    op->Add(literals);
     return AddAnnotation(op);
 }
 
-Ref Module::MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration,
-                           const std::vector<Operand*>& literals) {
+Ref Module::MemberDecorate(Ref structure_type, Literal member,
+                           spv::Decoration decoration,
+                           const std::vector<Literal>& literals) {
     auto op{new Op(spv::Op::OpMemberDecorate)};
     op->Add(structure_type);
-    op->Sink(member);
+    op->Add(member);
     AddEnum(op, decoration);
-    op->Sink(literals);
+    op->Add(literals);
     return AddAnnotation(op);
 }
 
diff --git a/src/insts/constant.cpp b/src/insts/constant.cpp
index 3360603..d5e8802 100644
--- a/src/insts/constant.cpp
+++ b/src/insts/constant.cpp
@@ -4,9 +4,9 @@
  * Lesser General Public License version 2.1 or any later version.
  */
 
-#include <cassert>
-#include "sirit/sirit.h"
 #include "insts.h"
+#include "sirit/sirit.h"
+#include <cassert>
 
 namespace Sirit {
 
@@ -18,20 +18,23 @@ Ref Module::ConstantFalse(Ref result_type) {
     return AddDeclaration(new Op(spv::Op::OpConstantFalse, bound, result_type));
 }
 
-Ref Module::Constant(Ref result_type, Operand* literal) {
+Ref Module::Constant(Ref result_type, const Literal& literal) {
     auto op{new Op(spv::Op::OpConstant, bound, result_type)};
     op->Add(literal);
     return AddDeclaration(op);
 }
 
-Ref Module::ConstantComposite(Ref result_type, const std::vector<Ref>& constituents) {
+Ref Module::ConstantComposite(Ref result_type,
+                              const std::vector<Ref>& constituents) {
     auto op{new Op(spv::Op::OpConstantComposite, bound, result_type)};
     op->Add(constituents);
     return AddDeclaration(op);
 }
 
-Ref Module::ConstantSampler(Ref result_type, spv::SamplerAddressingMode addressing_mode,
-                            bool normalized, spv::SamplerFilterMode filter_mode) {
+Ref Module::ConstantSampler(Ref result_type,
+                            spv::SamplerAddressingMode addressing_mode,
+                            bool normalized,
+                            spv::SamplerFilterMode filter_mode) {
     AddCapability(spv::Capability::LiteralSampler);
     AddCapability(spv::Capability::Kernel);
     auto op{new Op(spv::Op::OpConstantSampler, bound, result_type)};
diff --git a/src/insts/debug.cpp b/src/insts/debug.cpp
index c557511..3822dcc 100644
--- a/src/insts/debug.cpp
+++ b/src/insts/debug.cpp
@@ -4,8 +4,8 @@
  * Lesser General Public License version 2.1 or any later version.
  */
 
-#include "sirit/sirit.h"
 #include "insts.h"
+#include "sirit/sirit.h"
 
 namespace Sirit {
 
diff --git a/src/insts/flow.cpp b/src/insts/flow.cpp
index 056e82e..c88df40 100644
--- a/src/insts/flow.cpp
+++ b/src/insts/flow.cpp
@@ -4,12 +4,13 @@
  * Lesser General Public License version 2.1 or any later version.
  */
 
-#include "sirit/sirit.h"
 #include "insts.h"
+#include "sirit/sirit.h"
 
 namespace Sirit {
 
-Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask loop_control,
+Ref Module::LoopMerge(Ref merge_block, Ref continue_target,
+                      spv::LoopControlMask loop_control,
                       const std::vector<Ref>& literals) {
     auto op{new Op(spv::Op::OpLoopMerge)};
     op->Add(merge_block);
@@ -19,16 +20,15 @@ Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask
     return AddCode(op);
 }
 
-Ref Module::SelectionMerge(Ref merge_block, spv::SelectionControlMask selection_control) {
+Ref Module::SelectionMerge(Ref merge_block,
+                           spv::SelectionControlMask selection_control) {
     auto op{new Op(spv::Op::OpSelectionMerge)};
     op->Add(merge_block);
     AddEnum(op, selection_control);
     return AddCode(op);
 }
 
-Ref Module::Label() {
-    return AddCode(spv::Op::OpLabel, bound++);
-}
+Ref Module::Label() { return AddCode(spv::Op::OpLabel, bound++); }
 
 Ref Module::Branch(Ref target_label) {
     auto op{new Op(spv::Op::OpBranch)};
@@ -37,20 +37,19 @@ Ref Module::Branch(Ref target_label) {
 }
 
 Ref Module::BranchConditional(Ref condition, Ref true_label, Ref false_label,
-                              std::uint32_t true_weight, std::uint32_t false_weight) {
+                              std::uint32_t true_weight,
+                              std::uint32_t false_weight) {
     auto op{new Op(spv::Op::OpBranchConditional)};
     op->Add(condition);
     op->Add(true_label);
     op->Add(false_label);
     if (true_weight != 0 || false_weight != 0) {
-        op->Add(Literal(true_weight));
-        op->Add(Literal(false_weight));
+        op->Add(true_weight);
+        op->Add(false_weight);
     }
     return AddCode(op);
 }
 
-Ref Module::Return() {
-    return AddCode(spv::Op::OpReturn);
-}
+Ref Module::Return() { return AddCode(spv::Op::OpReturn); }
 
 } // namespace Sirit
diff --git a/src/insts/function.cpp b/src/insts/function.cpp
index 9b8ee0f..efcc2c6 100644
--- a/src/insts/function.cpp
+++ b/src/insts/function.cpp
@@ -4,20 +4,19 @@
  * Lesser General Public License version 2.1 or any later version.
  */
 
-#include "sirit/sirit.h"
 #include "insts.h"
+#include "sirit/sirit.h"
 
 namespace Sirit {
 
-Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type) {
+Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control,
+                     Ref function_type) {
     auto op{new Op{spv::Op::OpFunction, bound++, result_type}};
     op->Add(static_cast<u32>(function_control));
     op->Add(function_type);
     return AddCode(op);
 }
 
-Ref Module::FunctionEnd() {
-    return AddCode(spv::Op::OpFunctionEnd);
-}
+Ref Module::FunctionEnd() { return AddCode(spv::Op::OpFunctionEnd); }
 
 } // namespace Sirit
diff --git a/src/insts/type.cpp b/src/insts/type.cpp
index 4a2e1a5..2587ff4 100644
--- a/src/insts/type.cpp
+++ b/src/insts/type.cpp
@@ -7,8 +7,8 @@
 #include <cassert>
 #include <optional>
 
-#include "sirit/sirit.h"
 #include "insts.h"
+#include "sirit/sirit.h"
 
 namespace Sirit {
 
@@ -62,68 +62,68 @@ Ref Module::TypeMatrix(Ref column_type, int column_count) {
     return AddDeclaration(op);
 }
 
-Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms,
-                      int sampled, spv::ImageFormat image_format,
+Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed,
+                      bool ms, int sampled, spv::ImageFormat image_format,
                       std::optional<spv::AccessQualifier> access_qualifier) {
     switch (dim) {
-        case spv::Dim::Dim1D:
-            AddCapability(spv::Capability::Sampled1D);
-            break;
-        case spv::Dim::Cube:
-            AddCapability(spv::Capability::Shader);
-            break;
-        case spv::Dim::Rect:
-            AddCapability(spv::Capability::SampledRect);
-            break;
-        case spv::Dim::Buffer:
-            AddCapability(spv::Capability::SampledBuffer);
-            break;
-        case spv::Dim::SubpassData:
-            AddCapability(spv::Capability::InputAttachment);
-            break;
+    case spv::Dim::Dim1D:
+        AddCapability(spv::Capability::Sampled1D);
+        break;
+    case spv::Dim::Cube:
+        AddCapability(spv::Capability::Shader);
+        break;
+    case spv::Dim::Rect:
+        AddCapability(spv::Capability::SampledRect);
+        break;
+    case spv::Dim::Buffer:
+        AddCapability(spv::Capability::SampledBuffer);
+        break;
+    case spv::Dim::SubpassData:
+        AddCapability(spv::Capability::InputAttachment);
+        break;
     }
     switch (image_format) {
-        case spv::ImageFormat::Rgba32f:
-        case spv::ImageFormat::Rgba16f:
-        case spv::ImageFormat::R32f:
-        case spv::ImageFormat::Rgba8:
-        case spv::ImageFormat::Rgba8Snorm:
-        case spv::ImageFormat::Rgba32i:
-        case spv::ImageFormat::Rgba16i:
-        case spv::ImageFormat::Rgba8i:
-        case spv::ImageFormat::R32i:
-        case spv::ImageFormat::Rgba32ui:
-        case spv::ImageFormat::Rgba16ui:
-        case spv::ImageFormat::Rgba8ui:
-        case spv::ImageFormat::R32ui:
-            AddCapability(spv::Capability::Shader);
-            break;
-        case spv::ImageFormat::Rg32f:
-        case spv::ImageFormat::Rg16f:
-        case spv::ImageFormat::R11fG11fB10f:
-        case spv::ImageFormat::R16f:
-        case spv::ImageFormat::Rgba16:
-        case spv::ImageFormat::Rgb10A2:
-        case spv::ImageFormat::Rg16:
-        case spv::ImageFormat::Rg8:
-        case spv::ImageFormat::R16:
-        case spv::ImageFormat::R8:
-        case spv::ImageFormat::Rgba16Snorm:
-        case spv::ImageFormat::Rg16Snorm:
-        case spv::ImageFormat::Rg8Snorm:
-        case spv::ImageFormat::Rg32i:
-        case spv::ImageFormat::Rg16i:
-        case spv::ImageFormat::Rg8i:
-        case spv::ImageFormat::R16i:
-        case spv::ImageFormat::R8i:
-        case spv::ImageFormat::Rgb10a2ui:
-        case spv::ImageFormat::Rg32ui:
-        case spv::ImageFormat::Rg16ui:
-        case spv::ImageFormat::Rg8ui:
-        case spv::ImageFormat::R16ui:
-        case spv::ImageFormat::R8ui:
-            AddCapability(spv::Capability::StorageImageExtendedFormats);
-            break;
+    case spv::ImageFormat::Rgba32f:
+    case spv::ImageFormat::Rgba16f:
+    case spv::ImageFormat::R32f:
+    case spv::ImageFormat::Rgba8:
+    case spv::ImageFormat::Rgba8Snorm:
+    case spv::ImageFormat::Rgba32i:
+    case spv::ImageFormat::Rgba16i:
+    case spv::ImageFormat::Rgba8i:
+    case spv::ImageFormat::R32i:
+    case spv::ImageFormat::Rgba32ui:
+    case spv::ImageFormat::Rgba16ui:
+    case spv::ImageFormat::Rgba8ui:
+    case spv::ImageFormat::R32ui:
+        AddCapability(spv::Capability::Shader);
+        break;
+    case spv::ImageFormat::Rg32f:
+    case spv::ImageFormat::Rg16f:
+    case spv::ImageFormat::R11fG11fB10f:
+    case spv::ImageFormat::R16f:
+    case spv::ImageFormat::Rgba16:
+    case spv::ImageFormat::Rgb10A2:
+    case spv::ImageFormat::Rg16:
+    case spv::ImageFormat::Rg8:
+    case spv::ImageFormat::R16:
+    case spv::ImageFormat::R8:
+    case spv::ImageFormat::Rgba16Snorm:
+    case spv::ImageFormat::Rg16Snorm:
+    case spv::ImageFormat::Rg8Snorm:
+    case spv::ImageFormat::Rg32i:
+    case spv::ImageFormat::Rg16i:
+    case spv::ImageFormat::Rg8i:
+    case spv::ImageFormat::R16i:
+    case spv::ImageFormat::R8i:
+    case spv::ImageFormat::Rgb10a2ui:
+    case spv::ImageFormat::Rg32ui:
+    case spv::ImageFormat::Rg16ui:
+    case spv::ImageFormat::Rg8ui:
+    case spv::ImageFormat::R16ui:
+    case spv::ImageFormat::R8ui:
+        AddCapability(spv::Capability::StorageImageExtendedFormats);
+        break;
     }
     auto op{new Op(spv::Op::OpTypeImage, bound)};
     op->Add(sampled_type);
@@ -179,19 +179,19 @@ Ref Module::TypeOpaque(const std::string& name) {
 
 Ref Module::TypePointer(spv::StorageClass storage_class, Ref type) {
     switch (storage_class) {
-        case spv::StorageClass::Uniform:
-        case spv::StorageClass::Output:
-        case spv::StorageClass::Private:
-        case spv::StorageClass::PushConstant:
-        case spv::StorageClass::StorageBuffer:
-            AddCapability(spv::Capability::Shader);
-            break;
-        case spv::StorageClass::Generic:
-            AddCapability(spv::Capability::GenericPointer);
-            break;
-        case spv::StorageClass::AtomicCounter:
-            AddCapability(spv::Capability::AtomicStorage);
-            break;
+    case spv::StorageClass::Uniform:
+    case spv::StorageClass::Output:
+    case spv::StorageClass::Private:
+    case spv::StorageClass::PushConstant:
+    case spv::StorageClass::StorageBuffer:
+        AddCapability(spv::Capability::Shader);
+        break;
+    case spv::StorageClass::Generic:
+        AddCapability(spv::Capability::GenericPointer);
+        break;
+    case spv::StorageClass::AtomicCounter:
+        AddCapability(spv::Capability::AtomicStorage);
+        break;
     }
     auto op{new Op(spv::Op::OpTypePointer, bound)};
     op->Add(static_cast<u32>(storage_class));
diff --git a/src/literal.cpp b/src/literal.cpp
deleted file mode 100644
index ba8738a..0000000
--- a/src/literal.cpp
+++ /dev/null
@@ -1,26 +0,0 @@
-/* This file is part of the sirit project.
- * Copyright (c) 2018 ReinUsesLisp
- * This software may be used and distributed according to the terms of the GNU
- * Lesser General Public License version 2.1 or any later version.
- */
-
-#include "common_types.h"
-#include "literal-number.h"
-#include "operand.h"
-#include "sirit/sirit.h"
-
-namespace Sirit {
-
-#define DEFINE_LITERAL(type)                                                   \
-    Operand* Module::Literal(type value) {                                     \
-        return LiteralNumber::Create<type>(value);                             \
-    }
-
-DEFINE_LITERAL(u32)
-DEFINE_LITERAL(u64)
-DEFINE_LITERAL(s32)
-DEFINE_LITERAL(s64)
-DEFINE_LITERAL(f32)
-DEFINE_LITERAL(f64)
-
-} // namespace Sirit
diff --git a/src/op.cpp b/src/op.cpp
index c1d3e19..ea228ee 100644
--- a/src/op.cpp
+++ b/src/op.cpp
@@ -71,6 +71,34 @@ void Op::Sink(const std::vector<Operand*>& operands) {
     }
 }
 
+void Op::Add(const Literal& literal) {
+    Operand* operand = [&]() {
+        switch (literal.index()) {
+        case 0:
+            return LiteralNumber::Create(std::get<0>(literal));
+        case 1:
+            return LiteralNumber::Create(std::get<1>(literal));
+        case 2:
+            return LiteralNumber::Create(std::get<2>(literal));
+        case 3:
+            return LiteralNumber::Create(std::get<3>(literal));
+        case 4:
+            return LiteralNumber::Create(std::get<4>(literal));
+        case 5:
+            return LiteralNumber::Create(std::get<5>(literal));
+        default:
+            assert(!"invalid literal type");
+        }
+    }();
+    Sink(operand);
+}
+
+void Op::Add(const std::vector<Literal>& literals) {
+    for (const auto& literal : literals) {
+        Add(literal);
+    }
+}
+
 void Op::Add(const Operand* operand) { operands.push_back(operand); }
 
 void Op::Add(u32 integer) { Sink(LiteralNumber::Create<u32>(integer)); }
diff --git a/src/op.h b/src/op.h
index 87e51c1..0e13cdd 100644
--- a/src/op.h
+++ b/src/op.h
@@ -31,6 +31,10 @@ class Op : public Operand {
 
     void Sink(const std::vector<Operand*>& operands);
 
+    void Add(const Literal& literal);
+
+    void Add(const std::vector<Literal>& literals);
+
     void Add(const Operand* operand);
 
     void Add(u32 integer);
diff --git a/src/sirit.cpp b/src/sirit.cpp
index fee7325..9132958 100644
--- a/src/sirit.cpp
+++ b/src/sirit.cpp
@@ -20,8 +20,7 @@ static void WriteEnum(Stream& stream, spv::Op opcode, T value) {
     op.Write(stream);
 }
 
-template <typename T>
-static void WriteSet(Stream& stream, const T& set) {
+template <typename T> static void WriteSet(Stream& stream, const T& set) {
     for (const auto& item : set) {
         item->Write(stream);
     }