diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index b430f3026fe5..d8a0b255bb50 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -9,6 +9,8 @@ import org.apache.commons.lang3.StringUtils object AstCreatorHelper { private val TagsToKeepInFullName = List("", "", "", "", "", "") + private val ReturnTypeMatcher = """^\(.*\)(->|:)(.+)$""".r + private val ClosureSignatureMatcher = """^(\(.*\))\s*(.*)\s*->(.+)$""".r /** Removes generic type parameters from qualified names while preserving special tags. * @@ -84,9 +86,15 @@ object AstCreatorHelper { case "Dictionary" => Defines.Dictionary case "Nil" => Defines.Nil // Special patterns with specific handling - case t if t.startsWith("[") && t.endsWith("]") => Defines.Array - case t if t.contains("=>") || t.contains("->") => Defines.Function - case t if t.contains("( ") => t.substring(0, t.indexOf("( ")) + case t if t.startsWith("[") && t.endsWith("]") => Defines.Array + case ClosureSignatureMatcher(params, mods, returnType) => + // "throws" is the only modifier that swiftc keeps + // so we have to restore it here to keep signatures + // consistent between runs with compiler support and without. + val m = if (mods.contains("throws")) { "throws" } + else "" + s"${Defines.Function}<$params$m->$returnType>".replace(" ", "") + case t if t.contains("( ") => t.substring(0, t.indexOf("( ")) // Default case case typeStr => typeStr } @@ -162,7 +170,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As case None if identNode.typeFullName != Defines.Any => identNode.typeFullName case _ => Defines.Any } - val typedIdentNode = identNode.typeFullName(tpe) + identNode.typeFullName = tpe scope.addVariableReference(identifierName, identNode, tpe, EvaluationStrategies.BY_REFERENCE) Ast(identNode) } @@ -191,8 +199,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } } - private val ReturnTypeMatcher = """\(.*\)(->|:)(.+)""".r - protected def methodInfoForFunctionDeclLike(node: FunctionDeclLike): MethodInfo = { val name = calcMethodName(node) fullnameProvider.declFullname(node) match { @@ -226,9 +232,16 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As val returnType = cleanType(code(s.returnClause.`type`)) (s"${paramSignature(s.parameterClause)}->$returnType", returnType) case c: ClosureExprSyntax => - val returnType = c.signature.flatMap(_.returnClause).fold(Defines.Any)(r => cleanType(code(r.`type`))) - val paramClauseCode = c.signature.flatMap(_.parameterClause).fold("()")(paramSignature) - (s"$paramClauseCode->$returnType", returnType) + fullnameProvider.typeFullnameRaw(node) match { + case Some(tpe) => + val signature = tpe + val returnType = ReturnTypeMatcher.findFirstMatchIn(signature).map(_.group(2)).getOrElse(Defines.Any) + (signature, returnType) + case _ => + val returnType = c.signature.flatMap(_.returnClause).fold(Defines.Any)(r => cleanType(code(r.`type`))) + val paramClauseCode = c.signature.flatMap(_.parameterClause).fold("()")(paramSignature) + (s"$paramClauseCode->$returnType", returnType) + } } registerType(returnType) MethodInfo(methodName, methodFullName, signature, returnType) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForDeclSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForDeclSyntaxCreator.scala index 6597137d11ad..320508c929e8 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForDeclSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForDeclSyntaxCreator.scala @@ -705,7 +705,7 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) { List.empty[Ast] } - val methodReturnNode_ = methodReturnNode(node, returnType, Some(returnType)) + val methodReturnNode_ = methodReturnNode(node, returnType) val blockAst_ = blockAst(block, methodBlockContent ++ bodyStmtAsts) val astForMethod = diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala index 65edef691eb9..2c8098050ffd 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala @@ -210,7 +210,6 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { baseNode: NewNode, callName: String ): Ast = { - val trailingClosureAsts = callExpr.trailingClosure.toList.map(astForNode) val additionalTrailingClosuresAsts = callExpr.additionalTrailingClosures.children.map(c => astForNode(c.closure)) @@ -289,6 +288,8 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { val thisTmpNode = identifierNode(callee, tmpVarName) (fieldAccessAst, thisTmpNode, memberCode) } + case other if isRefToClosure(node, other) => + return astForClosureCall(node) case _ => val receiverAst = astForNode(callee) val thisNode = identifierNode(callee, "this") @@ -299,6 +300,48 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } } + private def astForClosureCall(expr: FunctionCallExprSyntax): Ast = { + val tpe = fullnameProvider.typeFullname(expr).getOrElse(Defines.Any) + registerType(tpe) + val signature = fullnameProvider.typeFullnameRaw(expr.calledExpression).getOrElse(x2cpg.Defines.UnresolvedSignature) + val callName = Defines.ClosureApplyMethodName + val callMethodFullname = s"${Defines.Function}<$signature>.$callName:$signature" + val baseAst = astForNode(expr.calledExpression) + + val trailingClosureAsts = expr.trailingClosure.toList.map(astForNode) + val additionalTrailingClosuresAsts = expr.additionalTrailingClosures.children.map(c => astForNode(c.closure)) + + val args = expr.arguments.children.map(astForNode) ++ trailingClosureAsts ++ additionalTrailingClosuresAsts + + val callExprCode = code(expr) + val callNode_ = callNode( + expr, + callExprCode, + callName, + callMethodFullname, + DispatchTypes.DYNAMIC_DISPATCH, + Option(signature), + Option(tpe) + ) + callAst(callNode_, args, Option(baseAst)) + } + + private def isRefToClosure(func: FunctionCallExprSyntax, node: ExprSyntax): Boolean = { + if (!config.swiftBuild) { + // Early exit; without types from the compiler we will be unable to identify closure calls anyway. + // This saves us the typeFullname lookup below. + return false + } + node match { + case refExpr: DeclReferenceExprSyntax + if refExpr.baseName.isInstanceOf[identifier] && refExpr.argumentNames.isEmpty && + fullnameProvider.declFullname(func).isEmpty && + fullnameProvider.typeFullname(refExpr).exists(_.startsWith(s"${Defines.Function}<")) => + true + case _ => false + } + } + private def astForGenericSpecializationExprSyntax(node: GenericSpecializationExprSyntax): Ast = { astForNode(node.expression) } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala index 09f39729eabe..66464b9b6a57 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala @@ -207,6 +207,13 @@ trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstC protected def createFunctionTypeAndTypeDecl(node: SwiftNode, methodNode: NewMethod): Unit = { registerType(methodNode.fullName) + + val (inherits, bindingName) = if (node.isInstanceOf[ClosureExprSyntax]) { + val inheritsFunctionFullName = s"${Defines.Function}<${methodNode.signature}>" + registerType(inheritsFunctionFullName) + (Seq(inheritsFunctionFullName), Defines.ClosureApplyMethodName) + } else (Seq.empty, methodNode.name) + val (astParentType, astParentFullName) = astParentInfo() val methodTypeDeclNode = typeDeclNode( node, @@ -214,15 +221,16 @@ trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstC methodNode.fullName, methodNode.filename, methodNode.fullName, - astParentType, - astParentFullName + astParentType = astParentType, + astParentFullName = astParentFullName, + inherits = inherits ) methodNode.astParentFullName = astParentFullName methodNode.astParentType = astParentType val functionBinding = NewBinding() - .name(methodNode.name) + .name(bindingName) .methodFullName(methodNode.fullName) .signature(methodNode.signature) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala index bc9d25f854dc..105bd36ff58f 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala @@ -15,7 +15,7 @@ object SwiftTypeNodePass { override def fullToShortName(typeName: String): String = { typeName match { case name if name.endsWith(NamespaceTraversal.globalNamespaceName) => NamespaceTraversal.globalNamespaceName - case _ => typeName.split('.').lastOption.getOrElse(typeName) + case _ => typeName.takeWhile(_ != '<').split('.').lastOption.getOrElse(typeName) } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/FullnameProvider.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/FullnameProvider.scala index 38d7928f8597..e49aea12791a 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/FullnameProvider.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/FullnameProvider.scala @@ -20,7 +20,12 @@ private object FullnameProvider { } // TODO: provide the actual mapping from SwiftNode.toString (nodeKind) to ResolvedTypeInfo.nodeKind - private val NodeKindMapping = Map("DeclReferenceExprSyntax" -> "type_expr") + private val NodeKindMapping = Map( + "DeclReferenceExprSyntax" -> "type_expr", + "VariableDeclSyntax" -> "var_decl", + "PatternBindingSyntax" -> "var_decl", + "IdentifierPatternSyntax" -> "var_decl" + ) } /** Provides functionality to resolve and retrieve fullnames for Swift types and declarations. Uses a type mapping to @@ -92,6 +97,12 @@ class FullnameProvider(typeMap: SwiftFileLocalTypeMapping) { * An optional String containing the type fullname if found */ protected def typeFullname(range: (Int, Int), nodeKind: String): Option[String] = { + fullName(range, FullnameProvider.Kind.Type, nodeKind).map(AstCreatorHelper.cleanType) + } + + /** Same as FullnameProvider.typeFullname but does no type name sanitation. + */ + protected def typeFullnameRaw(range: (Int, Int), nodeKind: String): Option[String] = { fullName(range, FullnameProvider.Kind.Type, nodeKind).map(AstCreatorHelper.cleanName) } @@ -124,6 +135,16 @@ class FullnameProvider(typeMap: SwiftFileLocalTypeMapping) { } } + /** Same as FullnameProvider.typeFullname but does no type name sanitation. + */ + def typeFullnameRaw(node: SwiftNode): Option[String] = { + if (typeMap.isEmpty) return None + (node.startOffset, node.endOffset) match { + case (Some(start), Some(end)) => typeFullnameRaw((start, end), node.toString) + case _ => None + } + } + /** Retrieves the declaration fullname for a given Swift node. Extracts the start and end offsets from the node if * available. Returns None if typeMap is empty. * diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AsyncTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AsyncTests.scala index 090478515ea0..adc124b0cc31 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AsyncTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AsyncTests.scala @@ -29,10 +29,10 @@ class AsyncTests extends AstSwiftSrc2CpgSuite { val cpg = code("func asyncGlobal3(fn: () throws -> Int) async rethrows { }") val List(asyncGlobal3) = cpg.method.internal.nameNot(NamespaceTraversal.globalNamespaceName).l asyncGlobal3.name shouldBe "asyncGlobal3" - asyncGlobal3.fullName shouldBe s"Test0.swift:.asyncGlobal3:(fn:${Defines.Function})->ANY" + asyncGlobal3.fullName shouldBe s"Test0.swift:.asyncGlobal3:(fn:Swift.Function<()throws->Int>)->ANY" val List(fn) = asyncGlobal3.parameter.l fn.name shouldBe "fn" - fn.typeFullName shouldBe Defines.Function + fn.typeFullName shouldBe "Swift.Function<()throws->Int>" } "testAsync4" in { diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ClosureWithCompilerTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ClosureWithCompilerTests.scala new file mode 100644 index 000000000000..41ea8411e9c3 --- /dev/null +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/ClosureWithCompilerTests.scala @@ -0,0 +1,273 @@ +package io.joern.swiftsrc2cpg.passes.ast + +import io.joern.swiftsrc2cpg.testfixtures.SwiftCompilerSrc2CpgSuite +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* + +class ClosureWithCompilerTests extends SwiftCompilerSrc2CpgSuite { + + "ClosureWithCompilerTests" should { + + "create type decls and bindings correctly (local closure declaration)" in { + // We test all closures here in one example because running the Swift compiler + // on Windows machines tend to be quite slow. Splitting them into individual test cases + // just results on awfully long test-code-test-repeat cycles. + val testCode = + """ + |import Foundation + | + |func main() { + | // closure call is allways "function_ref": "single_apply" + | + | let compare = { (s1: String, s2: String) -> Bool in + | return s1 > s2 + | } + | let compareResult = compare("1", "2") + | + | var customersInLine = ["Chris", "Alex", "Ewa", "Barry", "Daniella"] + | // auto-closure example 1: + | let customerProvider = { customersInLine.remove(at: 0) } + | let customerProviderResult = customerProvider() + | + | // auto-closure example 2: + | let greet = { + | print("Hello, World!") + | } + | greet() + | + | let greetUser = { (name: String) in + | print("Hey there, \(name).") + | } + | greetUser("Alex") + | + | let findSquare = { (num: Int) -> (Int) in + | let square = num * num + | return square + | } + | let findSquareResult = findSquare(5) + |}""".stripMargin + + val cpg = codeWithSwiftSetup(testCode) + + // compare ((Swift.String,Swift.String)->Swift.Bool) + val List(compareLocal) = cpg.local.nameExact("compare").l + compareLocal.typeFullName shouldBe "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + + val List(compareResultLocal) = cpg.local.nameExact("compareResult").l + compareResultLocal.typeFullName shouldBe "Swift.Bool" + + val compareClosureFullName = "Sources/main.swift:.main.0:(Swift.String,Swift.String)->Swift.Bool" + val List(compareClosure) = cpg.method.fullNameExact(compareClosureFullName).l + val List(compareClosureTypeDecl) = cpg.typeDecl.fullNameExact(compareClosureFullName).l + compareClosureTypeDecl.inheritsFromTypeFullName.l shouldBe List( + "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + ) + val List(compareClosureBinding) = compareClosureTypeDecl.bindsOut.l + compareClosureBinding.name shouldBe "single_apply" + compareClosureBinding.methodFullName shouldBe compareClosureFullName + compareClosureBinding.signature shouldBe "(Swift.String,Swift.String)->Swift.Bool" + + val List(compareClosureCall) = cpg.call.codeExact("""compare("1", "2")""").l + compareClosureCall.name shouldBe "single_apply" + compareClosure.signature shouldBe "(Swift.String,Swift.String)->Swift.Bool" + compareClosureCall.methodFullName shouldBe "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>.single_apply:(Swift.String,Swift.String)->Swift.Bool" + compareClosureCall.receiver.isIdentifier.name.l shouldBe List("compare") + compareClosureCall.receiver.isIdentifier.typeFullName.l shouldBe List( + "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + ) + compareClosureCall.argument(1).code shouldBe """"1"""" + compareClosureCall.argument(2).code shouldBe """"2"""" + + // customerProvider (no-arg -> Swift.String) + val List(customerProviderLocal) = cpg.local.nameExact("customerProvider").l + customerProviderLocal.typeFullName shouldBe "Swift.Function<()->Swift.String>" + + val List(customerProviderResultLocal) = cpg.local.nameExact("customerProviderResult").l + customerProviderResultLocal.typeFullName shouldBe "Swift.String" + + val customerProviderClosureFullName = "Sources/main.swift:.main.1:()->Swift.String" + val List(customerProviderClosure) = cpg.method.fullNameExact(customerProviderClosureFullName).l + val List(customerProviderTypeDecl) = cpg.typeDecl.fullNameExact(customerProviderClosureFullName).l + customerProviderTypeDecl.inheritsFromTypeFullName.l shouldBe List("Swift.Function<()->Swift.String>") + val List(customerProviderBinding) = customerProviderTypeDecl.bindsOut.l + customerProviderBinding.name shouldBe "single_apply" + customerProviderBinding.methodFullName shouldBe customerProviderClosureFullName + customerProviderBinding.signature shouldBe "()->Swift.String" + + val List(customerProviderCall) = cpg.call.codeExact("""customerProvider()""").l + customerProviderCall.name shouldBe "single_apply" + customerProviderClosure.signature shouldBe "()->Swift.String" + customerProviderCall.methodFullName shouldBe "Swift.Function<()->Swift.String>.single_apply:()->Swift.String" + customerProviderCall.receiver.isIdentifier.name.l shouldBe List("customerProvider") + customerProviderCall.receiver.isIdentifier.typeFullName.l shouldBe List("Swift.Function<()->Swift.String>") + + // greet (no-arg -> ()) + val List(greetLocal) = cpg.local.nameExact("greet").l + greetLocal.typeFullName shouldBe "Swift.Function<()->()>" + + val greetClosureFullName = "Sources/main.swift:.main.2:()->()" + val List(greetClosure) = cpg.method.fullNameExact(greetClosureFullName).l + val List(greetTypeDecl) = cpg.typeDecl.fullNameExact(greetClosureFullName).l + greetTypeDecl.inheritsFromTypeFullName.l shouldBe List("Swift.Function<()->()>") + val List(greetBinding) = greetTypeDecl.bindsOut.l + greetBinding.name shouldBe "single_apply" + greetBinding.methodFullName shouldBe greetClosureFullName + greetBinding.signature shouldBe "()->()" + + val List(greetCall) = cpg.call.codeExact("""greet()""").l + greetCall.name shouldBe "single_apply" + greetClosure.signature shouldBe "()->()" + greetCall.methodFullName shouldBe "Swift.Function<()->()>.single_apply:()->()" + greetCall.receiver.isIdentifier.name.l shouldBe List("greet") + greetCall.receiver.isIdentifier.typeFullName.l shouldBe List("Swift.Function<()->()>") + + // greetUser (Swift.String -> ()) + val List(greetUserLocal) = cpg.local.nameExact("greetUser").l + greetUserLocal.typeFullName shouldBe "Swift.Function<(Swift.String)->()>" + + val greetUserClosureFullName = "Sources/main.swift:.main.3:(Swift.String)->()" + val List(greetUserClosure) = cpg.method.fullNameExact(greetUserClosureFullName).l + val List(greetUserTypeDecl) = cpg.typeDecl.fullNameExact(greetUserClosureFullName).l + greetUserTypeDecl.inheritsFromTypeFullName.l shouldBe List("Swift.Function<(Swift.String)->()>") + val List(greetUserBinding) = greetUserTypeDecl.bindsOut.l + greetUserBinding.name shouldBe "single_apply" + greetUserBinding.methodFullName shouldBe greetUserClosureFullName + greetUserBinding.signature shouldBe "(Swift.String)->()" + + val List(greetUserCall) = cpg.call.codeExact("""greetUser("Alex")""").l + greetUserCall.name shouldBe "single_apply" + greetUserClosure.signature shouldBe "(Swift.String)->()" + greetUserCall.methodFullName shouldBe "Swift.Function<(Swift.String)->()>.single_apply:(Swift.String)->()" + greetUserCall.receiver.isIdentifier.name.l shouldBe List("greetUser") + greetUserCall.receiver.isIdentifier.typeFullName.l shouldBe List("Swift.Function<(Swift.String)->()>") + greetUserCall.argument(1).code shouldBe """"Alex"""" + + // findSquare (Swift.Int -> Swift.Int) + val List(findSquareLocal) = cpg.local.nameExact("findSquare").l + findSquareLocal.typeFullName shouldBe "Swift.Function<(Swift.Int)->Swift.Int>" + + val List(findSquareResultLocal) = cpg.local.nameExact("findSquareResult").l + findSquareResultLocal.typeFullName shouldBe "Swift.Int" + + val findSquareClosureFullName = "Sources/main.swift:.main.4:(Swift.Int)->Swift.Int" + val List(findSquareClosure) = cpg.method.fullNameExact(findSquareClosureFullName).l + val List(findSquareTypeDecl) = cpg.typeDecl.fullNameExact(findSquareClosureFullName).l + findSquareTypeDecl.inheritsFromTypeFullName.l shouldBe List("Swift.Function<(Swift.Int)->Swift.Int>") + val List(findSquareBinding) = findSquareTypeDecl.bindsOut.l + findSquareBinding.name shouldBe "single_apply" + findSquareBinding.methodFullName shouldBe findSquareClosureFullName + findSquareBinding.signature shouldBe "(Swift.Int)->Swift.Int" + + val List(findSquareCall) = cpg.call.codeExact("""findSquare(5)""").l + findSquareCall.name shouldBe "single_apply" + findSquareClosure.signature shouldBe "(Swift.Int)->Swift.Int" + findSquareCall.methodFullName shouldBe "Swift.Function<(Swift.Int)->Swift.Int>.single_apply:(Swift.Int)->Swift.Int" + findSquareCall.receiver.isIdentifier.name.l shouldBe List("findSquare") + findSquareCall.receiver.isIdentifier.typeFullName.l shouldBe List("Swift.Function<(Swift.Int)->Swift.Int>") + findSquareCall.argument(1).code shouldBe "5" + } + + "create type decls and bindings correctly (class variable closure declaration)" in { + val testCode = + """ + |import Foundation + | + |class Foo { + | var compare = { (s1: String, s2: String) -> Bool in + | return s1 > s2 + | } + | + | func main() { + | let compareResult = compare("1", "2") + | } + |}""".stripMargin + + val cpg = codeWithSwiftSetup(testCode) + + val List(compareClassLocal, compareFunctionLocal) = cpg.local.nameExact("compare").l + compareClassLocal.typeFullName shouldBe "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + compareFunctionLocal.typeFullName shouldBe "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + + val List(compareResultLocal) = cpg.local.nameExact("compareResult").l + compareResultLocal.typeFullName shouldBe "Swift.Bool" + + val compareClosureFullName = "Sources/main.swift:.Foo.0:(Swift.String,Swift.String)->Swift.Bool" + val List(compareClosure) = cpg.method.fullNameExact(compareClosureFullName).l + val List(compareClosureTypeDecl) = cpg.typeDecl.fullNameExact(compareClosureFullName).l + compareClosureTypeDecl.inheritsFromTypeFullName.l shouldBe List( + "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + ) + val List(compareClosureBinding) = compareClosureTypeDecl.bindsOut.l + compareClosureBinding.name shouldBe "single_apply" + compareClosureBinding.methodFullName shouldBe compareClosureFullName + compareClosureBinding.signature shouldBe "(Swift.String,Swift.String)->Swift.Bool" + + val List(compareClosureCall) = cpg.call.codeExact("""compare("1", "2")""").l + compareClosureCall.name shouldBe "single_apply" + compareClosure.signature shouldBe "(Swift.String,Swift.String)->Swift.Bool" + compareClosureCall.methodFullName shouldBe "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>.single_apply:(Swift.String,Swift.String)->Swift.Bool" + compareClosureCall.receiver.isIdentifier.name.l shouldBe List("compare") + compareClosureCall.receiver.isIdentifier.typeFullName.l shouldBe List( + "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + ) + compareClosureCall.argument(1).code shouldBe """"1"""" + compareClosureCall.argument(2).code shouldBe """"2"""" + } + + "create type decls and bindings correctly (closure as function parameter)" in { + val testCode = + """ + |import Foundation + | + |func runCompare(compare:(String, String) -> Bool) { + | let compareResult = compare("1", "2") + |} + | + |func main() { + | let compareFunc = { (s1: String, s2: String) -> Bool in + | return s1 > s2 + | } + | runCompare(compare: compareFunc) + |}""".stripMargin + + val cpg = codeWithSwiftSetup(testCode) + + val List(runCompareCall) = cpg.call.nameExact("runCompare").l + runCompareCall.methodFullName shouldBe "SwiftTest.runCompare:(compare:(Swift.String,Swift.String)->Swift.Bool)->()" + + val List(runCompareMethod) = cpg.method.nameExact("runCompare").l + runCompareMethod.fullName shouldBe "SwiftTest.runCompare:(compare:(Swift.String,Swift.String)->Swift.Bool)->()" + + val List(compareFuncLocal) = cpg.local.nameExact("compareFunc").l + compareFuncLocal.typeFullName shouldBe "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + + val List(compareResultLocal) = cpg.local.nameExact("compareResult").l + compareResultLocal.typeFullName shouldBe "Swift.Bool" + + val compareClosureFullName = "Sources/main.swift:.main.0:(Swift.String,Swift.String)->Swift.Bool" + val List(compareClosure) = cpg.method.fullNameExact(compareClosureFullName).l + val List(compareClosureTypeDecl) = cpg.typeDecl.fullNameExact(compareClosureFullName).l + compareClosureTypeDecl.inheritsFromTypeFullName.l shouldBe List( + "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + ) + val List(compareClosureBinding) = compareClosureTypeDecl.bindsOut.l + compareClosureBinding.name shouldBe "single_apply" + compareClosureBinding.methodFullName shouldBe compareClosureFullName + compareClosureBinding.signature shouldBe "(Swift.String,Swift.String)->Swift.Bool" + + val List(compareClosureCall) = cpg.call.codeExact("""compare("1", "2")""").l + compareClosureCall.name shouldBe "single_apply" + compareClosure.signature shouldBe "(Swift.String,Swift.String)->Swift.Bool" + compareClosureCall.methodFullName shouldBe "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>.single_apply:(Swift.String,Swift.String)->Swift.Bool" + compareClosureCall.receiver.isIdentifier.name.l shouldBe List("compare") + compareClosureCall.receiver.isIdentifier.typeFullName.l shouldBe List( + "Swift.Function<(Swift.String,Swift.String)->Swift.Bool>" + ) + compareClosureCall.argument(1).code shouldBe """"1"""" + compareClosureCall.argument(2).code shouldBe """"2"""" + } + + } + +} diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/InitDeinitTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/InitDeinitTests.scala index 4a8a06cc824e..5dc11343f8e3 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/InitDeinitTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/InitDeinitTests.scala @@ -23,13 +23,9 @@ class InitDeinitTests extends SwiftCompilerSrc2CpgSuite { val List(constructorB) = compilerCpg.method.isConstructor.l constructorA.methodReturn.typeFullName shouldBe "Sources/main.swift:.FooStructConstructorA" - constructorA.methodReturn.dynamicTypeHintFullName shouldBe Seq( - "Sources/main.swift:.FooStructConstructorA" - ) constructorA.fullName shouldBe "Sources/main.swift:.FooStructConstructorA.init:()->Sources/main.swift:.FooStructConstructorA" constructorB.methodReturn.typeFullName shouldBe "SwiftTest.FooStructConstructorA" - constructorB.methodReturn.dynamicTypeHintFullName shouldBe Seq("SwiftTest.FooStructConstructorA") constructorB.fullName shouldBe "SwiftTest.FooStructConstructorA.init:()->SwiftTest.FooStructConstructorA" val List(paramA) = constructorA.parameter.l @@ -58,13 +54,9 @@ class InitDeinitTests extends SwiftCompilerSrc2CpgSuite { val List(constructorB) = compilerCpg.method.isConstructor.l constructorA.methodReturn.typeFullName shouldBe "Sources/main.swift:.FooStructConstructorA" - constructorA.methodReturn.dynamicTypeHintFullName shouldBe Seq( - "Sources/main.swift:.FooStructConstructorA" - ) constructorA.fullName shouldBe "Sources/main.swift:.FooStructConstructorA.init:()->Sources/main.swift:.FooStructConstructorA" constructorB.methodReturn.typeFullName shouldBe "Swift.Optional" - constructorB.methodReturn.dynamicTypeHintFullName shouldBe Seq("Swift.Optional") constructorB.fullName shouldBe "SwiftTest.FooStructConstructorA.init:()->Swift.Optional" val List(paramA) = constructorA.parameter.l @@ -121,11 +113,9 @@ class InitDeinitTests extends SwiftCompilerSrc2CpgSuite { val List(constructorB) = compilerCpg.method.isConstructor.l constructorA.methodReturn.typeFullName shouldBe "Sources/main.swift:.BarUnion" - constructorA.methodReturn.dynamicTypeHintFullName shouldBe Seq("Sources/main.swift:.BarUnion") constructorA.fullName shouldBe "Sources/main.swift:.BarUnion.init:()->Sources/main.swift:.BarUnion" constructorB.methodReturn.typeFullName shouldBe "SwiftTest.BarUnion" - constructorB.methodReturn.dynamicTypeHintFullName shouldBe Seq("SwiftTest.BarUnion") constructorB.fullName shouldBe "SwiftTest.BarUnion.init:()->SwiftTest.BarUnion" val List(paramA) = constructorA.parameter.l @@ -155,11 +145,9 @@ class InitDeinitTests extends SwiftCompilerSrc2CpgSuite { val List(constructorB) = compilerCpg.method.isConstructor.l constructorA.methodReturn.typeFullName shouldBe "Sources/main.swift:.BarClass" - constructorA.methodReturn.dynamicTypeHintFullName shouldBe Seq("Sources/main.swift:.BarClass") constructorA.fullName shouldBe "Sources/main.swift:.BarClass.init:()->Sources/main.swift:.BarClass" constructorB.methodReturn.typeFullName shouldBe "SwiftTest.BarClass" - constructorB.methodReturn.dynamicTypeHintFullName shouldBe Seq("SwiftTest.BarClass") constructorB.fullName shouldBe "SwiftTest.BarClass.init:()->SwiftTest.BarClass" val List(paramA) = constructorA.parameter.l @@ -208,11 +196,9 @@ class InitDeinitTests extends SwiftCompilerSrc2CpgSuite { val List(constructorB) = compilerCpg.method.isConstructor.l constructorA.methodReturn.typeFullName shouldBe "Sources/main.swift:.BarClass" - constructorA.methodReturn.dynamicTypeHintFullName shouldBe Seq("Sources/main.swift:.BarClass") constructorA.fullName shouldBe "Sources/main.swift:.BarClass.init:(a:Swift.Int)->Sources/main.swift:.BarClass" constructorB.methodReturn.typeFullName shouldBe "SwiftTest.BarClass" - constructorB.methodReturn.dynamicTypeHintFullName shouldBe Seq("SwiftTest.BarClass") constructorB.fullName shouldBe "SwiftTest.BarClass.init:(a:Swift.Int)->SwiftTest.BarClass" val List(paramA, xA) = constructorA.parameter.l @@ -251,7 +237,6 @@ class InitDeinitTests extends SwiftCompilerSrc2CpgSuite { val List(constructorB) = compilerCpg.method.isConstructor.l constructorA.methodReturn.typeFullName shouldBe "Sources/main.swift:.BarProtocol" - constructorA.methodReturn.dynamicTypeHintFullName shouldBe Seq("Sources/main.swift:.BarProtocol") constructorA.fullName shouldBe "Sources/main.swift:.BarProtocol.init:()->Sources/main.swift:.BarProtocol" // Swift initializers conceptually return `Self`. For a protocol requirement, `Self` is not a concrete type; @@ -263,7 +248,6 @@ class InitDeinitTests extends SwiftCompilerSrc2CpgSuite { // but a protocol initializer constructs the concrete conforming type, not the existential. // Hence, the signature is `()->A`, matching Swift’s `init` => `Self`. constructorB.methodReturn.typeFullName shouldBe "A" - constructorB.methodReturn.dynamicTypeHintFullName shouldBe Seq("A") constructorB.fullName shouldBe "SwiftTest.BarProtocol.init:()->A" val List(paramA) = constructorA.parameter.l diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/ConstClosurePass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/ConstClosurePass.scala index 64fddba9ef77..5f42d6ecbf32 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/ConstClosurePass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/ConstClosurePass.scala @@ -1,7 +1,7 @@ package io.joern.x2cpg.frontendspecific.swiftsrc2cpg +import io.shiftleft.codepropertygraph.generated.nodes.{Method, MethodRef} import io.shiftleft.codepropertygraph.generated.{Cpg, PropertyNames} -import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Method, MethodRef} import io.shiftleft.passes.CpgPass import io.shiftleft.semanticcpg.language.* @@ -11,7 +11,7 @@ class ConstClosurePass(cpg: Cpg) extends CpgPass(cpg) { // Keeps track of how many times an identifier has been on the LHS of an assignment, by name private lazy val identifiersAssignedCount: Map[String, Int] = - cpg.assignment.target.collectAll[Identifier].name.groupCount + cpg.assignment.target.isIdentifier.name.groupCount override def run(diffGraph: DiffGraphBuilder): Unit = { handleConstClosures(diffGraph) @@ -19,11 +19,17 @@ class ConstClosurePass(cpg: Cpg) extends CpgPass(cpg) { handleClosuresAssignedToMutableVar(diffGraph) } + private def qualifies(methodRef: MethodRef): Boolean = { + // We only want this pass to handle closures where the frontend itself + // was not able to generate the correct calls, typedecls, and bindings itself. + cpg.typeDecl.fullNameExact(methodRef.methodFullName).bindsOut.nameExact(Defines.ClosureApplyMethodName).isEmpty + } + private def handleConstClosures(diffGraph: DiffGraphBuilder): Unit = for { assignment <- cpg.assignment - name <- assignment.filter(_.code.startsWith("let ")).target.isIdentifier.name - methodRef <- assignment.start.source.isMethodRef + name <- assignment.start.code("^let .*").target.isIdentifier.name + methodRef <- assignment.start.source.isMethodRef if qualifies(methodRef) method <- methodRef.referencedMethod enclosingMethod <- assignment.start.method.fullName } { @@ -41,7 +47,7 @@ class ConstClosurePass(cpg: Cpg) extends CpgPass(cpg) { .isFieldIdentifier .canonicalName .l - methodRef <- assignment.start.source.ast.isMethodRef + methodRef <- assignment.start.source.ast.isMethodRef if qualifies(methodRef) method <- methodRef.referencedMethod enclosingMethod <- assignment.start.method.fullName } { @@ -52,8 +58,8 @@ class ConstClosurePass(cpg: Cpg) extends CpgPass(cpg) { // Handle closures assigned to mutable variables for { assignment <- cpg.assignment - name <- assignment.start.code("^(var|let) .*").target.isIdentifier.name - methodRef <- assignment.start.source.ast.isMethodRef + name <- assignment.start.code("^var .*").target.isIdentifier.name + methodRef <- assignment.start.source.isMethodRef if qualifies(methodRef) method <- methodRef.referencedMethod enclosingMethod <- assignment.start.method.fullName } { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/Defines.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/Defines.scala index 4f14e2043c36..d8811591a04f 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/Defines.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/frontendspecific/swiftsrc2cpg/Defines.scala @@ -8,22 +8,23 @@ object Defines { private val logger: Logger = LoggerFactory.getLogger(this.getClass) - val Any: String = "ANY" - val Character: String = "Swift.Character" - val String: String = "Swift.String" - val Int: String = "Swift.Int" - val Float: String = "Swift.Float" - val Double: String = "Swift.Double" - val Bool: String = "Swift.Bool" - val Function: String = "Swift.Function" - val Array: String = "Swift.Array" - val Dictionary: String = "Swift.Dictionary" - val Nil: String = "Swift.Nil" - val Iterator: String = "Swift.Iterator" - val Void: String = "()" - val ConstructorMethodName: String = "init" - val DuplicateSuffix: String = "" - val GlobalNamespace: String = NamespaceTraversal.globalNamespaceName + val Any: String = "ANY" + val Character: String = "Swift.Character" + val String: String = "Swift.String" + val Int: String = "Swift.Int" + val Float: String = "Swift.Float" + val Double: String = "Swift.Double" + val Bool: String = "Swift.Bool" + val Function: String = "Swift.Function" + val Array: String = "Swift.Array" + val Dictionary: String = "Swift.Dictionary" + val Nil: String = "Swift.Nil" + val Iterator: String = "Swift.Iterator" + val Void: String = "()" + val ConstructorMethodName: String = "init" + val ClosureApplyMethodName: String = "single_apply" + val DuplicateSuffix: String = "" + val GlobalNamespace: String = NamespaceTraversal.globalNamespaceName val SwiftTypes: List[String] = List(Any, Nil, Character, String, Int, Float, Double, Bool, Function, Array, Dictionary, Iterator, Void) @@ -41,8 +42,8 @@ object Defines { } val PrefixOperatorMap: Map[String, String] = Map( - "-" -> Operators.preDecrement, - "+" -> Operators.preIncrement, + "-" -> Operators.minus, + "+" -> Operators.plus, "~" -> Operators.not, "!" -> Operators.logicalNot, "..<" -> Operators.lessThan,