1
0
Fork 0
mirror of https://git.suyu.dev/suyu/sirit.git synced 2025-01-05 02:06:13 +00:00

Remove forward references and add phi node patching

The previous API for forward declarations broke when more than one
definition was done. Forward references on instructions that are not
labels were only needed for phi nodes, so it has been replaced with a
deferred phi node instruction and a method to patch these after
everything has been defined.
This commit is contained in:
ReinUsesLisp 2021-04-11 02:02:40 -03:00
parent f1cccfd0f3
commit 51d541f1a1
4 changed files with 48 additions and 23 deletions

View file

@ -12,6 +12,7 @@
#include <optional> #include <optional>
#include <span> #include <span>
#include <string> #include <string>
#include <functional>
#include <string_view> #include <string_view>
#include <type_traits> #include <type_traits>
#include <unordered_set> #include <unordered_set>
@ -52,6 +53,9 @@ public:
*/ */
std::vector<std::uint32_t> Assemble() const; std::vector<std::uint32_t> Assemble() const;
/// Patches deferred phi nodes calling the passed function on each phi argument
void PatchDeferredPhi(const std::function<Id(std::size_t index)>& func);
/// Adds a SPIR-V extension. /// Adds a SPIR-V extension.
void AddExtension(std::string extension_name); void AddExtension(std::string extension_name);
@ -87,15 +91,6 @@ public:
AddExecutionMode(entry_point, mode, std::span<const Literal>({literals...})); AddExecutionMode(entry_point, mode, std::span<const Literal>({literals...}));
} }
/// Generate a new id for forward declarations
[[nodiscard]] Id ForwardDeclarationId();
/// Returns the current generator id, useful for self-referencing phi nodes
[[nodiscard]] Id CurrentId() const noexcept;
/// Assign a new id and return the old one, useful for defining forward declarations
Id ExchangeCurrentId(Id new_current_id);
/** /**
* Adds an existing label to the code * Adds an existing label to the code
* @param label Label to insert into code. * @param label Label to insert into code.
@ -253,6 +248,12 @@ public:
*/ */
Id OpPhi(Id result_type, std::span<const Id> operands); Id OpPhi(Id result_type, std::span<const Id> operands);
/**
* The SSA phi function. This instruction will be revisited when patching phi nodes.
* @param operands An immutable span of block pairs
*/
Id DeferredOpPhi(Id result_type, std::span<const Id> blocks);
/// Declare a structured loop. /// Declare a structured loop.
Id OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control, Id OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control,
std::span<const Id> literals = {}); std::span<const Id> literals = {});
@ -1236,6 +1237,7 @@ private:
std::unique_ptr<Declarations> declarations; std::unique_ptr<Declarations> declarations;
std::unique_ptr<Stream> global_variables; std::unique_ptr<Stream> global_variables;
std::unique_ptr<Stream> code; std::unique_ptr<Stream> code;
std::vector<std::uint32_t> deferred_phi_nodes;
}; };
} // namespace Sirit } // namespace Sirit

View file

@ -18,6 +18,16 @@ Id Module::OpPhi(Id result_type, std::span<const Id> operands) {
return *code << OpId{spv::Op::OpPhi, result_type} << operands << EndOp{}; return *code << OpId{spv::Op::OpPhi, result_type} << operands << EndOp{};
} }
Id Module::DeferredOpPhi(Id result_type, std::span<const Id> blocks) {
deferred_phi_nodes.push_back(code->LocalAddress());
code->Reserve(3 + blocks.size() * 2);
*code << OpId{spv::Op::OpPhi, result_type};
for (const Id block : blocks) {
*code << u32{0} << block;
}
return *code << EndOp{};
}
Id Module::OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control, Id Module::OpLoopMerge(Id merge_block, Id continue_target, spv::LoopControlMask loop_control,
std::span<const Id> literals) { std::span<const Id> literals) {
code->Reserve(4 + literals.size()); code->Reserve(4 + literals.size());

View file

@ -68,6 +68,20 @@ std::vector<u32> Module::Assemble() const {
return words; return words;
} }
void Module::PatchDeferredPhi(const std::function<Id(std::size_t index)>& func) {
for (const u32 phi_index : deferred_phi_nodes) {
const u32 first_word = code->Value(phi_index);
[[maybe_unused]] const spv::Op op = static_cast<spv::Op>(first_word & 0xffff);
assert(op == spv::Op::OpPhi);
const u32 num_words = first_word >> 16;
const u32 num_args = (num_words - 3) / 2;
u32 cursor = phi_index + 3;
for (u32 arg = 0; arg < num_args; ++arg, cursor += 2) {
code->SetValue(cursor, func(arg).value);
}
}
}
void Module::AddExtension(std::string extension_name) { void Module::AddExtension(std::string extension_name) {
extensions.insert(std::move(extension_name)); extensions.insert(std::move(extension_name));
} }
@ -95,19 +109,6 @@ void Module::AddExecutionMode(Id entry_point, spv::ExecutionMode mode,
*execution_modes << spv::Op::OpExecutionMode << entry_point << mode << literals << EndOp{}; *execution_modes << spv::Op::OpExecutionMode << entry_point << mode << literals << EndOp{};
} }
Id Module::ForwardDeclarationId() {
return Id{++bound};
}
Id Module::CurrentId() const noexcept {
return Id{bound + 1};
}
Id Module::ExchangeCurrentId(Id new_current_id) {
const std::uint32_t old_id = std::exchange(bound, new_current_id.value - 1);
return Id{old_id + 1};
}
Id Module::AddLabel(Id label) { Id Module::AddLabel(Id label) {
assert(label.value != 0); assert(label.value != 0);
code->Reserve(2); code->Reserve(2);

View file

@ -40,7 +40,7 @@ struct OpId {
struct EndOp {}; struct EndOp {};
constexpr size_t WordsInString(std::string_view string) { inline size_t WordsInString(std::string_view string) {
return string.size() / sizeof(u32) + 1; return string.size() / sizeof(u32) + 1;
} }
@ -76,6 +76,18 @@ public:
return std::span(words.data(), insert_index); return std::span(words.data(), insert_index);
} }
u32 LocalAddress() const noexcept {
return static_cast<u32>(words.size());
}
u32 Value(u32 index) const noexcept {
return words[index];
}
void SetValue(u32 index, u32 value) noexcept {
words[index] = value;
}
Stream& operator<<(spv::Op op) { Stream& operator<<(spv::Op op) {
op_index = insert_index; op_index = insert_index;
words[insert_index++] = static_cast<u32>(op); words[insert_index++] = static_cast<u32>(op);