Skip to content

Commit 791369d

Browse files
authored
Merge pull request #19584 from hvitved/rust/type-inference-await
Rust: Type inference for `.await` expressions
2 parents 3af10d2 + 3d395dd commit 791369d

File tree

11 files changed

+588
-24
lines changed

11 files changed

+588
-24
lines changed

rust/ql/lib/codeql/rust/elements/internal/GenericArgListImpl.qll

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,15 @@ module Impl {
3737

3838
/** Gets a type argument of this list. */
3939
TypeRepr getATypeArg() { result = this.getTypeArg(_) }
40+
41+
/** Gets the associated type argument with the given `name`, if any. */
42+
pragma[nomagic]
43+
TypeRepr getAssocTypeArg(string name) {
44+
exists(AssocTypeArg arg |
45+
arg = this.getAGenericArg() and
46+
result = arg.getTypeRepr() and
47+
name = arg.getIdentifier().getText()
48+
)
49+
}
4050
}
4151
}

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,19 @@ class ResultEnum extends Enum {
4949
/** Gets the `Err` variant. */
5050
Variant getErr() { result = this.getVariant("Err") }
5151
}
52+
53+
/**
54+
* The [`Future` trait][1].
55+
*
56+
* [1]: https://doc.rust-lang.org/std/future/trait.Future.html
57+
*/
58+
class FutureTrait extends Trait {
59+
FutureTrait() { this.getCanonicalPath() = "core::future::future::Future" }
60+
61+
/** Gets the `Output` associated type. */
62+
pragma[nomagic]
63+
TypeAlias getOutputType() {
64+
result = this.getAssocItemList().getAnAssocItem() and
65+
result.getName().getText() = "Output"
66+
}
67+
}

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
private import rust
66
private import codeql.rust.elements.internal.generated.ParentChild
77
private import codeql.rust.internal.CachedStages
8+
private import codeql.rust.frameworks.stdlib.Bultins as Builtins
89

910
private newtype TNamespace =
1011
TTypeNamespace() or
@@ -178,6 +179,8 @@ abstract class ItemNode extends Locatable {
178179
or
179180
// type parameters have access to the associated items of its bounds
180181
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
182+
or
183+
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
181184
}
182185

183186
/**
@@ -645,6 +648,28 @@ class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
645648
}
646649
}
647650

651+
private class ImplTraitTypeReprItemNode extends ItemNode instanceof ImplTraitTypeRepr {
652+
pragma[nomagic]
653+
Path getABoundPath() {
654+
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
655+
}
656+
657+
pragma[nomagic]
658+
ItemNode resolveABound() { result = resolvePathFull(this.getABoundPath()) }
659+
660+
override string getName() { result = "(impl trait)" }
661+
662+
override Namespace getNamespace() { result.isType() }
663+
664+
override Visibility getVisibility() { none() }
665+
666+
override TypeParam getTypeParam(int i) { none() }
667+
668+
override predicate hasCanonicalPath(Crate c) { none() }
669+
670+
override string getCanonicalPath(Crate c) { none() }
671+
}
672+
648673
private class MacroCallItemNode extends AssocItemNode instanceof MacroCall {
649674
override string getName() { result = "(macro call)" }
650675

@@ -1093,14 +1118,21 @@ private predicate crateDefEdge(CrateItemNode c, string name, ItemNode i) {
10931118
not i instanceof Crate
10941119
}
10951120

1121+
private class BuiltinSourceFile extends SourceFileItemNode {
1122+
BuiltinSourceFile() { this.getFile().getParentContainer() instanceof Builtins::BuiltinsFolder }
1123+
}
1124+
10961125
/**
10971126
* Holds if `file` depends on crate `dep` named `name`.
10981127
*/
1128+
pragma[nomagic]
10991129
private predicate crateDependencyEdge(SourceFileItemNode file, string name, CrateItemNode dep) {
1100-
exists(CrateItemNode c |
1101-
dep = c.(Crate).getDependency(name) and
1102-
file = c.getASourceFile()
1103-
)
1130+
exists(CrateItemNode c | dep = c.(Crate).getDependency(name) | file = c.getASourceFile())
1131+
or
1132+
// Give builtin files, such as `await.rs`, access to `std`
1133+
file instanceof BuiltinSourceFile and
1134+
dep.getName() = name and
1135+
name = "std"
11041136
}
11051137

11061138
private predicate useTreeDeclares(UseTree tree, string name) {
@@ -1461,9 +1493,14 @@ private predicate externCrateEdge(ExternCrateItemNode ec, string name, CrateItem
14611493
* [1]: https://doc.rust-lang.org/core/prelude/index.html
14621494
* [2]: https://doc.rust-lang.org/std/prelude/index.html
14631495
*/
1496+
pragma[nomagic]
14641497
private predicate preludeEdge(SourceFile f, string name, ItemNode i) {
14651498
exists(Crate stdOrCore, ModuleLikeNode mod, ModuleItemNode prelude, ModuleItemNode rust |
1466-
f = any(Crate c0 | stdOrCore = c0.getDependency(_) or stdOrCore = c0).getASourceFile() and
1499+
f = any(Crate c0 | stdOrCore = c0.getDependency(_) or stdOrCore = c0).getASourceFile()
1500+
or
1501+
// Give builtin files, such as `await.rs`, access to the prelude
1502+
f instanceof BuiltinSourceFile
1503+
|
14671504
stdOrCore.getName() = ["std", "core"] and
14681505
mod = stdOrCore.getSourceFile() and
14691506
prelude = mod.getASuccessorRec("prelude") and
@@ -1473,12 +1510,10 @@ private predicate preludeEdge(SourceFile f, string name, ItemNode i) {
14731510
)
14741511
}
14751512

1476-
private import codeql.rust.frameworks.stdlib.Bultins as Builtins
1477-
14781513
pragma[nomagic]
14791514
private predicate builtin(string name, ItemNode i) {
1480-
exists(SourceFileItemNode builtins |
1481-
builtins.getFile().getParentContainer() instanceof Builtins::BuiltinsFolder and
1515+
exists(BuiltinSourceFile builtins |
1516+
builtins.getFile().getBaseName() = "types.rs" and
14821517
i = builtins.getASuccessorRec(name)
14831518
)
14841519
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ newtype TType =
1515
TTrait(Trait t) or
1616
TArrayType() or // todo: add size?
1717
TRefType() or // todo: add mut?
18+
TImplTraitType(ImplTraitTypeRepr impl) or
1819
TTypeParamTypeParameter(TypeParam t) or
1920
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2021
TRefTypeParameter() or
@@ -115,6 +116,9 @@ class TraitType extends Type, TTrait {
115116

116117
TraitType() { this = TTrait(trait) }
117118

119+
/** Gets the underlying trait. */
120+
Trait getTrait() { result = trait }
121+
118122
override StructField getStructField(string name) { none() }
119123

120124
override TupleField getTupleField(int i) { none() }
@@ -176,6 +180,53 @@ class RefType extends Type, TRefType {
176180
override Location getLocation() { result instanceof EmptyLocation }
177181
}
178182

183+
/**
184+
* An [impl Trait][1] type.
185+
*
186+
* Each syntactic `impl Trait` type gives rise to its own type, even if
187+
* two `impl Trait` types have the same bounds.
188+
*
189+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html
190+
*/
191+
class ImplTraitType extends Type, TImplTraitType {
192+
ImplTraitTypeRepr impl;
193+
194+
ImplTraitType() { this = TImplTraitType(impl) }
195+
196+
/** Gets the underlying AST node. */
197+
ImplTraitTypeRepr getImplTraitTypeRepr() { result = impl }
198+
199+
/** Gets the function that this `impl Trait` belongs to. */
200+
abstract Function getFunction();
201+
202+
override StructField getStructField(string name) { none() }
203+
204+
override TupleField getTupleField(int i) { none() }
205+
206+
override TypeParameter getTypeParameter(int i) { none() }
207+
208+
override string toString() { result = impl.toString() }
209+
210+
override Location getLocation() { result = impl.getLocation() }
211+
}
212+
213+
/**
214+
* An [impl Trait in return position][1] type, for example:
215+
*
216+
* ```rust
217+
* fn foo() -> impl Trait
218+
* ```
219+
*
220+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.return
221+
*/
222+
class ImplTraitReturnType extends ImplTraitType {
223+
private Function function;
224+
225+
ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }
226+
227+
override Function getFunction() { result = function }
228+
}
229+
179230
/** A type parameter. */
180231
abstract class TypeParameter extends Type {
181232
override StructField getStructField(string name) { none() }
@@ -185,7 +236,7 @@ abstract class TypeParameter extends Type {
185236
override TypeParameter getTypeParameter(int i) { none() }
186237
}
187238

188-
private class RawTypeParameter = @type_param or @trait or @type_alias;
239+
private class RawTypeParameter = @type_param or @trait or @type_alias or @impl_trait_type_repr;
189240

190241
private predicate id(RawTypeParameter x, RawTypeParameter y) { x = y }
191242

@@ -281,6 +332,37 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
281332
override Location getLocation() { result = trait.getLocation() }
282333
}
283334

335+
/**
336+
* An [impl Trait in argument position][1] type, for example:
337+
*
338+
* ```rust
339+
* fn foo(arg: impl Trait)
340+
* ```
341+
*
342+
* Such types are syntactic sugar for type parameters, that is
343+
*
344+
* ```rust
345+
* fn foo<T: Trait>(arg: T)
346+
* ```
347+
*
348+
* so we model them as type parameters.
349+
*
350+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.param
351+
*/
352+
class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter {
353+
private Function function;
354+
355+
ImplTraitTypeTypeParameter() { impl = function.getParamList().getAParam().getTypeRepr() }
356+
357+
override Function getFunction() { result = function }
358+
359+
override StructField getStructField(string name) { none() }
360+
361+
override TupleField getTupleField(int i) { none() }
362+
363+
override TypeParameter getTypeParameter(int i) { none() }
364+
}
365+
284366
/**
285367
* A type abstraction. I.e., a place in the program where type variables are
286368
* introduced.
@@ -316,3 +398,7 @@ final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
316398

317399
override TypeParamTypeParameter getATypeParameter() { none() }
318400
}
401+
402+
final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
403+
override TypeParamTypeParameter getATypeParameter() { none() }
404+
}

0 commit comments

Comments
 (0)