1
0
Fork 0
mirror of https://git.suyu.dev/suyu/sirit.git synced 2024-12-22 20:22:02 +00:00

Use some C++17 features

This commit is contained in:
ReinUsesLisp 2018-10-03 00:32:45 -03:00
parent 45555c0e57
commit 0485e1877c
20 changed files with 170 additions and 245 deletions

6
.clang-format Normal file
View file

@ -0,0 +1,6 @@
BasedOnStyle: LLVM
IndentWidth: 4
Language: Cpp
DerivePointerAlignment: false
PointerAlignment: Left

View file

@ -7,16 +7,15 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <set>
#include <vector>
#include <memory> #include <memory>
#include <optional>
#include <set>
#include <spirv/unified1/spirv.hpp11> #include <spirv/unified1/spirv.hpp11>
#include <vector>
namespace Sirit { namespace Sirit {
static const std::uint32_t GeneratorMagicNumber = 0; constexpr std::uint32_t GENERATOR_MAGIC_NUMBER = 0;
static const std::uint32_t Undefined = UINT32_MAX;
class Op; class Op;
class Operand; class Operand;
@ -24,13 +23,14 @@ class Operand;
typedef const Op* Ref; typedef const Op* Ref;
class Module { class Module {
public: public:
explicit Module(); explicit Module();
~Module(); ~Module();
/** /**
* Assembles current module into a SPIR-V stream. * Assembles current module into a SPIR-V stream.
* It can be called multiple times but it's recommended to copy code externally. * It can be called multiple times but it's recommended to copy code
* externally.
* @return A stream of bytes representing a SPIR-V module. * @return A stream of bytes representing a SPIR-V module.
*/ */
std::vector<std::uint8_t> Assemble() const; std::vector<std::uint8_t> Assemble() const;
@ -46,15 +46,18 @@ public:
void AddCapability(spv::Capability capability); void AddCapability(spv::Capability capability);
/// Sets module memory model. /// Sets module memory model.
void SetMemoryModel(spv::AddressingModel addressing_model, spv::MemoryModel memory_model); void SetMemoryModel(spv::AddressingModel addressing_model,
spv::MemoryModel memory_model);
/// Adds an entry point. /// Adds an entry point.
void AddEntryPoint(spv::ExecutionModel execution_model, Ref entry_point, void AddEntryPoint(spv::ExecutionModel execution_model, Ref entry_point,
const std::string& name, const std::vector<Ref>& interfaces = {}); const std::string& name,
const std::vector<Ref>& interfaces = {});
/** /**
* Adds an instruction to module's code * Adds an instruction to module's code
* @param op Instruction to insert into code. Types and constants must not be emitted. * @param op Instruction to insert into code. Types and constants must not
* be emitted.
* @return Returns op. * @return Returns op.
*/ */
Ref Emit(Ref op); Ref Emit(Ref op);
@ -80,9 +83,9 @@ public:
Ref TypeMatrix(Ref column_type, int column_count); Ref TypeMatrix(Ref column_type, int column_count);
/// Returns type image. /// Returns type image.
Ref TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, Ref TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed,
int sampled, spv::ImageFormat image_format, bool ms, int sampled, spv::ImageFormat image_format,
spv::AccessQualifier access_qualifier = static_cast<spv::AccessQualifier>(Undefined)); std::optional<spv::AccessQualifier> access_qualifier = {});
/// Returns type sampler. /// Returns type sampler.
Ref TypeSampler(); Ref TypeSampler();
@ -135,10 +138,12 @@ public:
Ref Constant(Ref result_type, Operand* literal); Ref Constant(Ref result_type, Operand* literal);
/// Returns a numeric scalar constant. /// Returns a numeric scalar constant.
Ref ConstantComposite(Ref result_type, const std::vector<Ref>& constituents); Ref ConstantComposite(Ref result_type,
const std::vector<Ref>& constituents);
/// Returns a sampler constant. /// Returns a sampler constant.
Ref ConstantSampler(Ref result_type, spv::SamplerAddressingMode addressing_mode, Ref ConstantSampler(Ref result_type,
spv::SamplerAddressingMode addressing_mode,
bool normalized, spv::SamplerFilterMode filter_mode); bool normalized, spv::SamplerFilterMode filter_mode);
/// Returns a null constant value. /// Returns a null constant value.
@ -147,7 +152,8 @@ public:
// Function // Function
/// Declares a function. /// Declares a function.
Ref Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type); Ref Function(Ref result_type, spv::FunctionControlMask function_control,
Ref function_type);
/// Ends a function. /// Ends a function.
Ref FunctionEnd(); Ref FunctionEnd();
@ -155,21 +161,26 @@ public:
// Flow // Flow
/// Declare a structured loop. /// Declare a structured loop.
Ref LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask loop_control, Ref LoopMerge(Ref merge_block, Ref continue_target,
spv::LoopControlMask loop_control,
const std::vector<Ref>& literals = {}); const std::vector<Ref>& literals = {});
/// Declare a structured selection. /// Declare a structured selection.
Ref SelectionMerge(Ref merge_block, spv::SelectionControlMask selection_control); Ref SelectionMerge(Ref merge_block,
spv::SelectionControlMask selection_control);
/// The block label instruction: Any reference to a block is through this ref. /// The block label instruction: Any reference to a block is through this
/// ref.
Ref Label(); Ref Label();
/// Unconditional jump to label. /// Unconditional jump to label.
Ref Branch(Ref target_label); Ref Branch(Ref target_label);
/// If condition is true branch to true_label, otherwise branch to false_label. /// If condition is true branch to true_label, otherwise branch to
/// false_label.
Ref BranchConditional(Ref condition, Ref true_label, Ref false_label, Ref BranchConditional(Ref condition, Ref true_label, Ref false_label,
std::uint32_t true_weight = 0, std::uint32_t false_weight = 0); std::uint32_t true_weight = 0,
std::uint32_t false_weight = 0);
/// Returns with no value from a function with void return type. /// Returns with no value from a function with void return type.
Ref Return(); Ref Return();
@ -188,10 +199,10 @@ public:
static Operand* Literal(float value); static Operand* Literal(float value);
static Operand* Literal(double value); static Operand* Literal(double value);
private: private:
Ref AddCode(Op* op); Ref AddCode(Op* op);
Ref AddCode(spv::Op opcode, std::uint32_t id = UINT32_MAX); Ref AddCode(spv::Op opcode, std::optional<std::uint32_t> id = {});
Ref AddDeclaration(Op* op); Ref AddDeclaration(Op* op);

View file

@ -6,13 +6,12 @@
#pragma once #pragma once
#include "stream.h"
#include "op.h" #include "op.h"
#include "stream.h"
namespace Sirit { namespace Sirit {
template<typename T> template <typename T> inline void AddEnum(Op* op, T value) {
inline void AddEnum(Op* op, T value) {
op->Add(static_cast<u32>(value)); op->Add(static_cast<u32>(value));
} }

View file

@ -19,13 +19,13 @@ Ref Module::ConstantFalse(Ref result_type) {
} }
Ref Module::Constant(Ref result_type, Operand* literal) { Ref Module::Constant(Ref result_type, Operand* literal) {
Op* op{new Op(spv::Op::OpConstant, bound, result_type)}; auto const op{new Op(spv::Op::OpConstant, bound, result_type)};
op->Add(literal); op->Add(literal);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::ConstantComposite(Ref result_type, const std::vector<Ref>& constituents) { Ref Module::ConstantComposite(Ref result_type, const std::vector<Ref>& constituents) {
Op* op{new Op(spv::Op::OpConstantComposite, bound, result_type)}; auto const op{new Op(spv::Op::OpConstantComposite, bound, result_type)};
op->Add(constituents); op->Add(constituents);
return AddDeclaration(op); return AddDeclaration(op);
} }
@ -34,7 +34,7 @@ Ref Module::ConstantSampler(Ref result_type, spv::SamplerAddressingMode addressi
bool normalized, spv::SamplerFilterMode filter_mode) { bool normalized, spv::SamplerFilterMode filter_mode) {
AddCapability(spv::Capability::LiteralSampler); AddCapability(spv::Capability::LiteralSampler);
AddCapability(spv::Capability::Kernel); AddCapability(spv::Capability::Kernel);
Op* op{new Op(spv::Op::OpConstantSampler, bound, result_type)}; auto const op{new Op(spv::Op::OpConstantSampler, bound, result_type)};
AddEnum(op, addressing_mode); AddEnum(op, addressing_mode);
op->Add(normalized ? 1 : 0); op->Add(normalized ? 1 : 0);
AddEnum(op, filter_mode); AddEnum(op, filter_mode);

View file

@ -10,7 +10,7 @@
namespace Sirit { namespace Sirit {
Ref Module::Name(Ref target, const std::string& name) { Ref Module::Name(Ref target, const std::string& name) {
Op* op{new Op(spv::Op::OpName)}; auto const op{new Op(spv::Op::OpName)};
op->Add(target); op->Add(target);
op->Add(name); op->Add(name);
debug.push_back(std::unique_ptr<Op>(op)); debug.push_back(std::unique_ptr<Op>(op));

View file

@ -11,7 +11,7 @@ namespace Sirit {
Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask loop_control, Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask loop_control,
const std::vector<Ref>& literals) { const std::vector<Ref>& literals) {
Op* op{new Op(spv::Op::OpLoopMerge)}; auto const op{new Op(spv::Op::OpLoopMerge)};
op->Add(merge_block); op->Add(merge_block);
op->Add(continue_target); op->Add(continue_target);
AddEnum(op, loop_control); AddEnum(op, loop_control);
@ -20,7 +20,7 @@ Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask
} }
Ref Module::SelectionMerge(Ref merge_block, spv::SelectionControlMask selection_control) { Ref Module::SelectionMerge(Ref merge_block, spv::SelectionControlMask selection_control) {
Op* op{new Op(spv::Op::OpSelectionMerge)}; auto const op{new Op(spv::Op::OpSelectionMerge)};
op->Add(merge_block); op->Add(merge_block);
AddEnum(op, selection_control); AddEnum(op, selection_control);
return AddCode(op); return AddCode(op);
@ -31,14 +31,14 @@ Ref Module::Label() {
} }
Ref Module::Branch(Ref target_label) { Ref Module::Branch(Ref target_label) {
Op* op{new Op(spv::Op::OpBranch)}; auto const op{new Op(spv::Op::OpBranch)};
op->Add(target_label); op->Add(target_label);
return AddCode(op); return AddCode(op);
} }
Ref Module::BranchConditional(Ref condition, Ref true_label, Ref false_label, Ref Module::BranchConditional(Ref condition, Ref true_label, Ref false_label,
std::uint32_t true_weight, std::uint32_t false_weight) { std::uint32_t true_weight, std::uint32_t false_weight) {
Op* op{new Op(spv::Op::OpBranchConditional)}; auto const op{new Op(spv::Op::OpBranchConditional)};
op->Add(condition); op->Add(condition);
op->Add(true_label); op->Add(true_label);
op->Add(false_label); op->Add(false_label);

View file

@ -10,7 +10,7 @@
namespace Sirit { namespace Sirit {
Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type) { Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type) {
Op* op{new Op{spv::Op::OpFunction, bound++, result_type}}; auto const op{new Op{spv::Op::OpFunction, bound++, result_type}};
op->Add(static_cast<u32>(function_control)); op->Add(static_cast<u32>(function_control));
op->Add(function_type); op->Add(function_type);
return AddCode(op); return AddCode(op);

View file

@ -5,6 +5,7 @@
*/ */
#include <cassert> #include <cassert>
#include <optional>
#include "sirit/sirit.h" #include "sirit/sirit.h"
#include "insts.h" #include "insts.h"
@ -26,7 +27,7 @@ Ref Module::TypeInt(int width, bool is_signed) {
} else if (width == 64) { } else if (width == 64) {
AddCapability(spv::Capability::Int64); AddCapability(spv::Capability::Int64);
} }
Op* op{new Op(spv::Op::OpTypeInt, bound)}; auto const op{new Op(spv::Op::OpTypeInt, bound)};
op->Add(width); op->Add(width);
op->Add(is_signed ? 1 : 0); op->Add(is_signed ? 1 : 0);
return AddDeclaration(op); return AddDeclaration(op);
@ -38,14 +39,14 @@ Ref Module::TypeFloat(int width) {
} else if (width == 64) { } else if (width == 64) {
AddCapability(spv::Capability::Float64); AddCapability(spv::Capability::Float64);
} }
Op* op{new Op(spv::Op::OpTypeFloat, bound)}; auto const op{new Op(spv::Op::OpTypeFloat, bound)};
op->Add(width); op->Add(width);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::TypeVector(Ref component_type, int component_count) { Ref Module::TypeVector(Ref component_type, int component_count) {
assert(component_count >= 2); assert(component_count >= 2);
Op* op{new Op(spv::Op::OpTypeVector, bound)}; auto const op{new Op(spv::Op::OpTypeVector, bound)};
op->Add(component_type); op->Add(component_type);
op->Add(component_count); op->Add(component_count);
return AddDeclaration(op); return AddDeclaration(op);
@ -61,8 +62,8 @@ Ref Module::TypeMatrix(Ref column_type, int column_count) {
} }
Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms,
int sampled, spv::ImageFormat image_format, int sampled, spv::ImageFormat image_format,
spv::AccessQualifier access_qualifier) { std::optional<spv::AccessQualifier> access_qualifier) {
switch (dim) { switch (dim) {
case spv::Dim::Dim1D: case spv::Dim::Dim1D:
AddCapability(spv::Capability::Sampled1D); AddCapability(spv::Capability::Sampled1D);
@ -123,7 +124,7 @@ Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, b
AddCapability(spv::Capability::StorageImageExtendedFormats); AddCapability(spv::Capability::StorageImageExtendedFormats);
break; break;
} }
Op* op{new Op(spv::Op::OpTypeImage, bound)}; auto const op{new Op(spv::Op::OpTypeImage, bound)};
op->Add(sampled_type); op->Add(sampled_type);
op->Add(static_cast<u32>(dim)); op->Add(static_cast<u32>(dim));
op->Add(depth); op->Add(depth);
@ -131,9 +132,9 @@ Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, b
op->Add(ms ? 1 : 0); op->Add(ms ? 1 : 0);
op->Add(sampled); op->Add(sampled);
op->Add(static_cast<u32>(image_format)); op->Add(static_cast<u32>(image_format));
if (static_cast<u32>(access_qualifier) != Undefined) { if (access_qualifier.has_value()) {
AddCapability(spv::Capability::Kernel); AddCapability(spv::Capability::Kernel);
op->Add(static_cast<u32>(access_qualifier)); op->Add(static_cast<u32>(access_qualifier.value()));
} }
return AddDeclaration(op); return AddDeclaration(op);
} }
@ -143,13 +144,13 @@ Ref Module::TypeSampler() {
} }
Ref Module::TypeSampledImage(Ref image_type) { Ref Module::TypeSampledImage(Ref image_type) {
Op* op{new Op(spv::Op::OpTypeSampledImage, bound)}; auto const op{new Op(spv::Op::OpTypeSampledImage, bound)};
op->Add(image_type); op->Add(image_type);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::TypeArray(Ref element_type, Ref length) { Ref Module::TypeArray(Ref element_type, Ref length) {
Op* op{new Op(spv::Op::OpTypeArray, bound)}; auto const op{new Op(spv::Op::OpTypeArray, bound)};
op->Add(element_type); op->Add(element_type);
op->Add(length); op->Add(length);
return AddDeclaration(op); return AddDeclaration(op);
@ -157,20 +158,20 @@ Ref Module::TypeArray(Ref element_type, Ref length) {
Ref Module::TypeRuntimeArray(Ref element_type) { Ref Module::TypeRuntimeArray(Ref element_type) {
AddCapability(spv::Capability::Shader); AddCapability(spv::Capability::Shader);
Op* op{new Op(spv::Op::OpTypeRuntimeArray, bound)}; auto const op{new Op(spv::Op::OpTypeRuntimeArray, bound)};
op->Add(element_type); op->Add(element_type);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::TypeStruct(const std::vector<Ref>& members) { Ref Module::TypeStruct(const std::vector<Ref>& members) {
Op* op{new Op(spv::Op::OpTypeStruct, bound)}; auto const op{new Op(spv::Op::OpTypeStruct, bound)};
op->Add(members); op->Add(members);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::TypeOpaque(const std::string& name) { Ref Module::TypeOpaque(const std::string& name) {
AddCapability(spv::Capability::Kernel); AddCapability(spv::Capability::Kernel);
Op* op{new Op(spv::Op::OpTypeOpaque, bound)}; auto const op{new Op(spv::Op::OpTypeOpaque, bound)};
op->Add(name); op->Add(name);
return AddDeclaration(op); return AddDeclaration(op);
} }
@ -191,14 +192,14 @@ Ref Module::TypePointer(spv::StorageClass storage_class, Ref type) {
AddCapability(spv::Capability::AtomicStorage); AddCapability(spv::Capability::AtomicStorage);
break; break;
} }
Op* op{new Op(spv::Op::OpTypePointer, bound)}; auto const op{new Op(spv::Op::OpTypePointer, bound)};
op->Add(static_cast<u32>(storage_class)); op->Add(static_cast<u32>(storage_class));
op->Add(type); op->Add(type);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::TypeFunction(Ref return_type, const std::vector<Ref>& arguments) { Ref Module::TypeFunction(Ref return_type, const std::vector<Ref>& arguments) {
Op* op{new Op(spv::Op::OpTypeFunction, bound)}; auto const op{new Op(spv::Op::OpTypeFunction, bound)};
op->Add(return_type); op->Add(return_type);
op->Add(arguments); op->Add(arguments);
return AddDeclaration(op); return AddDeclaration(op);
@ -226,7 +227,7 @@ Ref Module::TypeQueue() {
Ref Module::TypePipe(spv::AccessQualifier access_qualifier) { Ref Module::TypePipe(spv::AccessQualifier access_qualifier) {
AddCapability(spv::Capability::Pipes); AddCapability(spv::Capability::Pipes);
Op* op{new Op(spv::Op::OpTypePipe, bound)}; auto const op{new Op(spv::Op::OpTypePipe, bound)};
op->Add(static_cast<u32>(access_qualifier)); op->Add(static_cast<u32>(access_qualifier));
return AddDeclaration(op); return AddDeclaration(op);
} }

View file

@ -4,35 +4,23 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include "sirit/sirit.h"
#include "common_types.h" #include "common_types.h"
#include "operand.h"
#include "lnumber.h" #include "lnumber.h"
#include "operand.h"
#include "sirit/sirit.h"
namespace Sirit { namespace Sirit {
Operand* Module::Literal(u32 value) { #define DEFINE_LITERAL(type) \
return new LiteralNumber(value); Operand* Module::Literal(type value) { \
} return LiteralNumber::Create<type>(value); \
}
Operand* Module::Literal(u64 value) { DEFINE_LITERAL(u32)
return new LiteralNumber(value); DEFINE_LITERAL(u64)
} DEFINE_LITERAL(s32)
DEFINE_LITERAL(s64)
Operand* Module::Literal(s32 value) { DEFINE_LITERAL(f32)
return new LiteralNumber(value); DEFINE_LITERAL(f64)
}
Operand* Module::Literal(s64 value) {
return new LiteralNumber(value);
}
Operand* Module::Literal(f32 value) {
return new LiteralNumber(value);
}
Operand* Module::Literal(f64 value) {
return new LiteralNumber(value);
}
} // namespace Sirit } // namespace Sirit

View file

@ -4,79 +4,26 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include <cassert>
#include "lnumber.h" #include "lnumber.h"
#include <cassert>
namespace Sirit { namespace Sirit {
LiteralNumber::LiteralNumber() { LiteralNumber::LiteralNumber(std::type_index type) : type(type) {
operand_type = OperandType::Number; operand_type = OperandType::Number;
} }
LiteralNumber::LiteralNumber(u32 number)
: uint32(number), type(NumberType::U32) {
LiteralNumber();
}
LiteralNumber::LiteralNumber(s32 number)
: int32(number), type(NumberType::S32) {
LiteralNumber();
}
LiteralNumber::LiteralNumber(f32 number)
: float32(number), type(NumberType::F32) {
LiteralNumber();
}
LiteralNumber::LiteralNumber(u64 number)
: uint64(number), type(NumberType::U64) {
LiteralNumber();
}
LiteralNumber::LiteralNumber(s64 number)
: int64(number), type(NumberType::S64) {
LiteralNumber();
}
LiteralNumber::LiteralNumber(f64 number)
: float64(number), type(NumberType::F64) {
LiteralNumber();
}
LiteralNumber::~LiteralNumber() = default; LiteralNumber::~LiteralNumber() = default;
void LiteralNumber::Fetch(Stream& stream) const { void LiteralNumber::Fetch(Stream& stream) const {
switch (type) { if (is_32) {
case NumberType::S32: stream.Write(static_cast<u32>(raw));
case NumberType::U32: } else {
case NumberType::F32: stream.Write(raw);
stream.Write(uint32);
break;
case NumberType::S64:
case NumberType::U64:
case NumberType::F64:
stream.Write(uint64);
break;
default:
assert(0);
} }
} }
u16 LiteralNumber::GetWordCount() const { u16 LiteralNumber::GetWordCount() const { return is_32 ? 1 : 2; }
switch (type) {
case NumberType::S32:
case NumberType::U32:
case NumberType::F32:
return 1;
case NumberType::S64:
case NumberType::U64:
case NumberType::F64:
return 2;
default:
assert(0);
return 0;
}
}
bool LiteralNumber::operator==(const Operand& other) const { bool LiteralNumber::operator==(const Operand& other) const {
if (operand_type == other.GetType()) { if (operand_type == other.GetType()) {

View file

@ -6,19 +6,15 @@
#pragma once #pragma once
#include "stream.h"
#include "operand.h" #include "operand.h"
#include "stream.h"
#include <typeindex>
namespace Sirit { namespace Sirit {
class LiteralNumber : public Operand { class LiteralNumber : public Operand {
public: public:
LiteralNumber(u32 number); LiteralNumber(std::type_index type);
LiteralNumber(s32 number);
LiteralNumber(f32 number);
LiteralNumber(u64 number);
LiteralNumber(s64 number);
LiteralNumber(f64 number);
~LiteralNumber(); ~LiteralNumber();
virtual void Fetch(Stream& stream) const; virtual void Fetch(Stream& stream) const;
@ -26,27 +22,21 @@ public:
virtual bool operator==(const Operand& other) const; virtual bool operator==(const Operand& other) const;
private: template <typename T> static LiteralNumber* Create(T value) {
LiteralNumber(); static_assert(sizeof(T) == 4 || sizeof(T) == 8);
LiteralNumber* number = new LiteralNumber(std::type_index(typeid(T)));
if (number->is_32 = sizeof(T) == 4; number->is_32) {
number->raw = *reinterpret_cast<u32*>(&value);
} else {
number->raw = *reinterpret_cast<u64*>(&value);
}
return number;
}
enum class NumberType { private:
U32, std::type_index type;
S32, bool is_32;
F32, u64 raw;
U64,
S64,
F64
} type;
union {
u64 raw{};
u32 uint32;
s32 int32;
u64 uint64;
s64 int64;
f32 float32;
f64 float64;
};
}; };
} // namespace Sirit } // namespace Sirit

View file

@ -8,8 +8,7 @@
namespace Sirit { namespace Sirit {
LiteralString::LiteralString(const std::string& string_) LiteralString::LiteralString(const std::string& string_) : string(string_) {
: string(string_) {
operand_type = OperandType::String; operand_type = OperandType::String;
} }

View file

@ -6,14 +6,14 @@
#pragma once #pragma once
#include <string>
#include "stream.h"
#include "operand.h" #include "operand.h"
#include "stream.h"
#include <string>
namespace Sirit { namespace Sirit {
class LiteralString : public Operand { class LiteralString : public Operand {
public: public:
LiteralString(const std::string& string); LiteralString(const std::string& string);
~LiteralString(); ~LiteralString();
@ -22,7 +22,7 @@ public:
virtual bool operator==(const Operand& other) const; virtual bool operator==(const Operand& other) const;
private: private:
std::string string; std::string string;
}; };

View file

@ -5,29 +5,28 @@
*/ */
#include <cassert> #include <cassert>
#include "common_types.h" #include "common_types.h"
#include "operand.h"
#include "op.h"
#include "lnumber.h" #include "lnumber.h"
#include "lstring.h" #include "lstring.h"
#include "op.h"
#include "operand.h"
namespace Sirit { namespace Sirit {
Op::Op(spv::Op opcode_, u32 id_, Ref result_type_) Op::Op(spv::Op opcode, std::optional<u32> id, Ref result_type)
: opcode(opcode_), id(id_), result_type(result_type_) { : opcode(opcode), id(id), result_type(result_type) {
operand_type = OperandType::Op; operand_type = OperandType::Op;
} }
Op::~Op() = default; Op::~Op() = default;
void Op::Fetch(Stream& stream) const { void Op::Fetch(Stream& stream) const {
assert(id != UINT32_MAX); assert(id.has_value());
stream.Write(id); stream.Write(id.value());
} }
u16 Op::GetWordCount() const { u16 Op::GetWordCount() const { return 1; }
return 1;
}
bool Op::operator==(const Operand& other) const { bool Op::operator==(const Operand& other) const {
if (operand_type != other.GetType()) { if (operand_type != other.GetType()) {
@ -53,8 +52,8 @@ void Op::Write(Stream& stream) const {
if (result_type) { if (result_type) {
result_type->Fetch(stream); result_type->Fetch(stream);
} }
if (id != UINT32_MAX) { if (id.has_value()) {
stream.Write(id); stream.Write(id.value());
} }
for (const Operand* operand : operands) { for (const Operand* operand : operands) {
operand->Fetch(stream); operand->Fetch(stream);
@ -66,17 +65,11 @@ void Op::Add(Operand* operand) {
operand_store.push_back(std::unique_ptr<Operand>(operand)); operand_store.push_back(std::unique_ptr<Operand>(operand));
} }
void Op::Add(const Operand* operand) { void Op::Add(const Operand* operand) { operands.push_back(operand); }
operands.push_back(operand);
}
void Op::Add(u32 integer) { void Op::Add(u32 integer) { Add(LiteralNumber::Create<u32>(integer)); }
Add(new LiteralNumber(integer));
}
void Op::Add(const std::string& string) { void Op::Add(const std::string& string) { Add(new LiteralString(string)); }
Add(new LiteralString(string));
}
void Op::Add(const std::vector<Ref>& ids) { void Op::Add(const std::vector<Ref>& ids) {
for (Ref op : ids) { for (Ref op : ids) {
@ -89,7 +82,7 @@ u16 Op::WordCount() const {
if (result_type) { if (result_type) {
count++; count++;
} }
if (id != UINT32_MAX) { if (id.has_value()) {
count++; count++;
} }
for (const Operand* operand : operands) { for (const Operand* operand : operands) {

View file

@ -6,16 +6,18 @@
#pragma once #pragma once
#include "sirit/sirit.h"
#include "common_types.h" #include "common_types.h"
#include "operand.h" #include "operand.h"
#include "sirit/sirit.h"
#include "stream.h" #include "stream.h"
#include <optional>
namespace Sirit { namespace Sirit {
class Op : public Operand { class Op : public Operand {
public: public:
explicit Op(spv::Op opcode, u32 id = UINT32_MAX, Ref result_type = nullptr); explicit Op(spv::Op opcode, std::optional<u32> id = {},
Ref result_type = nullptr);
~Op(); ~Op();
virtual void Fetch(Stream& stream) const; virtual void Fetch(Stream& stream) const;
@ -35,14 +37,14 @@ public:
void Add(const std::vector<Ref>& ids); void Add(const std::vector<Ref>& ids);
private: private:
u16 WordCount() const; u16 WordCount() const;
spv::Op opcode; spv::Op opcode;
Ref result_type; Ref result_type;
u32 id; std::optional<u32> id;
std::vector<const Operand*> operands; std::vector<const Operand*> operands;

View file

@ -4,8 +4,8 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include <cassert>
#include "operand.h" #include "operand.h"
#include <cassert>
namespace Sirit { namespace Sirit {
@ -22,16 +22,12 @@ u16 Operand::GetWordCount() const {
return 0; return 0;
} }
bool Operand::operator==(const Operand& other) const { bool Operand::operator==(const Operand& other) const { return false; }
return false;
}
bool Operand::operator!=(const Operand& other) const { bool Operand::operator!=(const Operand& other) const {
return !(*this == other); return !(*this == other);
} }
OperandType Operand::GetType() const { OperandType Operand::GetType() const { return operand_type; }
return operand_type;
}
} // namespace Sirit } // namespace Sirit

View file

@ -10,15 +10,10 @@
namespace Sirit { namespace Sirit {
enum class OperandType { enum class OperandType { Invalid, Op, Number, String };
Invalid,
Op,
Number,
String
};
class Operand { class Operand {
public: public:
Operand(); Operand();
virtual ~Operand(); virtual ~Operand();
@ -30,7 +25,7 @@ public:
OperandType GetType() const; OperandType GetType() const;
protected: protected:
OperandType operand_type{}; OperandType operand_type{};
}; };

View file

@ -4,17 +4,17 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include <algorithm>
#include <cassert>
#include "sirit/sirit.h" #include "sirit/sirit.h"
#include "common_types.h" #include "common_types.h"
#include "op.h" #include "op.h"
#include "stream.h" #include "stream.h"
#include <algorithm>
#include <cassert>
namespace Sirit { namespace Sirit {
template<typename T> template <typename T>
inline void WriteEnum(Stream& stream, spv::Op opcode, T value) { static void WriteEnum(Stream& stream, spv::Op opcode, T value) {
Op op{opcode}; Op op{opcode};
op.Add(static_cast<u32>(value)); op.Add(static_cast<u32>(value));
op.Write(stream); op.Write(stream);
@ -30,7 +30,7 @@ std::vector<u8> Module::Assemble() const {
stream.Write(spv::MagicNumber); stream.Write(spv::MagicNumber);
stream.Write(spv::Version); stream.Write(spv::Version);
stream.Write(GeneratorMagicNumber); stream.Write(GENERATOR_MAGIC_NUMBER);
stream.Write(bound); stream.Write(bound);
stream.Write(static_cast<u32>(0)); stream.Write(static_cast<u32>(0));
@ -69,21 +69,22 @@ std::vector<u8> Module::Assemble() const {
return bytes; return bytes;
} }
void Module::Optimize(int level) { void Module::Optimize(int level) {}
}
void Module::AddCapability(spv::Capability capability) { void Module::AddCapability(spv::Capability capability) {
capabilities.insert(capability); capabilities.insert(capability);
} }
void Module::SetMemoryModel(spv::AddressingModel addressing_model, spv::MemoryModel memory_model) { void Module::SetMemoryModel(spv::AddressingModel addressing_model,
spv::MemoryModel memory_model) {
this->addressing_model = addressing_model; this->addressing_model = addressing_model;
this->memory_model = memory_model; this->memory_model = memory_model;
} }
void Module::AddEntryPoint(spv::ExecutionModel execution_model, Ref entry_point, void Module::AddEntryPoint(spv::ExecutionModel execution_model, Ref entry_point,
const std::string& name, const std::vector<Ref>& interfaces) { const std::string& name,
Op* op{new Op(spv::Op::OpEntryPoint)}; const std::vector<Ref>& interfaces) {
auto const op{new Op(spv::Op::OpEntryPoint)};
op->Add(static_cast<u32>(execution_model)); op->Add(static_cast<u32>(execution_model));
op->Add(entry_point); op->Add(entry_point);
op->Add(name); op->Add(name);
@ -102,14 +103,14 @@ Ref Module::AddCode(Op* op) {
return op; return op;
} }
Ref Module::AddCode(spv::Op opcode, u32 id) { Ref Module::AddCode(spv::Op opcode, std::optional<u32> id) {
return AddCode(new Op{opcode, id}); return AddCode(new Op(opcode, id));
} }
Ref Module::AddDeclaration(Op* op) { Ref Module::AddDeclaration(Op* op) {
const auto& found{std::find_if(declarations.begin(), declarations.end(), [=](const auto& other) { const auto& found{
return *other == *op; std::find_if(declarations.begin(), declarations.end(),
})}; [&op](const auto& other) { return *other == *op; })};
if (found != declarations.end()) { if (found != declarations.end()) {
delete op; delete op;
return found->get(); return found->get();

View file

@ -8,42 +8,39 @@
namespace Sirit { namespace Sirit {
Stream::Stream(std::vector<u8>& bytes_) Stream::Stream(std::vector<u8>& bytes_) : bytes(bytes_) {}
: bytes(bytes_) {}
Stream::~Stream() = default; Stream::~Stream() = default;
void Stream::Write(std::string string) { void Stream::Write(std::string string) {
std::size_t size{string.size()}; const auto size{string.size()};
u8* data{reinterpret_cast<u8*>(string.data())}; const auto data{reinterpret_cast<u8*>(string.data())};
for (std::size_t i{}; i < size; i++) { for (std::size_t i = 0; i < size; i++) {
Write(data[i]); Write(data[i]);
} }
for (std::size_t i{}; i < 4 - size % 4; i++) { for (std::size_t i = 0; i < 4 - size % 4; i++) {
Write(static_cast<u8>(0)); Write(static_cast<u8>(0));
} }
} }
void Stream::Write(u64 value) { void Stream::Write(u64 value) {
u32* mem{reinterpret_cast<u32*>(&value)}; const auto mem{reinterpret_cast<u32*>(&value)};
Write(mem[0]); Write(mem[0]);
Write(mem[1]); Write(mem[1]);
} }
void Stream::Write(u32 value) { void Stream::Write(u32 value) {
u16* mem{reinterpret_cast<u16*>(&value)}; const auto mem{reinterpret_cast<u16*>(&value)};
Write(mem[0]); Write(mem[0]);
Write(mem[1]); Write(mem[1]);
} }
void Stream::Write(u16 value) { void Stream::Write(u16 value) {
u8* mem{reinterpret_cast<u8*>(&value)}; const auto mem{reinterpret_cast<u8*>(&value)};
Write(mem[0]); Write(mem[0]);
Write(mem[1]); Write(mem[1]);
} }
void Stream::Write(u8 value) { void Stream::Write(u8 value) { bytes.push_back(value); }
bytes.push_back(value);
}
} // namespace Sirit } // namespace Sirit

View file

@ -6,14 +6,14 @@
#pragma once #pragma once
#include "common_types.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "common_types.h"
namespace Sirit { namespace Sirit {
class Stream { class Stream {
public: public:
explicit Stream(std::vector<u8>& bytes); explicit Stream(std::vector<u8>& bytes);
~Stream(); ~Stream();
@ -27,7 +27,7 @@ public:
void Write(u8 value); void Write(u8 value);
private: private:
std::vector<u8>& bytes; std::vector<u8>& bytes;
}; };