Skip to content

Commit 2ebb619

Browse files
committed
Rust: Reimplement type inference for impl Traits and await expressions
1 parent ad4aae2 commit 2ebb619

File tree

9 files changed

+327
-162
lines changed

9 files changed

+327
-162
lines changed

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
private import rust
66
private import codeql.rust.elements.internal.generated.ParentChild
77
private import codeql.rust.internal.CachedStages
8+
private import codeql.rust.frameworks.stdlib.Bultins as Builtins
89

910
private newtype TNamespace =
1011
TTypeNamespace() or
@@ -165,6 +166,8 @@ abstract class ItemNode extends Locatable {
165166
or
166167
// type parameters have access to the associated items of its bounds
167168
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
169+
or
170+
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
168171
}
169172

170173
/**
@@ -618,6 +621,28 @@ class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
618621
}
619622
}
620623

624+
private class ImplTraitTypeReprItemNode extends ItemNode instanceof ImplTraitTypeRepr {
625+
pragma[nomagic]
626+
Path getABoundPath() {
627+
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
628+
}
629+
630+
pragma[nomagic]
631+
ItemNode resolveABound() { result = resolvePathFull(this.getABoundPath()) }
632+
633+
override string getName() { result = "(impl trait)" }
634+
635+
override Namespace getNamespace() { result.isType() }
636+
637+
override Visibility getVisibility() { none() }
638+
639+
override TypeParam getTypeParam(int i) { none() }
640+
641+
override predicate hasCanonicalPath(Crate c) { none() }
642+
643+
override string getCanonicalPath(Crate c) { none() }
644+
}
645+
621646
private class MacroCallItemNode extends AssocItemNode instanceof MacroCall {
622647
override string getName() { result = "(macro call)" }
623648

@@ -1062,14 +1087,21 @@ private predicate crateDefEdge(CrateItemNode c, string name, ItemNode i) {
10621087
not i instanceof Crate
10631088
}
10641089

1090+
private class BuiltinSourceFile extends SourceFileItemNode {
1091+
BuiltinSourceFile() { this.getFile().getParentContainer() instanceof Builtins::BuiltinsFolder }
1092+
}
1093+
10651094
/**
10661095
* Holds if `m` depends on crate `dep` named `name`.
10671096
*/
1097+
pragma[nomagic]
10681098
private predicate crateDependencyEdge(ModuleLikeNode m, string name, CrateItemNode dep) {
1069-
exists(CrateItemNode c |
1070-
dep = c.(Crate).getDependency(name) and
1071-
m = c.getASourceFile()
1072-
)
1099+
exists(CrateItemNode c | dep = c.(Crate).getDependency(name) | m = c.getASourceFile())
1100+
or
1101+
// Give builtin files, such as `await.rs`, access to `std`
1102+
m instanceof BuiltinSourceFile and
1103+
dep.getName() = name and
1104+
name = "std"
10731105
}
10741106

10751107
private predicate useTreeDeclares(UseTree tree, string name) {
@@ -1414,9 +1446,14 @@ private predicate useImportEdge(Use use, string name, ItemNode item) {
14141446
* [1]: https://doc.rust-lang.org/core/prelude/index.html
14151447
* [2]: https://doc.rust-lang.org/std/prelude/index.html
14161448
*/
1449+
pragma[nomagic]
14171450
private predicate preludeEdge(SourceFile f, string name, ItemNode i) {
14181451
exists(Crate stdOrCore, ModuleLikeNode mod, ModuleItemNode prelude, ModuleItemNode rust |
1419-
f = any(Crate c0 | stdOrCore = c0.getDependency(_) or stdOrCore = c0).getASourceFile() and
1452+
f = any(Crate c0 | stdOrCore = c0.getDependency(_) or stdOrCore = c0).getASourceFile()
1453+
or
1454+
// Give builtin files, such as `await.rs`, access to the prelude
1455+
f instanceof BuiltinSourceFile
1456+
|
14201457
stdOrCore.getName() = ["std", "core"] and
14211458
mod = stdOrCore.getSourceFile() and
14221459
prelude = mod.getASuccessorRec("prelude") and
@@ -1426,12 +1463,10 @@ private predicate preludeEdge(SourceFile f, string name, ItemNode i) {
14261463
)
14271464
}
14281465

1429-
private import codeql.rust.frameworks.stdlib.Bultins as Builtins
1430-
14311466
pragma[nomagic]
14321467
private predicate builtin(string name, ItemNode i) {
1433-
exists(SourceFileItemNode builtins |
1434-
builtins.getFile().getParentContainer() instanceof Builtins::BuiltinsFolder and
1468+
exists(BuiltinSourceFile builtins |
1469+
builtins.getFile().getBaseName() = "types.rs" and
14351470
i = builtins.getASuccessorRec(name)
14361471
)
14371472
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@ newtype TType =
1515
TTrait(Trait t) or
1616
TArrayType() or // todo: add size?
1717
TRefType() or // todo: add mut?
18-
TImplTraitType(int bounds) {
19-
bounds = any(ImplTraitTypeRepr impl).getTypeBoundList().getNumberOfBounds()
20-
} or
18+
TImplTraitType(ImplTraitTypeRepr impl) or
2119
TTypeParamTypeParameter(TypeParam t) or
2220
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2321
TRefTypeParameter() or
24-
TSelfTypeParameter(Trait t) or
25-
TImplTraitTypeParameter(ImplTraitType t, int i) { i in [0 .. t.getNumberOfBounds() - 1] }
22+
TSelfTypeParameter(Trait t)
2623

2724
/**
2825
* A type without type arguments.
@@ -184,30 +181,50 @@ class RefType extends Type, TRefType {
184181
}
185182

186183
/**
187-
* An [`impl Trait`][1] type.
184+
* An [impl Trait][1] type.
188185
*
189-
* We represent `impl Trait` types as generic types with as many type parameters
190-
* as there are bounds.
186+
* Each syntactic `impl Trait` type gives rise to its own type, even if
187+
* two `impl Trait` types have the same bounds.
191188
*
192-
* [1] https://doc.rust-lang.org/book/ch10-02-traits.html#traits-as-parameters
189+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html
193190
*/
194191
class ImplTraitType extends Type, TImplTraitType {
195-
private int bounds;
192+
ImplTraitTypeRepr impl;
196193

197-
ImplTraitType() { this = TImplTraitType(bounds) }
194+
ImplTraitType() { this = TImplTraitType(impl) }
198195

199-
/** Gets the number of bounds of this `impl Trait` type. */
200-
int getNumberOfBounds() { result = bounds }
196+
/** Gets the underlying AST node. */
197+
ImplTraitTypeRepr getImplTraitTypeRepr() { result = impl }
198+
199+
/** Gets the function that this `impl Trait` belongs to. */
200+
abstract Function getFunction();
201201

202202
override StructField getStructField(string name) { none() }
203203

204204
override TupleField getTupleField(int i) { none() }
205205

206-
override TypeParameter getTypeParameter(int i) { result = TImplTraitTypeParameter(this, i) }
206+
override TypeParameter getTypeParameter(int i) { none() }
207207

208-
override string toString() { result = "impl Trait ..." }
208+
override string toString() { result = impl.toString() }
209209

210-
override Location getLocation() { result instanceof EmptyLocation }
210+
override Location getLocation() { result = impl.getLocation() }
211+
}
212+
213+
/**
214+
* An [impl Trait in return position][1] type, for example:
215+
*
216+
* ```rust
217+
* fn foo() -> impl Trait
218+
* ```
219+
*
220+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.return
221+
*/
222+
class ImplTraitReturnType extends ImplTraitType {
223+
private Function function;
224+
225+
ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }
226+
227+
override Function getFunction() { result = function }
211228
}
212229

213230
/** A type parameter. */
@@ -219,7 +236,7 @@ abstract class TypeParameter extends Type {
219236
override TypeParameter getTypeParameter(int i) { none() }
220237
}
221238

222-
private class RawTypeParameter = @type_param or @trait or @type_alias;
239+
private class RawTypeParameter = @type_param or @trait or @type_alias or @impl_trait_type_repr;
223240

224241
private predicate id(RawTypeParameter x, RawTypeParameter y) { x = y }
225242

@@ -316,23 +333,34 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
316333
}
317334

318335
/**
319-
* An `impl Trait` type parameter.
336+
* An [impl Trait in argument position][1] type, for example:
337+
*
338+
* ```rust
339+
* fn foo(arg: impl Trait)
340+
* ```
341+
*
342+
* Such types are syntactic sugar for type parameters, that is
343+
*
344+
* ```rust
345+
* fn foo<T: Trait>(arg: T)
346+
* ```
347+
*
348+
* so we model them as type parameters.
349+
*
350+
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.param
320351
*/
321-
class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
322-
private ImplTraitType implTraitType;
323-
private int i;
352+
class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter {
353+
private Function function;
324354

325-
ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTraitType, i) }
355+
ImplTraitTypeTypeParameter() { impl = function.getParamList().getAParam().getTypeRepr() }
326356

327-
/** Gets the `impl Trait` type that this parameter belongs to. */
328-
ImplTraitType getImplTraitType() { result = implTraitType }
357+
override Function getFunction() { result = function }
329358

330-
/** Gets the index of this type parameter. */
331-
int getIndex() { result = i }
359+
override StructField getStructField(string name) { none() }
332360

333-
override string toString() { result = "impl Trait<" + i.toString() + ">" }
361+
override TupleField getTupleField(int i) { none() }
334362

335-
override Location getLocation() { result instanceof EmptyLocation }
363+
override TypeParameter getTypeParameter(int i) { none() }
336364
}
337365

338366
/**
@@ -370,3 +398,7 @@ final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
370398

371399
override TypeParamTypeParameter getATypeParameter() { none() }
372400
}
401+
402+
final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
403+
override TypeParamTypeParameter getATypeParameter() { none() }
404+
}

0 commit comments

Comments
 (0)