Skip to content

Commit 44f29b2

Browse files
authored
Emit table of return sorts (#911)
Part of: #905 In #905, we are implementing a Python binding for the backend's function evaluator: given a function label and list of argument `Pattern`s, construct runtime terms for the arguments, evaluate the function with the given label, and return the result as an AST pattern. To safely reify the runtime term produced by the function call to an AST pattern, we need to know its sort (so that the machinery in #907, #908 can be used correctly). In some places in the bindings, we have to require that callers provide a sort when reifying terms back to patterns. However, when calling a function, the label of the function determines precisely the correct sort to use. This PR emits a new table of global data into compiled interpreters that maps tags to declared return sorts, along with a function that abstracts away indexing into this table. This change is similar to (but simpler than) an existing table of _argument sorts_ for each symbol that we already emit. Testing is handled by binding the new function to Python.
1 parent 4cbbb3a commit 44f29b2

File tree

13 files changed

+6268
-0
lines changed

13 files changed

+6268
-0
lines changed

bindings/core/src/core.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ void *constructInitialConfiguration(const KOREPattern *);
1212

1313
namespace kllvm::bindings {
1414

15+
std::string return_sort_for_label(std::string const &label) {
16+
auto tag = getTagForSymbolName(label.c_str());
17+
return getReturnSortForTag(tag);
18+
}
19+
1520
std::shared_ptr<KOREPattern> make_injection(
1621
std::shared_ptr<KOREPattern> term, std::shared_ptr<KORESort> from,
1722
std::shared_ptr<KORESort> to) {

bindings/python/runtime.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ void bind_runtime(py::module_ &m) {
5353
m.def("simplify_pattern", bindings::simplify);
5454
m.def("simplify_bool_pattern", bindings::simplify_to_bool);
5555

56+
m.def("return_sort_for_label", bindings::return_sort_for_label);
57+
5658
// This class can't be used directly from Python; the mutability semantics
5759
// that we get from the Pybind wrappers make it really easy to break things.
5860
// We therefore have to wrap it up in some external Python code; see

include/kllvm/ast/AST.h

+12
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ using sptr = std::shared_ptr<T>;
3636

3737
std::string decodeKore(std::string);
3838

39+
/*
40+
* Helper function to avoid repeated call-site uses of ostringstream when we
41+
* just want the string representation of a node, rather than to print it to a
42+
* stream.
43+
*/
44+
template <typename T>
45+
std::string ast_to_string(T &&node) {
46+
auto os = std::ostringstream{};
47+
std::forward<T>(node).print(os);
48+
return os.str();
49+
}
50+
3951
// KORESort
4052
class KORESort : public std::enable_shared_from_this<KORESort> {
4153
public:

include/kllvm/bindings/core/core.h

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
namespace kllvm::bindings {
1414

15+
std::string return_sort_for_label(std::string const &label);
16+
1517
std::shared_ptr<kllvm::KOREPattern> make_injection(
1618
std::shared_ptr<kllvm::KOREPattern> term,
1719
std::shared_ptr<kllvm::KORESort> from, std::shared_ptr<kllvm::KORESort> to);

include/runtime/header.h

+1
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ uint32_t getInjectionForSortOfTag(uint32_t tag);
350350
bool hook_STRING_eq(SortString, SortString);
351351

352352
const char *getSymbolNameForTag(uint32_t tag);
353+
const char *getReturnSortForTag(uint32_t tag);
353354
const char *topSort(void);
354355

355356
typedef struct {

lib/codegen/EmitConfigParser.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,55 @@ static void emitSortTable(KOREDefinition *definition, llvm::Module *module) {
13081308
}
13091309
}
13101310

1311+
/*
1312+
* Emit a table mapping symbol tags to the declared return sort for that symbol.
1313+
* For example:
1314+
*
1315+
* tag_of(initGeneratedTopCell) |-> sort_name_SortGeneratedTopCell{}
1316+
*
1317+
* Each value in the table is a pointer to a global variable containing the
1318+
* relevant sort name as a null-terminated string.
1319+
*
1320+
* The function `getReturnSortForTag` abstracts accesses to the data in this
1321+
* table.
1322+
*/
1323+
static void
1324+
emitReturnSortTable(KOREDefinition *definition, llvm::Module *module) {
1325+
auto &ctx = module->getContext();
1326+
1327+
auto const &syms = definition->getSymbols();
1328+
1329+
auto element_type = llvm::Type::getInt8PtrTy(ctx);
1330+
auto table_type = llvm::ArrayType::get(element_type, syms.size());
1331+
1332+
auto table = module->getOrInsertGlobal("return_sort_table", table_type);
1333+
auto values = std::vector<llvm::Constant *>{};
1334+
1335+
for (auto [tag, symbol] : syms) {
1336+
auto sort = symbol->getSort();
1337+
auto sort_str = ast_to_string(*sort);
1338+
1339+
auto char_type = llvm::Type::getInt8Ty(ctx);
1340+
auto str_type = llvm::ArrayType::get(char_type, sort_str.size() + 1);
1341+
1342+
auto sort_name
1343+
= module->getOrInsertGlobal("sort_name_" + sort_str, str_type);
1344+
1345+
auto i64_type = llvm::Type::getInt64Ty(ctx);
1346+
auto zero = llvm::ConstantInt::get(i64_type, 0);
1347+
1348+
auto pointer = llvm::ConstantExpr::getInBoundsGetElementPtr(
1349+
str_type, sort_name, std::vector<llvm::Constant *>{zero});
1350+
1351+
values.push_back(pointer);
1352+
}
1353+
1354+
auto global = llvm::dyn_cast<llvm::GlobalVariable>(table);
1355+
if (!global->hasInitializer()) {
1356+
global->setInitializer(llvm::ConstantArray::get(table_type, values));
1357+
}
1358+
}
1359+
13111360
void emitConfigParserFunctions(
13121361
KOREDefinition *definition, llvm::Module *module) {
13131362
emitGetTagForSymbolName(definition, module);
@@ -1329,6 +1378,7 @@ void emitConfigParserFunctions(
13291378
emitInjTags(definition, module);
13301379

13311380
emitSortTable(definition, module);
1381+
emitReturnSortTable(definition, module);
13321382
}
13331383

13341384
} // namespace kllvm

runtime/util/util.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
extern "C" {
44

5+
extern char *return_sort_table;
6+
7+
const char *getReturnSortForTag(uint32_t tag) {
8+
return (&return_sort_table)[tag];
9+
}
10+
511
block *dot_k() {
612
return leaf_block(getTagForSymbolName("dotk{}"));
713
}

test/python/Inputs/sorts.kore

+6,145
Large diffs are not rendered by default.

test/python/k-files/sorts.k

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module SORTS
2+
imports DOMAINS
3+
4+
syntax Int ::= func() [function, label(func), symbol]
5+
6+
syntax Foo ::= foo() [label(foo), symbol]
7+
syntax Bar ::= Foo
8+
| bar() [label(bar), symbol]
9+
10+
rule func() => 0
11+
rule foo() => bar()
12+
endmodule

test/python/test_return_sorts.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# RUN: mkdir -p %t
2+
# RUN: export IN=$(realpath Inputs/sorts.kore)
3+
# RUN: cd %t && %kompile "$IN" python --python %py-interpreter --python-output-dir .
4+
# RUN: KLLVM_DEFINITION=%t %python -u %s
5+
6+
from test_bindings import kllvm
7+
8+
import unittest
9+
10+
class TestReturnSorts(unittest.TestCase):
11+
12+
def _check_sort(self, label, sort):
13+
self.assertEqual(kllvm.runtime.return_sort_for_label(label), sort)
14+
15+
def test_function(self):
16+
self._check_sort('Lblfunc{}', 'SortInt{}')
17+
18+
def test_constructor(self):
19+
self._check_sort('Lblfoo{}', 'SortFoo{}')
20+
21+
def test_subsort(self):
22+
self._check_sort('Lblbar{}', 'SortBar{}')
23+
24+
25+
if __name__ == "__main__":
26+
unittest.main()

unittests/runtime-ffi/ffi.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define KCHAR char
1515
#define TYPETAG(type) "Lbl'Hash'ffi'Unds'" #type "{}"
1616

17+
char *return_sort_table = nullptr;
18+
1719
void *constructCompositePattern(uint32_t tag, std::vector<void *> &arguments) {
1820
return nullptr;
1921
}

unittests/runtime-io/io.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#define KCHAR char
1616

17+
char *return_sort_table = nullptr;
18+
1719
void *constructCompositePattern(uint32_t tag, std::vector<void *> &arguments) {
1820
return nullptr;
1921
}

unittests/runtime-strings/bytestest.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
#define KCHAR char
1010
extern "C" {
11+
12+
char *return_sort_table = nullptr;
13+
1114
uint32_t getTagForSymbolName(const char *s) {
1215
return 0;
1316
}

0 commit comments

Comments
 (0)