Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Constraint set satisfaction

```toml
[environment]
python-version = "3.12"
```

Constraint sets exist to help us check assignability and subtyping of types in the presence of
typevars. We construct a constraint set describing the conditions under which assignability holds
between the two types. Then we check whether that constraint set is satisfied for the valid
specializations of the relevant typevars. This file tests that final step.

## Inferable vs non-inferable typevars

Typevars can appear in _inferable_ or _non-inferable_ positions.

When a typevar is in an inferable position, the constraint set only needs to be satisfied for _some_
valid specialization. The most common inferable position occurs when invoking a generic function:
all of the function's typevars are inferable, because we want to use the argument types to infer
which specialization is being invoked.
Comment on lines +17 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Thanks for adding it


When a typevar is in a non-inferable position, the constraint set must be satisfied for _every_
valid specialization. The most common non-inferable position occurs in the body of a generic
function or class: here we don't know in advance what type the typevar will be specialized to, and
so we have to ensure that the body is valid for all possible specializations.

```py
def f[T](t: T) -> T:
# In the function body, T is non-inferable. All assignability checks involving T must be
# satisfied for _all_ valid specializations of T.
return t

# When invoking the function, T is inferable — we attempt to infer a specialization that is valid
# for the particular arguments that are passed to the function. Assignability checks (in particular,
# that the argument type is assignable to the parameter type) only need to succeed for _at least
# one_ specialization.
f(1)
```

In all of the examples below, for ease of reproducibility, we explicitly list the typevars that are
inferable in each `satisfied_by_all_typevars` call; any typevar not listed is assumed to be
non-inferable.

## Unbounded typevar

If a typevar has no bound or constraints, then it can specialize to any type. In an inferable
position, that means we just need a single type (any type at all!) that satisfies the constraint
set. In a non-inferable position, that means the constraint set must be satisfied for every possible
type.

```py
from typing import final, Never
from ty_extensions import ConstraintSet, static_assert

class Super: ...
class Base(Super): ...
class Sub(Base): ...

@final
class Unrelated: ...

def unbounded[T]():
static_assert(ConstraintSet.always().satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(ConstraintSet.always().satisfied_by_all_typevars())

static_assert(not ConstraintSet.never().satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(not ConstraintSet.never().satisfied_by_all_typevars())

# (T = Never) is a valid specialization, which satisfies (T ≤ Unrelated).
static_assert(ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Base) is a valid specialization, which does not satisfy (T ≤ Unrelated).
static_assert(not ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars())

# (T = Base) is a valid specialization, which satisfies (T ≤ Super).
static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Super).
static_assert(not ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars())

# (T = Base) is a valid specialization, which satisfies (T ≤ Base).
static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Base).
static_assert(not ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars())

# (T = Sub) is a valid specialization, which satisfies (T ≤ Sub).
static_assert(ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Sub).
static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars())
```

## Typevar with an upper bound

If a typevar has an upper bound, then it must specialize to a type that is a subtype of that bound.
For an inferable typevar, that means we need a single type that satisfies both the constraint set
and the upper bound. For a non-inferable typevar, that means the constraint set must be satisfied
for every type that satisfies the upper bound.

```py
from typing import final, Never
from ty_extensions import ConstraintSet, static_assert

class Super: ...
class Base(Super): ...
class Sub(Base): ...

@final
class Unrelated: ...

def bounded[T: Base]():
static_assert(ConstraintSet.always().satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(ConstraintSet.always().satisfied_by_all_typevars())

static_assert(not ConstraintSet.never().satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(not ConstraintSet.never().satisfied_by_all_typevars())

# (T = Base) is a valid specialization, which satisfies (T ≤ Super).
static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T]))
# Every valid specialization satisfies (T ≤ Base). Since (Base ≤ Super), every valid
# specialization also satisfies (T ≤ Super).
static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars())

# (T = Base) is a valid specialization, which satisfies (T ≤ Base).
static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars(inferable=tuple[T]))
# Every valid specialization satisfies (T ≤ Base).
static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars())

# (T = Sub) is a valid specialization, which satisfies (T ≤ Sub).
static_assert(ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Base) is a valid specialization, which does not satisfy (T ≤ Sub).
static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars())

# (T = Never) is a valid specialization, which satisfies (T ≤ Unrelated).
constraints = ConstraintSet.range(Never, T, Unrelated)
static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Base) is a valid specialization, which does not satisfy (T ≤ Unrelated).
static_assert(not constraints.satisfied_by_all_typevars())

# Never is the only type that satisfies both (T ≤ Base) and (T ≤ Unrelated). So there is no
# valid specialization that satisfies (T ≤ Unrelated ∧ T ≠ Never).
constraints = constraints & ~ConstraintSet.range(Never, T, Never)
static_assert(not constraints.satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(not constraints.satisfied_by_all_typevars())
```

## Constrained typevar

If a typevar has constraints, then it must specialize to one of those specific types. (Not to a
subtype of one of those types!) For an inferable typevar, that means we need the constraint set to
be satisfied by any one of the constraints. For a non-inferable typevar, that means we need the
constraint set to be satisfied by all of those constraints.

```py
from typing import final, Never
from ty_extensions import ConstraintSet, static_assert

class Super: ...
class Base(Super): ...
class Sub(Base): ...

@final
class Unrelated: ...

def constrained[T: (Base, Unrelated)]():
static_assert(ConstraintSet.always().satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(ConstraintSet.always().satisfied_by_all_typevars())

static_assert(not ConstraintSet.never().satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(not ConstraintSet.never().satisfied_by_all_typevars())

# (T = Unrelated) is a valid specialization, which satisfies (T ≤ Unrelated).
static_assert(ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Base) is a valid specialization, which does not satisfy (T ≤ Unrelated).
static_assert(not ConstraintSet.range(Never, T, Unrelated).satisfied_by_all_typevars())

# (T = Base) is a valid specialization, which satisfies (T ≤ Super).
static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still somewhat confused by the DSL here but I (maybe?) finally figured out how to read this?

Is my understanding correct that this results in:

Never <= any(Base, Unrelated) <= Super

which is true because Base satisfies this constraint.

And the next example is:

Never < forall(Base, Unrelated) < Super

which is false, because Unrelated doesn't satisfy this constraint.

The part I'm struggling with right now is what a real-world example of static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T])) would look like?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still somewhat confused by the DSL here but I (maybe?) finally figured out how to read this?

I would formalize it differently, but I think your notation leads to the right understanding:

Is my understanding correct that this results in:

Never <= any(Base, Unrelated) <= Super

which is true because Base satisfies this constraint.

Yes

And the next example is:

Never < forall(Base, Unrelated) < Super

which is false, because Unrelated doesn't satisfy this constraint.

This should use <= instead of < like the first example, but otherwise yes.

The part I'm struggling with right now is what a real-world example of static_assert(ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars(inferable=tuple[T])) would look like?

I'm actually not sure what Python I could write that would result in this check either! We need something that wants to check that an instance of T is assignable to Super. The non-inferable case is easy:

def f[T: (Base, Unrelated)](t: T):
    x: Super = t

But in the T inferable case, we would be invoking this function. And I'm not sure what Python code would lead to a T ≤ Super check.

The problem is that I can't limit myself to implementing these algorithms for constraint set checks that have obvious Python analogues. 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in the T inferable case, we would be invoking this function. And I'm not sure what Python code would lead to a T ≤ Super check.

Ah maybe via the return type:

def f[T: (Base, Unrelated)]() -> T:
    raise NotImplementedError

x: Super = f()

(Disregard the fact that there is no real function body that we could write that would satisfy that signature!)

# (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Super).
static_assert(not ConstraintSet.range(Never, T, Super).satisfied_by_all_typevars())

# (T = Base) is a valid specialization, which satisfies (T ≤ Base).
static_assert(ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Unrelated) is a valid specialization, which does not satisfy (T ≤ Base).
static_assert(not ConstraintSet.range(Never, T, Base).satisfied_by_all_typevars())

# Neither (T = Base) nor (T = Unrelated) satisfy (T ≤ Sub).
static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(not ConstraintSet.range(Never, T, Sub).satisfied_by_all_typevars())

# (T = Base) and (T = Unrelated) both satisfy (T ≤ Super ∨ T ≤ Unrelated).
constraints = ConstraintSet.range(Never, T, Super) | ConstraintSet.range(Never, T, Unrelated)
static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(constraints.satisfied_by_all_typevars())

# (T = Base) and (T = Unrelated) both satisfy (T ≤ Base ∨ T ≤ Unrelated).
constraints = ConstraintSet.range(Never, T, Base) | ConstraintSet.range(Never, T, Unrelated)
static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(constraints.satisfied_by_all_typevars())

# (T = Unrelated) is a valid specialization, which satisfies (T ≤ Sub ∨ T ≤ Unrelated).
constraints = ConstraintSet.range(Never, T, Sub) | ConstraintSet.range(Never, T, Unrelated)
static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Base) is a valid specialization, which does not satisfy (T ≤ Sub ∨ T ≤ Unrelated).
static_assert(not constraints.satisfied_by_all_typevars())

# (T = Unrelated) is a valid specialization, which satisfies (T = Super ∨ T = Unrelated).
constraints = ConstraintSet.range(Super, T, Super) | ConstraintSet.range(Unrelated, T, Unrelated)
static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Base) is a valid specialization, which does not satisfy (T = Super ∨ T = Unrelated).
static_assert(not constraints.satisfied_by_all_typevars())

# (T = Base) and (T = Unrelated) both satisfy (T = Base ∨ T = Unrelated).
constraints = ConstraintSet.range(Base, T, Base) | ConstraintSet.range(Unrelated, T, Unrelated)
static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T]))
static_assert(constraints.satisfied_by_all_typevars())

# (T = Unrelated) is a valid specialization, which satisfies (T = Sub ∨ T = Unrelated).
constraints = ConstraintSet.range(Sub, T, Sub) | ConstraintSet.range(Unrelated, T, Unrelated)
static_assert(constraints.satisfied_by_all_typevars(inferable=tuple[T]))
# (T = Base) is a valid specialization, which does not satisfy (T = Sub ∨ T = Unrelated).
static_assert(not constraints.satisfied_by_all_typevars())
```
57 changes: 49 additions & 8 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4159,6 +4159,14 @@ impl<'db> Type<'db> {
))
.into()
}
Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked))
if name == "satisfied_by_all_typevars" =>
{
Place::bound(Type::KnownBoundMethod(
KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(tracked),
))
.into()
}

Type::ClassLiteral(class)
if name == "__get__" && class.is_known(db, KnownClass::FunctionType) =>
Expand Down Expand Up @@ -6921,6 +6929,7 @@ impl<'db> Type<'db> {
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
Expand Down Expand Up @@ -7072,7 +7081,8 @@ impl<'db> Type<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_),
)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
Expand Down Expand Up @@ -10337,6 +10347,7 @@ pub enum KnownBoundMethodType<'db> {
ConstraintSetAlways,
ConstraintSetNever,
ConstraintSetImpliesSubtypeOf(TrackedConstraintSet<'db>),
ConstraintSetSatisfiedByAllTypeVars(TrackedConstraintSet<'db>),
}

pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
Expand Down Expand Up @@ -10364,7 +10375,8 @@ pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Size
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => {}
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => {}
}
}

Expand Down Expand Up @@ -10432,6 +10444,10 @@ impl<'db> KnownBoundMethodType<'db> {
| (
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
)
| (
KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_),
KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_),
) => ConstraintSet::from(true),

(
Expand All @@ -10444,7 +10460,8 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_),
KnownBoundMethodType::FunctionTypeDunderGet(_)
| KnownBoundMethodType::FunctionTypeDunderCall(_)
| KnownBoundMethodType::PropertyDunderGet(_)
Expand All @@ -10454,7 +10471,8 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_),
) => ConstraintSet::from(false),
}
}
Expand Down Expand Up @@ -10507,6 +10525,10 @@ impl<'db> KnownBoundMethodType<'db> {
(
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(left_constraints),
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(right_constraints),
)
| (
KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(left_constraints),
KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(right_constraints),
) => left_constraints
.constraints(db)
.iff(db, right_constraints.constraints(db)),
Expand All @@ -10521,7 +10543,8 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_),
KnownBoundMethodType::FunctionTypeDunderGet(_)
| KnownBoundMethodType::FunctionTypeDunderCall(_)
| KnownBoundMethodType::PropertyDunderGet(_)
Expand All @@ -10531,7 +10554,8 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_),
) => ConstraintSet::from(false),
}
}
Expand All @@ -10555,7 +10579,8 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => self,
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => self,
}
}

Expand All @@ -10571,7 +10596,10 @@ impl<'db> KnownBoundMethodType<'db> {
KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => KnownClass::ConstraintSet,
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => {
KnownClass::ConstraintSet
}
}
}

Expand Down Expand Up @@ -10710,6 +10738,19 @@ impl<'db> KnownBoundMethodType<'db> {
Some(KnownClass::ConstraintSet.to_instance(db)),
)))
}

KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_) => {
Either::Right(std::iter::once(Signature::new(
Parameters::new([Parameter::keyword_only(Name::new_static("inferable"))
.type_form()
.with_annotated_type(UnionType::from_elements(
db,
[Type::homogeneous_tuple(db, Type::any()), Type::none(db)],
))
.with_default_type(Type::none(db))]),
Some(KnownClass::Bool.to_instance(db)),
)))
}
}
}
}
Expand Down
Loading
Loading