Skip to content

Commit 0e24009

Browse files
authored
Refactoring proof trace code (#876)
As discussed previously, this PR is an initial cleanup and refactoring effort for the existing proof-hint generation code implemented for rewrite, function and hook events. The changes in this PR don't add any new features to the traces; they just reorganise the existing code. Most of the changes made are documentation and identifying duplicated code across different call-sites that can be merged together. I plan to port the reorganisation in #862 over to the infrastructure in this PR once it is merged. The trace format is not tested in the backend currently (this is future work), but I have verified that a proof trace (that uses all event types) generated using this branch is byte-for-byte identical to one generated from the current master branch.
1 parent cb8fe60 commit 0e24009

File tree

4 files changed

+344
-321
lines changed

4 files changed

+344
-321
lines changed

include/kllvm/codegen/ProofEvent.h

+95-13
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,114 @@
77

88
#include "llvm/IR/Instructions.h"
99

10-
namespace kllvm {
10+
#include <map>
11+
#include <tuple>
1112

12-
void writeUInt64(
13-
llvm::Value *outputFile, llvm::Module *Module, uint64_t value,
14-
llvm::BasicBlock *Block);
13+
namespace kllvm {
1514

1615
class ProofEvent {
1716
private:
1817
KOREDefinition *Definition;
19-
llvm::BasicBlock *CurrentBlock;
2018
llvm::Module *Module;
2119
llvm::LLVMContext &Ctx;
20+
21+
/*
22+
* Load the boolean flag that controls whether proof hint output is enabled or
23+
* not, then create a branch at the end of this basic block depending on the
24+
* result.
25+
*
26+
* Returns a pair of blocks [proof enabled, merge]; the first of these is
27+
* intended for self-contained behaviour only relevant in proof output mode,
28+
* while the second is for the continuation of the interpreter's previous
29+
* behaviour.
30+
*/
2231
std::pair<llvm::BasicBlock *, llvm::BasicBlock *>
23-
proofBranch(std::string label);
32+
proofBranch(std::string const &label, llvm::BasicBlock *insertAtEnd);
33+
34+
/*
35+
* Set up a standard event prelude by creating a pair of basic blocks for the
36+
* proof output and continuation, then loading the output filename from its
37+
* global.
38+
*
39+
* Returns a triple [proof enabled, merge, output_file]; see `proofBranch` and
40+
* `emitGetOutputFileName`.
41+
*/
42+
std::tuple<llvm::BasicBlock *, llvm::BasicBlock *, llvm::Value *>
43+
eventPrelude(std::string const &label, llvm::BasicBlock *insertAtEnd);
44+
45+
/*
46+
* Emit a call that will serialize `term` to the specified `outputFile` as
47+
* binary KORE. This function can be called on any term, but the sort of that
48+
* term must be known.
49+
*/
50+
llvm::CallInst *emitSerializeTerm(
51+
KORECompositeSort &sort, llvm::Value *outputFile, llvm::Value *term,
52+
llvm::BasicBlock *insertAtEnd);
53+
54+
/*
55+
* Emit a call that will serialize `config` to the specified `outputFile` as
56+
* binary KORE. This function does not require a sort, but the configuration
57+
* passed must be a top-level configuration.
58+
*/
59+
llvm::CallInst *emitSerializeConfiguration(
60+
llvm::Value *outputFile, llvm::Value *config,
61+
llvm::BasicBlock *insertAtEnd);
62+
63+
/*
64+
* Emit a call that will serialize `value` to the specified `outputFile`.
65+
*/
66+
llvm::CallInst *emitWriteUInt64(
67+
llvm::Value *outputFile, uint64_t value, llvm::BasicBlock *insertAtEnd);
68+
69+
/*
70+
* Emit a call that will serialize `str` to the specified `outputFile`.
71+
*/
72+
llvm::CallInst *emitWriteString(
73+
llvm::Value *outputFile, std::string const &str,
74+
llvm::BasicBlock *insertAtEnd);
75+
76+
/*
77+
* Emit an instruction that has no effect and will be removed by optimization
78+
* passes.
79+
*
80+
* We need this workaround because some callsites will try to use
81+
* llvm::Instruction::insertAfter on the back of the MergeBlock after a proof
82+
* branch is created. If the MergeBlock has no instructions, this has resulted
83+
* in a segfault when printing the IR. Adding an effective no-op prevents this.
84+
*/
85+
llvm::BinaryOperator *emitNoOp(llvm::BasicBlock *insertAtEnd);
86+
87+
/*
88+
* Emit instructions to load the path of the interpreter's current output
89+
* file; used here for binary proof trace data.
90+
*/
91+
llvm::LoadInst *emitGetOutputFileName(llvm::BasicBlock *insertAtEnd);
2492

2593
public:
26-
llvm::BasicBlock *hookEvent_pre(std::string name);
27-
llvm::BasicBlock *hookEvent_post(llvm::Value *val, KORECompositeSort *sort);
28-
llvm::BasicBlock *hookArg(llvm::Value *val, KORECompositeSort *sort);
94+
[[nodiscard]] llvm::BasicBlock *
95+
hookEvent_pre(std::string name, llvm::BasicBlock *current_block);
96+
97+
[[nodiscard]] llvm::BasicBlock *hookEvent_post(
98+
llvm::Value *val, KORECompositeSort *sort,
99+
llvm::BasicBlock *current_block);
100+
101+
[[nodiscard]] llvm::BasicBlock *hookArg(
102+
llvm::Value *val, KORECompositeSort *sort,
103+
llvm::BasicBlock *current_block);
104+
105+
[[nodiscard]] llvm::BasicBlock *rewriteEvent(
106+
KOREAxiomDeclaration *axiom, llvm::Value *return_value, uint64_t arity,
107+
std::map<std::string, KOREVariablePattern *> vars,
108+
llvm::StringMap<llvm::Value *> const &subst,
109+
llvm::BasicBlock *current_block);
110+
111+
[[nodiscard]] llvm::BasicBlock *functionEvent(
112+
llvm::BasicBlock *current_block, KORECompositePattern *pattern,
113+
std::string const &locationStack);
29114

30115
public:
31-
ProofEvent(
32-
KOREDefinition *Definition, llvm::BasicBlock *EntryBlock,
33-
llvm::Module *Module)
116+
ProofEvent(KOREDefinition *Definition, llvm::Module *Module)
34117
: Definition(Definition)
35-
, CurrentBlock(EntryBlock)
36118
, Module(Module)
37119
, Ctx(Module->getContext()) { }
38120
};

lib/codegen/CreateTerm.cpp

+12-138
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <iostream>
1111

1212
#include "runtime/header.h" //for macros
13+
1314
#include "llvm/IR/BasicBlock.h"
1415
#include "llvm/IR/Constants.h"
1516
#include "llvm/IR/DerivedTypes.h"
@@ -117,10 +118,6 @@ declare void @printConfiguration(i8 *, %block *)
117118
}
118119
} // namespace
119120

120-
void writeUInt64(
121-
llvm::Value *outputFile, llvm::Module *Module, uint64_t value,
122-
llvm::BasicBlock *Block);
123-
124121
std::unique_ptr<llvm::Module>
125122
newModule(std::string name, llvm::LLVMContext &Context) {
126123
llvm::SMDiagnostic Err;
@@ -318,8 +315,8 @@ llvm::Value *CreateTerm::alloc_arg(
318315
llvm::Value *ret
319316
= createAllocation(p, fmt::format("{}:{}", locationStack, idx)).first;
320317
auto sort = dynamic_cast<KORECompositeSort *>(p->getSort().get());
321-
ProofEvent e(Definition, CurrentBlock, Module);
322-
CurrentBlock = e.hookArg(ret, sort);
318+
ProofEvent e(Definition, Module);
319+
CurrentBlock = e.hookArg(ret, sort, CurrentBlock);
323320
return ret;
324321
}
325322

@@ -700,49 +697,8 @@ llvm::Value *CreateTerm::createFunctionCall(
700697
}
701698
}
702699

703-
llvm::Function *func = CurrentBlock->getParent();
704-
705-
auto ProofOutputFlag = Module->getOrInsertGlobal(
706-
"proof_output", llvm::Type::getInt1Ty(Module->getContext()));
707-
auto OutputFileName = Module->getOrInsertGlobal(
708-
"output_file", llvm::Type::getInt8PtrTy(Module->getContext()));
709-
auto proofOutput = new llvm::LoadInst(
710-
llvm::Type::getInt1Ty(Module->getContext()), ProofOutputFlag,
711-
"proof_output", CurrentBlock);
712-
llvm::BasicBlock *TrueBlock
713-
= llvm::BasicBlock::Create(Module->getContext(), "if", func);
714-
auto outputFile = new llvm::LoadInst(
715-
llvm::Type::getInt8PtrTy(Module->getContext()), OutputFileName, "output",
716-
TrueBlock);
717-
auto ir = new llvm::IRBuilder(TrueBlock);
718-
llvm::BasicBlock *MergeBlock
719-
= llvm::BasicBlock::Create(Module->getContext(), "tail", func);
720-
llvm::BranchInst::Create(TrueBlock, MergeBlock, proofOutput, CurrentBlock);
721-
722-
std::ostringstream symbolName;
723-
pattern->getConstructor()->print(symbolName);
724-
725-
auto symbolString
726-
= ir->CreateGlobalStringPtr(symbolName.str(), "", 0, Module);
727-
auto positionString = ir->CreateGlobalStringPtr(locationStack, "", 0, Module);
728-
writeUInt64(outputFile, Module, 0xdddddddddddddddd, TrueBlock);
729-
ir->CreateCall(
730-
getOrInsertFunction(
731-
Module, "printVariableToFile",
732-
llvm::Type::getVoidTy(Module->getContext()),
733-
llvm::Type::getInt8PtrTy(Module->getContext()),
734-
llvm::Type::getInt8PtrTy(Module->getContext())),
735-
{outputFile, symbolString});
736-
ir->CreateCall(
737-
getOrInsertFunction(
738-
Module, "printVariableToFile",
739-
llvm::Type::getVoidTy(Module->getContext()),
740-
llvm::Type::getInt8PtrTy(Module->getContext()),
741-
llvm::Type::getInt8PtrTy(Module->getContext())),
742-
{outputFile, positionString});
743-
744-
llvm::BranchInst::Create(MergeBlock, TrueBlock);
745-
CurrentBlock = MergeBlock;
700+
auto event = ProofEvent(Definition, Module);
701+
CurrentBlock = event.functionEvent(CurrentBlock, pattern, locationStack);
746702

747703
return createFunctionCall(name, returnCat, args, sret, tailcc, locationStack);
748704
}
@@ -932,13 +888,12 @@ CreateTerm::createAllocation(KOREPattern *pattern, std::string locationStack) {
932888
.get());
933889
std::string name = strPattern->getContents();
934890

935-
ProofEvent p1(Definition, CurrentBlock, Module);
936-
CurrentBlock = p1.hookEvent_pre(name);
891+
ProofEvent p(Definition, Module);
892+
CurrentBlock = p.hookEvent_pre(name, CurrentBlock);
937893
llvm::Value *val = createHook(
938894
symbolDecl->getAttributes().at("hook").get(), constructor,
939895
locationStack);
940-
ProofEvent p2(Definition, CurrentBlock, Module);
941-
CurrentBlock = p2.hookEvent_post(val, sort);
896+
CurrentBlock = p.hookEvent_post(val, sort, CurrentBlock);
942897

943898
return std::make_pair(val, true);
944899
} else {
@@ -1114,91 +1069,10 @@ bool makeFunction(
11141069

11151070
auto CurrentBlock = creator.getCurrentBlock();
11161071
if (apply && bigStep) {
1117-
auto ProofOutputFlag = Module->getOrInsertGlobal(
1118-
"proof_output", llvm::Type::getInt1Ty(Module->getContext()));
1119-
auto OutputFileName = Module->getOrInsertGlobal(
1120-
"output_file", llvm::Type::getInt8PtrTy(Module->getContext()));
1121-
auto proofOutput = new llvm::LoadInst(
1122-
llvm::Type::getInt1Ty(Module->getContext()), ProofOutputFlag,
1123-
"proof_output", CurrentBlock);
1124-
llvm::BasicBlock *TrueBlock
1125-
= llvm::BasicBlock::Create(Module->getContext(), "if", applyRule);
1126-
auto ir = new llvm::IRBuilder(TrueBlock);
1127-
llvm::BasicBlock *MergeBlock
1128-
= llvm::BasicBlock::Create(Module->getContext(), "tail", applyRule);
1129-
llvm::BranchInst::Create(TrueBlock, MergeBlock, proofOutput, CurrentBlock);
1130-
auto outputFile = new llvm::LoadInst(
1131-
llvm::Type::getInt8PtrTy(Module->getContext()), OutputFileName,
1132-
"output", TrueBlock);
1133-
writeUInt64(outputFile, Module, axiom->getOrdinal(), TrueBlock);
1134-
writeUInt64(
1135-
outputFile, Module, applyRule->arg_end() - applyRule->arg_begin(),
1136-
TrueBlock);
1137-
for (auto entry = subst.begin(); entry != subst.end(); ++entry) {
1138-
auto key = entry->getKey();
1139-
auto val = entry->getValue();
1140-
auto var = vars[key.str()];
1141-
auto sort = dynamic_cast<KORECompositeSort *>(var->getSort().get());
1142-
auto cat = sort->getCategory(definition);
1143-
std::ostringstream Out;
1144-
sort->print(Out);
1145-
auto sortptr = ir->CreateGlobalStringPtr(Out.str(), "", 0, Module);
1146-
auto varname = ir->CreateGlobalStringPtr(key, "", 0, Module);
1147-
ir->CreateCall(
1148-
getOrInsertFunction(
1149-
Module, "printVariableToFile",
1150-
llvm::Type::getVoidTy(Module->getContext()),
1151-
llvm::Type::getInt8PtrTy(Module->getContext()),
1152-
llvm::Type::getInt8PtrTy(Module->getContext())),
1153-
{outputFile, varname});
1154-
if (cat.cat == SortCategory::Symbol
1155-
|| cat.cat == SortCategory::Variable) {
1156-
ir->CreateCall(
1157-
getOrInsertFunction(
1158-
Module, "serializeTermToFile",
1159-
llvm::Type::getVoidTy(Module->getContext()),
1160-
llvm::Type::getInt8PtrTy(Module->getContext()),
1161-
getValueType({SortCategory::Symbol, 0}, Module),
1162-
llvm::Type::getInt8PtrTy(Module->getContext())),
1163-
{outputFile, val, sortptr});
1164-
} else if (val->getType()->isIntegerTy()) {
1165-
val = ir->CreateIntToPtr(
1166-
val, llvm::Type::getInt8PtrTy(Module->getContext()));
1167-
ir->CreateCall(
1168-
getOrInsertFunction(
1169-
Module, "serializeRawTermToFile",
1170-
llvm::Type::getVoidTy(Module->getContext()),
1171-
llvm::Type::getInt8PtrTy(Module->getContext()),
1172-
llvm::Type::getInt8PtrTy(Module->getContext()),
1173-
llvm::Type::getInt8PtrTy(Module->getContext())),
1174-
{outputFile, val, sortptr});
1175-
} else {
1176-
val = ir->CreatePointerCast(
1177-
val, llvm::Type::getInt8PtrTy(Module->getContext()));
1178-
ir->CreateCall(
1179-
getOrInsertFunction(
1180-
Module, "serializeRawTermToFile",
1181-
llvm::Type::getVoidTy(Module->getContext()),
1182-
llvm::Type::getInt8PtrTy(Module->getContext()),
1183-
llvm::Type::getInt8PtrTy(Module->getContext()),
1184-
llvm::Type::getInt8PtrTy(Module->getContext())),
1185-
{outputFile, val, sortptr});
1186-
}
1187-
writeUInt64(outputFile, Module, 0xcccccccccccccccc, TrueBlock);
1188-
}
1189-
1190-
writeUInt64(outputFile, Module, 0xffffffffffffffff, TrueBlock);
1191-
ir->CreateCall(
1192-
getOrInsertFunction(
1193-
Module, "serializeConfigurationToFile",
1194-
llvm::Type::getVoidTy(Module->getContext()),
1195-
llvm::Type::getInt8PtrTy(Module->getContext()),
1196-
getValueType({SortCategory::Symbol, 0}, Module)),
1197-
{outputFile, retval});
1198-
writeUInt64(outputFile, Module, 0xcccccccccccccccc, TrueBlock);
1199-
1200-
llvm::BranchInst::Create(MergeBlock, TrueBlock);
1201-
CurrentBlock = MergeBlock;
1072+
auto event = ProofEvent(definition, Module);
1073+
CurrentBlock = event.rewriteEvent(
1074+
axiom, retval, applyRule->arg_end() - applyRule->arg_begin(), vars,
1075+
subst, CurrentBlock);
12021076
}
12031077

12041078
if (bigStep) {

lib/codegen/EmitConfigParser.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,7 @@ static llvm::Constant *getOffsetOfMember(
11191119
auto offset
11201120
= llvm::DataLayout(mod).getStructLayout(struct_ty)->getElementOffset(
11211121
nth_member);
1122-
auto offset_ty = llvm::Type::getInt32Ty(mod->getContext());
1122+
auto offset_ty = llvm::Type::getInt64Ty(mod->getContext());
11231123
return llvm::ConstantInt::get(offset_ty, offset);
11241124
#else
11251125
return llvm::ConstantExpr::getOffsetOf(struct_ty, nth_member);

0 commit comments

Comments
 (0)