Skip to content

Emit table of return sorts #911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bindings/core/src/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ void *constructInitialConfiguration(const KOREPattern *);

namespace kllvm::bindings {

std::string return_sort_for_label(std::string const &label) {
auto tag = getTagForSymbolName(label.c_str());
return getReturnSortForTag(tag);
}

std::shared_ptr<KOREPattern> make_injection(
std::shared_ptr<KOREPattern> term, std::shared_ptr<KORESort> from,
std::shared_ptr<KORESort> to) {
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ void bind_runtime(py::module_ &m) {
m.def("simplify_pattern", bindings::simplify);
m.def("simplify_bool_pattern", bindings::simplify_to_bool);

m.def("return_sort_for_label", bindings::return_sort_for_label);

// This class can't be used directly from Python; the mutability semantics
// that we get from the Pybind wrappers make it really easy to break things.
// We therefore have to wrap it up in some external Python code; see
Expand Down
12 changes: 12 additions & 0 deletions include/kllvm/ast/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ using sptr = std::shared_ptr<T>;

std::string decodeKore(std::string);

/*
* Helper function to avoid repeated call-site uses of ostringstream when we
* just want the string representation of a node, rather than to print it to a
* stream.
*/
template <typename T>
std::string ast_to_string(T &&node) {
auto os = std::ostringstream{};
std::forward<T>(node).print(os);
return os.str();
}

// KORESort
class KORESort : public std::enable_shared_from_this<KORESort> {
public:
Expand Down
2 changes: 2 additions & 0 deletions include/kllvm/bindings/core/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace kllvm::bindings {

std::string return_sort_for_label(std::string const &label);

std::shared_ptr<kllvm::KOREPattern> make_injection(
std::shared_ptr<kllvm::KOREPattern> term,
std::shared_ptr<kllvm::KORESort> from, std::shared_ptr<kllvm::KORESort> to);
Expand Down
1 change: 1 addition & 0 deletions include/runtime/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ uint32_t getInjectionForSortOfTag(uint32_t tag);
bool hook_STRING_eq(SortString, SortString);

const char *getSymbolNameForTag(uint32_t tag);
const char *getReturnSortForTag(uint32_t tag);
const char *topSort(void);

typedef struct {
Expand Down
50 changes: 50 additions & 0 deletions lib/codegen/EmitConfigParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,55 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) {
}
}

/*
* Emit a table mapping symbol tags to the declared return sort for that symbol.
* For example:
*
* tag_of(initGeneratedTopCell) |-> sort_name_SortGeneratedTopCell{}
*
* Each value in the table is a pointer to a global variable containing the
* relevant sort name as a null-terminated string.
*
* The function `getReturnSortForTag` abstracts accesses to the data in this
* table.
*/
static void
emitReturnSortTable(KOREDefinition *definition, llvm::Module *module) {
auto &ctx = module->getContext();

auto const &syms = definition->getSymbols();

auto element_type = llvm::Type::getInt8PtrTy(ctx);
auto table_type = llvm::ArrayType::get(element_type, syms.size());

auto table = module->getOrInsertGlobal("return_sort_table", table_type);
auto values = std::vector<llvm::Constant *>{};

for (auto [tag, symbol] : syms) {
auto sort = symbol->getSort();
auto sort_str = ast_to_string(*sort);

auto char_type = llvm::Type::getInt8Ty(ctx);
auto str_type = llvm::ArrayType::get(char_type, sort_str.size() + 1);

auto sort_name
= module->getOrInsertGlobal("sort_name_" + sort_str, str_type);

auto i64_type = llvm::Type::getInt64Ty(ctx);
auto zero = llvm::ConstantInt::get(i64_type, 0);

auto pointer = llvm::ConstantExpr::getInBoundsGetElementPtr(
str_type, sort_name, std::vector<llvm::Constant *>{zero});

values.push_back(pointer);
}

auto global = llvm::dyn_cast<llvm::GlobalVariable>(table);
if (!global->hasInitializer()) {
global->setInitializer(llvm::ConstantArray::get(table_type, values));
}
}

void emitConfigParserFunctions(
KOREDefinition *definition, llvm::Module *module) {
emitGetTagForSymbolName(definition, module);
Expand All @@ -1329,6 +1378,7 @@ void emitConfigParserFunctions(
emitInjTags(definition, module);

emitSortTable(definition, module);
emitReturnSortTable(definition, module);
}

} // namespace kllvm
6 changes: 6 additions & 0 deletions runtime/util/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

extern "C" {

extern char *return_sort_table;

const char *getReturnSortForTag(uint32_t tag) {
return (&return_sort_table)[tag];
}

block *dot_k() {
return leaf_block(getTagForSymbolName("dotk{}"));
}
Expand Down
6,145 changes: 6,145 additions & 0 deletions test/python/Inputs/sorts.kore

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions test/python/k-files/sorts.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module SORTS
imports DOMAINS

syntax Int ::= func() [function, label(func), symbol]

syntax Foo ::= foo() [label(foo), symbol]
syntax Bar ::= Foo
| bar() [label(bar), symbol]

rule func() => 0
rule foo() => bar()
endmodule
26 changes: 26 additions & 0 deletions test/python/test_return_sorts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# RUN: mkdir -p %t
# RUN: export IN=$(realpath Inputs/sorts.kore)
# RUN: cd %t && %kompile "$IN" python --python %py-interpreter --python-output-dir .
# RUN: KLLVM_DEFINITION=%t %python -u %s

from test_bindings import kllvm

import unittest

class TestReturnSorts(unittest.TestCase):

def _check_sort(self, label, sort):
self.assertEqual(kllvm.runtime.return_sort_for_label(label), sort)

def test_function(self):
self._check_sort('Lblfunc{}', 'SortInt{}')

def test_constructor(self):
self._check_sort('Lblfoo{}', 'SortFoo{}')

def test_subsort(self):
self._check_sort('Lblbar{}', 'SortBar{}')


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions unittests/runtime-ffi/ffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#define KCHAR char
#define TYPETAG(type) "Lbl'Hash'ffi'Unds'" #type "{}"

char *return_sort_table = nullptr;

void *constructCompositePattern(uint32_t tag, std::vector<void *> &arguments) {
return nullptr;
}
Expand Down
2 changes: 2 additions & 0 deletions unittests/runtime-io/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#define KCHAR char

char *return_sort_table = nullptr;

void *constructCompositePattern(uint32_t tag, std::vector<void *> &arguments) {
return nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions unittests/runtime-strings/bytestest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#define KCHAR char
extern "C" {

char *return_sort_table = nullptr;

uint32_t getTagForSymbolName(const char *s) {
return 0;
}
Expand Down