diff --git a/chb/app/CHVersion.py b/chb/app/CHVersion.py index 6f4ed84a..254c8c22 100644 --- a/chb/app/CHVersion.py +++ b/chb/app/CHVersion.py @@ -1 +1 @@ -chbversion: str = "0.3.0-20250722" +chbversion: str = "0.3.0-20250723" diff --git a/chb/arm/ARMInstruction.py b/chb/arm/ARMInstruction.py index 3230106c..87cedbd1 100644 --- a/chb/arm/ARMInstruction.py +++ b/chb/arm/ARMInstruction.py @@ -441,8 +441,9 @@ def to_string( opcodewidth: int = 40, typingrules: bool = False, sp: bool = False) -> str: + + lines: List[str] = [] if typingrules: - lines: List[str] = [] rulesapplied = self.app.type_constraints.rules_applied_to_instruction( self.armfunction.faddr, self.iaddr) for r in sorted(str(r) for r in rulesapplied): diff --git a/chb/ast/ASTApplicationInterface.py b/chb/ast/ASTApplicationInterface.py index 7980c4e5..5027dd49 100644 --- a/chb/ast/ASTApplicationInterface.py +++ b/chb/ast/ASTApplicationInterface.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------ # The MIT License (MIT) # -# Copyright (c) 2022-2023 Aarno Labs, LLC +# Copyright (c) 2022-2025 Aarno Labs, LLC # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,7 +42,7 @@ from chb.ast.CustomASTSupport import CustomASTSupport -pirversion: str = "0.1.0-20240703" +pirversion: str = "0.1.0-202507024" class ASTApplicationInterface: diff --git a/chb/ast/ASTDeserializer.py b/chb/ast/ASTDeserializer.py index 0b16d4b4..61e32cbf 100644 --- a/chb/ast/ASTDeserializer.py +++ b/chb/ast/ASTDeserializer.py @@ -71,6 +71,15 @@ def low_level_ast(self) -> AST.ASTStmt: def high_unreduced_ast(self) -> AST.ASTStmt: return self.asts[1] + def get_stmt(self, stmtid: int) -> AST.ASTStmt: + for node in self.nodes.values(): + if node.is_ast_stmt: + node = cast(AST.ASTStmt, node) + if node.stmtid == stmtid: + return node + else: + raise Exception("No stmt found with id: " + str(stmtid)) + def get_instruction(self, instrid: int) -> AST.ASTInstruction: for node in self.nodes.values(): if node.is_ast_instruction: @@ -742,13 +751,15 @@ def arg(ix: int) -> Dict[str, Any]: thenbranch = cast(AST.ASTStmt, mk_node(arg(1))) elsebranch = cast(AST.ASTStmt, mk_node(arg(2))) targetaddr = r["destination-addr"] + predicated = r.get("predicated", None) nodes[id] = astree.mk_branch( condition, thenbranch, elsebranch, targetaddr, optstmtid=stmtid, - optlocationid=locationid) + optlocationid=locationid, + predicated=predicated) elif tag == "block": stmtid = r["stmtid"] diff --git a/chb/ast/ASTNOPVisitor.py b/chb/ast/ASTNOPVisitor.py index 0e1f1a70..7f58c2e5 100644 --- a/chb/ast/ASTNOPVisitor.py +++ b/chb/ast/ASTNOPVisitor.py @@ -148,6 +148,9 @@ def visit_question_expression(self, expr: AST.ASTQuestion) -> None: def visit_address_of_expression(self, expr: AST.ASTAddressOf) -> None: pass + def visit_start_of_expression(self, expr: AST.ASTStartOf) -> None: + pass + def visit_void_typ(self, typ: AST.ASTTypVoid) -> None: pass diff --git a/chb/ast/ASTNode.py b/chb/ast/ASTNode.py index 2ce74022..aed63174 100644 --- a/chb/ast/ASTNode.py +++ b/chb/ast/ASTNode.py @@ -595,13 +595,15 @@ def __init__( elsestmt: "ASTStmt", tgtaddress: str, mergeaddress: Optional[str], - labels: List["ASTStmtLabel"] = []) -> None: + labels: List["ASTStmtLabel"] = [], + predicated: Optional[int] = None) -> None: ASTStmt.__init__(self, stmtid, locationid, labels, "if") self._cond = cond self._ifstmt = ifstmt self._elsestmt = elsestmt self._tgtaddress = tgtaddress self._mergeaddress = mergeaddress + self._predicated = predicated @property def is_ast_branch(self) -> bool: @@ -627,6 +629,10 @@ def target_address(self) -> str: def merge_address(self) -> Optional[str]: return self._mergeaddress + @property + def predicated(self) -> Optional[int]: + return self._predicated + def accept(self, visitor: "ASTVisitor") -> None: visitor.visit_branch_stmt(self) diff --git a/chb/ast/ASTProvenanceCollector.py b/chb/ast/ASTProvenanceCollector.py new file mode 100644 index 00000000..9e552d1d --- /dev/null +++ b/chb/ast/ASTProvenanceCollector.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------------ +# CodeHawk Binary Analyzer +# Author: Henny Sipma +# ------------------------------------------------------------------------------ +# The MIT License (MIT) +# +# Copyright (c) 2021-2025 Aarno Labs LLC +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ------------------------------------------------------------------------------ +"""Collects stmt and instruction provenance for pir inspector.""" + +from typing import Dict, List, TYPE_CHECKING + +import chb.ast.ASTNode as AST +from chb.ast.ASTNOPVisitor import ASTNOPVisitor + +if TYPE_CHECKING: + from chb.ast.ASTDeserializer import ASTFunctionDeserialization + + +class ASTProvenanceCollector(ASTNOPVisitor): + + def __init__(self, dfn: "ASTFunctionDeserialization") -> None: + self._dfn = dfn + self._hl_ll_instr_mapping: Dict[int, List[int]] = {} + + @property + def dfn(self) -> "ASTFunctionDeserialization": + return self._dfn + + @property + def provenance(self) -> Dict[int, List[AST.ASTInstruction]]: + result: Dict[int, List[AST.ASTInstruction]] = {} + for hlinstrid in self._hl_ll_instr_mapping: + llinstrids = self._hl_ll_instr_mapping[hlinstrid] + result[hlinstrid] = [self.dfn.get_instruction(i) for i in llinstrids] + return result + + def instruction_provenance( + self, stmt: AST.ASTStmt) -> Dict[int, List[AST.ASTInstruction]]: + stmt.accept(self) + return self.provenance + + def visit_loop_stmt(self, stmt: AST.ASTLoop) -> None: + stmt.body.accept(self) + + def visit_block_stmt(self, stmt: AST.ASTBlock) -> None: + for s in stmt.stmts: + s.accept(self) + + def visit_instruction_sequence_stmt(self, stmt: AST.ASTInstrSequence) -> None: + for i in stmt.instructions: + if i.instrid in self.dfn.astree.provenance.instruction_mapping: + self._hl_ll_instr_mapping[i.instrid] = ( + self.dfn.astree.provenance.instruction_mapping[i.instrid]) + + def visit_branch_stmt(self, stmt: AST.ASTBranch) -> None: + stmt.ifstmt.accept(self) + stmt.elsestmt.accept(self) diff --git a/chb/ast/ASTSerializer.py b/chb/ast/ASTSerializer.py index 5bc99c78..7624caa5 100644 --- a/chb/ast/ASTSerializer.py +++ b/chb/ast/ASTSerializer.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------ # The MIT License (MIT) # -# Copyright (c) 2022-2024 Aarno Labs LLC +# Copyright (c) 2022-2025 Aarno Labs LLC # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -373,6 +373,8 @@ def index_branch_stmt(self, stmt: AST.ASTBranch) -> int: stmt.elsestmt.index(self)]) node["destination-addr"] = stmt.target_address node["merge-addr"] = stmt.merge_address + if stmt.predicated is not None: + node["predicated"] = stmt.predicated return self.add(tags, args, node) def index_instruction_sequence_stmt(self, stmt: AST.ASTInstrSequence) -> int: diff --git a/chb/ast/ASTViewer.py b/chb/ast/ASTViewer.py index 70f12d38..2f681a21 100644 --- a/chb/ast/ASTViewer.py +++ b/chb/ast/ASTViewer.py @@ -95,6 +95,16 @@ def to_graph(self, stmt: AST.ASTStmt) -> DU.ASTDotGraph: stmt.accept(self) return self.dotgraph + def stmt_to_graph( + self, + stmt: AST.ASTStmt, + provinstrs: Dict[int, List[AST.ASTInstruction]] = {}) -> DU.ASTDotGraph: + stmt.accept(self) + for pl in provinstrs.values(): + for p in pl: + p.accept(self) + return self.dotgraph + def instr_to_graph( self, instr: AST.ASTInstruction, @@ -124,6 +134,9 @@ def get_expr_connections(self, expr: AST.ASTExpr) -> str: if self.astree.provenance.has_reaching_definitions(expr.exprid): ids = self.astree.provenance.reaching_definitions[expr.exprid] result += "\\nrdefs:[" + ",".join(str(id) for id in ids) + "]" + if expr.exprid in self.astree.expr_spanmap(): + span = self.astree.expr_spanmap()[expr.exprid] + result += "\\ncc:" + span return result def stmt_name(self, stmt: AST.ASTStmt) -> str: @@ -157,8 +170,10 @@ def visit_loop_stmt(self, stmt: AST.ASTLoop) -> None: def visit_branch_stmt(self, stmt: AST.ASTBranch) -> None: name = self.stmt_name(stmt) - self.add_node( - name, labeltxt="if:" + str(stmt.stmtid), color=nodecolors["stmt"]) + labeltxt = "if:" + str(stmt.stmtid) + if stmt.predicated is not None: + labeltxt += "\\npredicated:" + str(stmt.predicated) + self.add_node(name, labeltxt=labeltxt, color=nodecolors["stmt"]) self.add_edge(name, self.stmt_name(stmt.ifstmt), labeltxt="then") self.add_edge(name, self.stmt_name(stmt.elsestmt), labeltxt="else") self.add_edge(name, self.expr_name(stmt.condition), labeltxt="condition") diff --git a/chb/ast/AbstractSyntaxTree.py b/chb/ast/AbstractSyntaxTree.py index 3c78a95e..8564e3b9 100644 --- a/chb/ast/AbstractSyntaxTree.py +++ b/chb/ast/AbstractSyntaxTree.py @@ -287,6 +287,25 @@ def get_exprid(self, exprid: Optional[int]) -> int: def add_span(self, span: ASTSpanRecord) -> None: self._spans.append(span) + def add_stmt_span( + self, locationid: int, spans: List[Tuple[str, str]]) -> None: + """Add a span for the ast instructions contained in a stmt. + + Note: this is currently done only for if statements originating from + predicated instructions. + """ + spaninstances: List[Dict[str, Union[str, int]]] = [] + for (iaddr, bytestring) in spans: + span: Dict[str, Union[str, int]] = {} + span["base_va"] = iaddr + span["size"] = len(bytestring) // 2 + spaninstances.append(span) + spanrec: Dict[str, Any] = {} + spanrec["locationid"] = locationid + spanrec["spans"] = spaninstances + self.add_span(cast(ASTSpanRecord, spanrec)) + + def add_instruction_span( self, locationid: int, base: str, bytestring: str) -> None: """Add a span for an ast instruction.""" @@ -402,7 +421,8 @@ def mk_branch( mergeaddr: Optional[str] = None, optstmtid: Optional[int] = None, optlocationid: Optional[int] = None, - labels: List[AST.ASTStmtLabel] = []) -> AST.ASTBranch: + labels: List[AST.ASTStmtLabel] = [], + predicated: Optional[int] = None) -> AST.ASTBranch: stmtid = self.get_stmtid(optstmtid) locationid = self.get_locationid(optlocationid) if condition is None: @@ -410,7 +430,7 @@ def mk_branch( condition = self.mk_tmp_lval_expression() return AST.ASTBranch( stmtid, locationid, condition, ifbranch, elsebranch, - targetaddr, mergeaddr) + targetaddr, mergeaddr, predicated=predicated) def mk_goto_stmt( self, @@ -1224,7 +1244,7 @@ def _cast_if_needed(self, e: AST.ASTExpr, opunsigned: bool) -> AST.ASTExpr: # CIL encodes signedness in the shift operator but C infers # the operator flavor from the type of the left operand, so # we may need to insert a cast to ensure that serialization - # through C code preserves the AST. + # through C code preserves the AST. mb_t = e.ctype(ASTBasicCTyper(self.globalsymboltable)) force_cast = mb_t is None diff --git a/chb/ast/astutil.py b/chb/ast/astutil.py index eeda4420..fbde34f7 100644 --- a/chb/ast/astutil.py +++ b/chb/ast/astutil.py @@ -39,6 +39,7 @@ from chb.ast.ASTCPrettyPrinter import ASTCPrettyPrinter from chb.ast.ASTDeserializer import ASTDeserializer import chb.ast.ASTNode as AST +from chb.ast.ASTProvenanceCollector import ASTProvenanceCollector from chb.ast.ASTViewer import ASTViewer import chb.ast.astdotutil as DU @@ -191,6 +192,36 @@ def viewastcmd(args: argparse.Namespace) -> NoReturn: exit(0) +def viewstmtcmd(args: argparse.Namespace) -> NoReturn: + + # arguments + pirfile: str = args.pirfile + function: Optional[str] = args.function + stmtid: int = args.stmtid + provenance: bool = args.provenance + outputfilename: str = args.output + + with open(pirfile, "r") as fp: + pirjson = json.load(fp) + + faddr = get_function_addr(pirjson, function) + deserializer = ASTDeserializer(pirjson) + (globaltable, dfns) = deserializer.deserialize() + for dfn in dfns: + if dfn.astree.faddr == faddr: + stmt = dfn.get_stmt(stmtid) + viewer = ASTViewer(faddr, dfn.astree) + if provenance: + provcollector = ASTProvenanceCollector(dfn) + provinstrs = provcollector.instruction_provenance(stmt) + g = viewer.stmt_to_graph(stmt, provinstrs) + else: + g = viewer.to_graph(stmt) + + DU.print_dot(outputfilename, g) + exit(0) + + def viewinstrcmd(args: argparse.Namespace) -> NoReturn: # arguments diff --git a/chb/ast/doc/CHANGELOG.md b/chb/ast/doc/CHANGELOG.md index c2ec7d6c..da047438 100644 --- a/chb/ast/doc/CHANGELOG.md +++ b/chb/ast/doc/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +### version 0.1.0-2025-07-24 + +- Add property predicated to branch stmt to indicate its origin from a predicated instruction +- Add option to pir inspector to view a single statement + + ### version 0.1.0-2024-07-03 - Add printing of structs to ASTCPrettyPrinter diff --git a/chb/ast/pirinspector b/chb/ast/pirinspector index 45d71f46..370ffd1d 100755 --- a/chb/ast/pirinspector +++ b/chb/ast/pirinspector @@ -93,6 +93,24 @@ def parse() -> argparse.Namespace: ) viewastcmd.set_defaults(func=AU.viewastcmd) + # --- view selected stmt ast + viewstmtcmd = viewparsers.add_parser("stmt") + viewstmtcmd.add_argument("pirfile", help="name of json file with ast information") + viewstmtcmd.add_argument("--function", help="name or address of function to view") + viewstmtcmd.add_argument( + "--stmtid", help="stmt id of statement", type=int) + viewstmtcmd.add_argument( + "--provenance", + help="show associated low-level instructions", + action="store_true", + ) + viewstmtcmd.add_argument( + "-o", "--output", + help="name of output graph file (without extension)", + required=True, + ) + viewstmtcmd.set_defaults(func=AU.viewstmtcmd) + # --- view single instruction ast viewinstrcmd = viewparsers.add_parser("instruction") viewinstrcmd.add_argument("pirfile", help="name of json file with ast information") diff --git a/chb/astinterface/ASTICodeTransformer.py b/chb/astinterface/ASTICodeTransformer.py index 661a2c7c..ac0f5270 100644 --- a/chb/astinterface/ASTICodeTransformer.py +++ b/chb/astinterface/ASTICodeTransformer.py @@ -123,7 +123,8 @@ def transform_branch_stmt(self, stmt: AST.ASTBranch) -> AST.ASTStmt: newelse, stmt.target_address, mergeaddr=stmt.merge_address, - optlocationid=stmt.locationid) + optlocationid=stmt.locationid, + predicated=stmt.predicated) def transform_switch_stmt(self, stmt: AST.ASTSwitchStmt) -> AST.ASTStmt: newcases = stmt.cases.transform(self) diff --git a/chb/astinterface/ASTInterface.py b/chb/astinterface/ASTInterface.py index 9a7a286e..5720ca55 100644 --- a/chb/astinterface/ASTInterface.py +++ b/chb/astinterface/ASTInterface.py @@ -585,7 +585,10 @@ def add_formal( return nextindex def add_span(self, span: ASTSpanRecord) -> None: - self.astree.add_span + self.astree.add_span(span) + + def add_stmt_span(self, id: int, spans: List[Tuple[str, str]]) -> None: + self.astree.add_stmt_span(id, spans) def add_instruction_span(self, id: int, base: str, bytestring: str) -> None: self.astree.add_instruction_span(id, base, bytestring) @@ -696,14 +699,16 @@ def mk_branch( elsebranch: AST.ASTStmt, targetaddr: str, mergeaddr: Optional[str] = None, - optlocationid: Optional[int] = None) -> AST.ASTStmt: + optlocationid: Optional[int] = None, + predicated: Optional[int] = None) -> AST.ASTStmt: return self.astree.mk_branch( condition, ifbranch, elsebranch, targetaddr, mergeaddr=mergeaddr, - optlocationid=optlocationid) + optlocationid=optlocationid, + predicated=predicated) def mk_instr_sequence( self, diff --git a/chb/astinterface/ASTInterfaceBasicBlock.py b/chb/astinterface/ASTInterfaceBasicBlock.py index eace028a..8151e94f 100644 --- a/chb/astinterface/ASTInterfaceBasicBlock.py +++ b/chb/astinterface/ASTInterfaceBasicBlock.py @@ -204,6 +204,7 @@ def ast_fragment( elseinstrs = [self.get_instruction(i.iaddr) for i in frag.elsebranch] thenstmt = self.linear_block_ast(astree, theninstrs) elsestmt = self.linear_block_ast(astree, elseinstrs) + spans = [(i.iaddr, i.bytestring) for i in theninstrs + elseinstrs] cinstr = theninstrs[0] brcond = cinstr.ast_cc_condition(astree) if brcond is None: @@ -214,7 +215,11 @@ def ast_fragment( else: astree.astree.add_expr_span( brcond.exprid, cinstr.iaddr, cinstr.bytestring) - return astree.mk_branch(brcond, thenstmt, elsestmt, "0x0") + instrcount = len(theninstrs) + len(elseinstrs) + ifstmt = astree.mk_branch( + brcond, thenstmt, elsestmt, cinstr.iaddr, predicated=instrcount) + astree.add_stmt_span(ifstmt.locationid, spans) + return ifstmt else: instrs = [self.get_instruction(i.iaddr) for i in frag.linear] return self.linear_ast(astree, instrs) diff --git a/chb/astinterface/ASTInterfaceFunction.py b/chb/astinterface/ASTInterfaceFunction.py index e3d646e5..44fc14ea 100644 --- a/chb/astinterface/ASTInterfaceFunction.py +++ b/chb/astinterface/ASTInterfaceFunction.py @@ -274,7 +274,7 @@ def set_invariants(self) -> None: self.astinterface, anonymous=True) - if str(var).startswith("astmem_tmp"): + if "astmem_tmp" in str(var): chklogger.logger.info( "Skipping invariant %s at %s", str(fact), str(loc)) @@ -294,6 +294,11 @@ def set_invariants(self) -> None: instr.iaddr, self.astinterface, anonymous=True) + if "astmem_tmp" in str(aexpr): + chklogger.logger.info( + "Skipping invariant %s at %s", + str(aexpr), str(loc)) + continue aexprindex = aexpr.index(self.astinterface.serializer) else: continue @@ -328,7 +333,8 @@ def set_invariants(self) -> None: self.astinterface, anonymous=True) - if str(var).startswith("astmem_tmp"): + if "astmem_tmp" in str(aexpr): + # if str(var).startswith("astmem_tmp"): chklogger.logger.info( "Skipping invariant %s at %s", str(fact), str(loc))