Skip to content

Commit 25b928d

Browse files
committed
Rust: Type inference for .await expressions
1 parent 988e9c6 commit 25b928d

File tree

6 files changed

+230
-15
lines changed

6 files changed

+230
-15
lines changed

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,19 @@ class ResultEnum extends Enum {
4949
/** Gets the `Err` variant. */
5050
Variant getErr() { result = this.getVariant("Err") }
5151
}
52+
53+
/**
54+
* The [`Future` trait][1].
55+
*
56+
* [1]: https://doc.rust-lang.org/std/future/trait.Future.html
57+
*/
58+
class FutureTrait extends Trait {
59+
FutureTrait() { this.getCanonicalPath() = "core::future::future::Future" }
60+
61+
/** Gets the `Output` associated type. */
62+
pragma[nomagic]
63+
TypeAlias getOutputType() {
64+
result = this.getAssocItemList().getAnAssocItem() and
65+
result.getName().getText() = "Output"
66+
}
67+
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@ newtype TType =
1515
TTrait(Trait t) or
1616
TArrayType() or // todo: add size?
1717
TRefType() or // todo: add mut?
18+
TImplTraitType(int bounds) {
19+
bounds = any(ImplTraitTypeRepr impl).getTypeBoundList().getNumberOfBounds()
20+
} or
1821
TTypeParamTypeParameter(TypeParam t) or
1922
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2023
TRefTypeParameter() or
21-
TSelfTypeParameter(Trait t)
24+
TSelfTypeParameter(Trait t) or
25+
TImplTraitTypeParameter(ImplTraitType t, int i) { i in [0 .. t.getNumberOfBounds() - 1] }
2226

2327
/**
2428
* A type without type arguments.
@@ -115,6 +119,9 @@ class TraitType extends Type, TTrait {
115119

116120
TraitType() { this = TTrait(trait) }
117121

122+
/** Gets the underlying trait. */
123+
Trait getTrait() { result = trait }
124+
118125
override StructField getStructField(string name) { none() }
119126

120127
override TupleField getTupleField(int i) { none() }
@@ -176,6 +183,33 @@ class RefType extends Type, TRefType {
176183
override Location getLocation() { result instanceof EmptyLocation }
177184
}
178185

186+
/**
187+
* An [`impl Trait`][1] type.
188+
*
189+
* We represent `impl Trait` types as generic types with as many type parameters
190+
* as there are bounds.
191+
*
192+
* [1] https://doc.rust-lang.org/book/ch10-02-traits.html#traits-as-parameters
193+
*/
194+
class ImplTraitType extends Type, TImplTraitType {
195+
private int bounds;
196+
197+
ImplTraitType() { this = TImplTraitType(bounds) }
198+
199+
/** Gets the number of bounds of this `impl Trait` type. */
200+
int getNumberOfBounds() { result = bounds }
201+
202+
override StructField getStructField(string name) { none() }
203+
204+
override TupleField getTupleField(int i) { none() }
205+
206+
override TypeParameter getTypeParameter(int i) { result = TImplTraitTypeParameter(this, i) }
207+
208+
override string toString() { result = "impl Trait ..." }
209+
210+
override Location getLocation() { result instanceof EmptyLocation }
211+
}
212+
179213
/** A type parameter. */
180214
abstract class TypeParameter extends Type {
181215
override StructField getStructField(string name) { none() }
@@ -281,6 +315,26 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
281315
override Location getLocation() { result = trait.getLocation() }
282316
}
283317

318+
/**
319+
* An `impl Trait` type parameter.
320+
*/
321+
class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
322+
private ImplTraitType implTraitType;
323+
private int i;
324+
325+
ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTraitType, i) }
326+
327+
/** Gets the `impl Trait` type that this parameter belongs to. */
328+
ImplTraitType getImplTraitType() { result = implTraitType }
329+
330+
/** Gets the index of this type parameter. */
331+
int getIndex() { result = i }
332+
333+
override string toString() { result = "impl Trait<" + i.toString() + ">" }
334+
335+
override Location getLocation() { result instanceof EmptyLocation }
336+
}
337+
284338
/**
285339
* A type abstraction. I.e., a place in the program where type variables are
286340
* introduced.

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ private module Input1 implements InputSig1<Location> {
7777
apos.asMethodTypeArgumentPosition() = ppos.asTypeParam().getPosition()
7878
}
7979

80+
private int getImplTraitTypeParameterId(ImplTraitTypeParameter tp) {
81+
tp =
82+
rank[result](ImplTraitTypeParameter tp0, int bounds, int i |
83+
bounds = tp0.getImplTraitType().getNumberOfBounds() and
84+
i = tp0.getIndex()
85+
|
86+
tp0 order by bounds, i
87+
)
88+
}
89+
8090
int getTypeParameterId(TypeParameter tp) {
8191
tp =
8292
rank[result](TypeParameter tp0, int kind, int id |
@@ -90,6 +100,9 @@ private module Input1 implements InputSig1<Location> {
90100
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
91101
node = tp0.(SelfTypeParameter).getTrait()
92102
)
103+
or
104+
kind = 2 and
105+
id = getImplTraitTypeParameterId(tp0)
93106
|
94107
tp0 order by kind, id
95108
)
@@ -232,8 +245,12 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath
232245
n1 = n2.(ParenExpr).getExpr() and
233246
path1 = path2
234247
or
235-
n1 = n2.(BlockExpr).getStmtList().getTailExpr() and
236-
path1 = path2
248+
n2 =
249+
any(BlockExpr be |
250+
not be.isAsync() and
251+
n1 = be.getStmtList().getTailExpr() and
252+
path1 = path2
253+
)
237254
or
238255
n1 = n2.(IfExpr).getABranch() and
239256
path1 = path2
@@ -998,6 +1015,29 @@ private StructType inferLiteralType(LiteralExpr le) {
9981015
)
9991016
}
10001017

1018+
pragma[nomagic]
1019+
private AssociatedTypeTypeParameter getFutureOutputTypeParameter() {
1020+
result.getTypeAlias() = any(FutureTrait ft).getOutputType()
1021+
}
1022+
1023+
pragma[nomagic]
1024+
private Type inferAwaitExprType(AwaitExpr ae, TypePath path) {
1025+
exists(TypePath exprPath | result = inferType(ae.getExpr(), exprPath) |
1026+
exprPath
1027+
.isCons(TImplTraitTypeParameter(_, _),
1028+
any(TypePath path0 | path0.isCons(getFutureOutputTypeParameter(), path)))
1029+
or
1030+
path = exprPath and
1031+
not (
1032+
exprPath = TypePath::singleton(TImplTraitTypeParameter(_, _)) and
1033+
result.(TraitType).getTrait() instanceof FutureTrait
1034+
) and
1035+
not exprPath
1036+
.isCons(TImplTraitTypeParameter(_, _),
1037+
any(TypePath path0 | path0.isCons(getFutureOutputTypeParameter(), _)))
1038+
)
1039+
}
1040+
10011041
private module MethodCall {
10021042
/** An expression that calls a method. */
10031043
abstract private class MethodCallImpl extends Expr {
@@ -1087,12 +1127,17 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
10871127
}
10881128

10891129
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1130+
pragma[nomagic]
1131+
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1132+
rootType = mc.getTypeAt(TypePath::nil()) and
1133+
name = mc.getMethodName() and
1134+
arity = mc.getArity()
1135+
}
1136+
10901137
pragma[nomagic]
10911138
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
10921139
exists(Type rootType, string name, int arity |
1093-
rootType = mc.getTypeAt(TypePath::nil()) and
1094-
name = mc.getMethodName() and
1095-
arity = mc.getArity() and
1140+
isMethodCall(mc, rootType, name, arity) and
10961141
constraint = impl.(ImplTypeAbstraction).getSelfTy()
10971142
|
10981143
methodCandidateTrait(rootType, mc.getTrait(), name, arity, impl)
@@ -1129,6 +1174,12 @@ private Function getMethodFromImpl(MethodCall mc) {
11291174
)
11301175
}
11311176

1177+
bindingset[trait, name]
1178+
pragma[inline_late]
1179+
private Function getTraitMethod(TraitType trait, string name) {
1180+
result = getMethodSuccessor(trait.getTrait(), name)
1181+
}
1182+
11321183
/**
11331184
* Gets a method that the method call `mc` resolves to based on type inference,
11341185
* if any.
@@ -1140,6 +1191,11 @@ private Function inferMethodCallTarget(MethodCall mc) {
11401191
// The type of the receiver is a type parameter and the method comes from a
11411192
// trait bound on the type parameter.
11421193
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1194+
or
1195+
// The type of the receiver is an `impl Trait` type.
1196+
result =
1197+
getTraitMethod(mc.getTypeAt(TypePath::singleton(TImplTraitTypeParameter(_, _))),
1198+
mc.getMethodName())
11431199
}
11441200

11451201
cached
@@ -1315,6 +1371,8 @@ private module Cached {
13151371
or
13161372
result = inferLiteralType(n) and
13171373
path.isEmpty()
1374+
or
1375+
result = inferAwaitExprType(n, path)
13181376
}
13191377
}
13201378

@@ -1331,7 +1389,7 @@ private module Debug {
13311389
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
13321390
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
13331391
filepath.matches("%/main.rs") and
1334-
startline = 948
1392+
startline = 1334
13351393
)
13361394
}
13371395

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ abstract class TypeMention extends AstNode {
1515

1616
/** Gets the sub mention at `path`. */
1717
pragma[nomagic]
18-
private TypeMention getMentionAt(TypePath path) {
18+
TypeMention getMentionAt(TypePath path) {
1919
path.isEmpty() and
2020
result = this
2121
or
@@ -150,6 +150,54 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
150150
not exists(resolved.(TypeAlias).getTypeRepr()) and
151151
result = super.resolveTypeAt(typePath)
152152
}
153+
154+
pragma[nomagic]
155+
private TypeAlias getResolvedTraitAlias(string name) {
156+
exists(TraitItemNode trait |
157+
trait = resolvePath(path) and
158+
result = trait.getAnAssocItem() and
159+
name = result.getName().getText()
160+
)
161+
}
162+
163+
pragma[nomagic]
164+
private TypeRepr getAssocTypeArg(string name) {
165+
exists(AssocTypeArg arg |
166+
arg = path.getSegment().getGenericArgList().getAGenericArg() and
167+
result = arg.getTypeRepr() and
168+
name = arg.getIdentifier().getText()
169+
)
170+
}
171+
172+
/** Gets the type argument for the associated type `alias`, if any. */
173+
pragma[nomagic]
174+
private TypeRepr getAnAssocTypeArgument(TypeAlias alias) {
175+
exists(string name |
176+
alias = this.getResolvedTraitAlias(name) and
177+
result = this.getAssocTypeArg(name)
178+
)
179+
}
180+
181+
override TypeMention getMentionAt(TypePath tp) {
182+
result = super.getMentionAt(tp)
183+
or
184+
exists(TypeAlias alias, AssociatedTypeTypeParameter attp, TypeMention arg, TypePath suffix |
185+
arg = this.getAnAssocTypeArgument(alias) and
186+
result = arg.getMentionAt(suffix) and
187+
tp = TypePath::cons(attp, suffix) and
188+
attp.getTypeAlias() = alias
189+
)
190+
}
191+
}
192+
193+
class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr {
194+
override TypeMention getTypeArgument(int i) {
195+
result = super.getTypeBoundList().getBound(i).getTypeRepr()
196+
}
197+
198+
override ImplTraitType resolveType() {
199+
result.getNumberOfBounds() = super.getTypeBoundList().getNumberOfBounds()
200+
}
153201
}
154202

155203
private TypeParameter pathGetTypeParameter(TypeAlias alias, int i) {

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,9 +1299,9 @@ mod async_ {
12991299
}
13001300

13011301
pub async fn f() {
1302-
f1().await.f(); // $ MISSING: method=S1f
1303-
f2().await.f(); // $ MISSING: method=S1f
1304-
f3().await.f(); // $ MISSING: method=S1f
1302+
f1().await.f(); // $ method=S1f
1303+
f2().await.f(); // $ method=S1f
1304+
f3().await.f(); // $ method=S1f
13051305
}
13061306
}
13071307

@@ -1331,8 +1331,8 @@ mod impl_trait {
13311331

13321332
pub fn f() {
13331333
let x = f1();
1334-
x.f1(); // $ MISSING: method=Trait1f1
1335-
x.f2(); // $ MISSING: method=Trait2f2
1334+
x.f1(); // $ method=Trait1f1
1335+
x.f2(); // $ method=Trait2f2
13361336
}
13371337
}
13381338

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,8 +1621,12 @@ inferType
16211621
| main.rs:1274:18:1274:21 | SelfParam | | main.rs:1271:5:1271:14 | S1 |
16221622
| main.rs:1277:25:1279:5 | { ... } | | main.rs:1271:5:1271:14 | S1 |
16231623
| main.rs:1278:9:1278:10 | S1 | | main.rs:1271:5:1271:14 | S1 |
1624-
| main.rs:1281:41:1285:5 | { ... } | | main.rs:1271:5:1271:14 | S1 |
1625-
| main.rs:1282:9:1284:9 | { ... } | | main.rs:1271:5:1271:14 | S1 |
1624+
| main.rs:1281:41:1285:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
1625+
| main.rs:1281:41:1285:5 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
1626+
| main.rs:1281:41:1285:5 | { ... } | impl Trait<0>.Output | main.rs:1271:5:1271:14 | S1 |
1627+
| main.rs:1282:9:1284:9 | { ... } | | file://:0:0:0:0 | impl Trait ... |
1628+
| main.rs:1282:9:1284:9 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
1629+
| main.rs:1282:9:1284:9 | { ... } | impl Trait<0>.Output | main.rs:1271:5:1271:14 | S1 |
16261630
| main.rs:1283:13:1283:14 | S1 | | main.rs:1271:5:1271:14 | S1 |
16271631
| main.rs:1292:17:1292:46 | SelfParam | | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/pin.rs:934:1:1104:1 | Pin |
16281632
| main.rs:1292:17:1292:46 | SelfParam | Ptr | file://:0:0:0:0 | & |
@@ -1634,9 +1638,26 @@ inferType
16341638
| main.rs:1293:13:1293:38 | ...::Ready(...) | | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/task/poll.rs:6:1:28:1 | Poll |
16351639
| main.rs:1293:13:1293:38 | ...::Ready(...) | T | main.rs:1271:5:1271:14 | S1 |
16361640
| main.rs:1293:36:1293:37 | S1 | | main.rs:1271:5:1271:14 | S1 |
1641+
| main.rs:1297:41:1299:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
16371642
| main.rs:1297:41:1299:5 | { ... } | | main.rs:1287:5:1287:14 | S2 |
1643+
| main.rs:1297:41:1299:5 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
1644+
| main.rs:1297:41:1299:5 | { ... } | impl Trait<0>.Output | main.rs:1271:5:1271:14 | S1 |
1645+
| main.rs:1298:9:1298:10 | S2 | | file://:0:0:0:0 | impl Trait ... |
16381646
| main.rs:1298:9:1298:10 | S2 | | main.rs:1287:5:1287:14 | S2 |
1647+
| main.rs:1298:9:1298:10 | S2 | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
1648+
| main.rs:1298:9:1298:10 | S2 | impl Trait<0>.Output | main.rs:1271:5:1271:14 | S1 |
16391649
| main.rs:1302:9:1302:12 | f1(...) | | main.rs:1271:5:1271:14 | S1 |
1650+
| main.rs:1302:9:1302:18 | await ... | | main.rs:1271:5:1271:14 | S1 |
1651+
| main.rs:1303:9:1303:12 | f2(...) | | file://:0:0:0:0 | impl Trait ... |
1652+
| main.rs:1303:9:1303:12 | f2(...) | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
1653+
| main.rs:1303:9:1303:12 | f2(...) | impl Trait<0>.Output | main.rs:1271:5:1271:14 | S1 |
1654+
| main.rs:1303:9:1303:18 | await ... | | file://:0:0:0:0 | impl Trait ... |
1655+
| main.rs:1303:9:1303:18 | await ... | | main.rs:1271:5:1271:14 | S1 |
1656+
| main.rs:1304:9:1304:12 | f3(...) | | file://:0:0:0:0 | impl Trait ... |
1657+
| main.rs:1304:9:1304:12 | f3(...) | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
1658+
| main.rs:1304:9:1304:12 | f3(...) | impl Trait<0>.Output | main.rs:1271:5:1271:14 | S1 |
1659+
| main.rs:1304:9:1304:18 | await ... | | file://:0:0:0:0 | impl Trait ... |
1660+
| main.rs:1304:9:1304:18 | await ... | | main.rs:1271:5:1271:14 | S1 |
16401661
| main.rs:1313:15:1313:19 | SelfParam | | file://:0:0:0:0 | & |
16411662
| main.rs:1313:15:1313:19 | SelfParam | &T | main.rs:1312:5:1314:5 | Self [trait Trait1] |
16421663
| main.rs:1317:15:1317:19 | SelfParam | | file://:0:0:0:0 | & |
@@ -1645,8 +1666,26 @@ inferType
16451666
| main.rs:1321:15:1321:19 | SelfParam | &T | main.rs:1310:5:1310:14 | S1 |
16461667
| main.rs:1325:15:1325:19 | SelfParam | | file://:0:0:0:0 | & |
16471668
| main.rs:1325:15:1325:19 | SelfParam | &T | main.rs:1310:5:1310:14 | S1 |
1669+
| main.rs:1328:37:1330:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
16481670
| main.rs:1328:37:1330:5 | { ... } | | main.rs:1310:5:1310:14 | S1 |
1671+
| main.rs:1328:37:1330:5 | { ... } | impl Trait<0> | main.rs:1312:5:1314:5 | trait Trait1 |
1672+
| main.rs:1328:37:1330:5 | { ... } | impl Trait<1> | main.rs:1316:5:1318:5 | trait Trait2 |
1673+
| main.rs:1329:9:1329:10 | S1 | | file://:0:0:0:0 | impl Trait ... |
16491674
| main.rs:1329:9:1329:10 | S1 | | main.rs:1310:5:1310:14 | S1 |
1675+
| main.rs:1329:9:1329:10 | S1 | impl Trait<0> | main.rs:1312:5:1314:5 | trait Trait1 |
1676+
| main.rs:1329:9:1329:10 | S1 | impl Trait<1> | main.rs:1316:5:1318:5 | trait Trait2 |
1677+
| main.rs:1333:13:1333:13 | x | | file://:0:0:0:0 | impl Trait ... |
1678+
| main.rs:1333:13:1333:13 | x | impl Trait<0> | main.rs:1312:5:1314:5 | trait Trait1 |
1679+
| main.rs:1333:13:1333:13 | x | impl Trait<1> | main.rs:1316:5:1318:5 | trait Trait2 |
1680+
| main.rs:1333:17:1333:20 | f1(...) | | file://:0:0:0:0 | impl Trait ... |
1681+
| main.rs:1333:17:1333:20 | f1(...) | impl Trait<0> | main.rs:1312:5:1314:5 | trait Trait1 |
1682+
| main.rs:1333:17:1333:20 | f1(...) | impl Trait<1> | main.rs:1316:5:1318:5 | trait Trait2 |
1683+
| main.rs:1334:9:1334:9 | x | | file://:0:0:0:0 | impl Trait ... |
1684+
| main.rs:1334:9:1334:9 | x | impl Trait<0> | main.rs:1312:5:1314:5 | trait Trait1 |
1685+
| main.rs:1334:9:1334:9 | x | impl Trait<1> | main.rs:1316:5:1318:5 | trait Trait2 |
1686+
| main.rs:1335:9:1335:9 | x | | file://:0:0:0:0 | impl Trait ... |
1687+
| main.rs:1335:9:1335:9 | x | impl Trait<0> | main.rs:1312:5:1314:5 | trait Trait1 |
1688+
| main.rs:1335:9:1335:9 | x | impl Trait<1> | main.rs:1316:5:1318:5 | trait Trait2 |
16501689
| main.rs:1341:5:1341:20 | ...::f(...) | | main.rs:67:5:67:21 | Foo |
16511690
| main.rs:1342:5:1342:60 | ...::g(...) | | main.rs:67:5:67:21 | Foo |
16521691
| main.rs:1342:20:1342:38 | ...::Foo {...} | | main.rs:67:5:67:21 | Foo |

0 commit comments

Comments
 (0)