@@ -95,7 +95,9 @@ use crate::types::mro::MroErrorKind;
9595use crate::types::newtype::NewType;
9696use crate::types::signatures::{Parameter, Parameters, Signature};
9797use crate::types::subclass_of::SubclassOfInner;
98- use crate::types::tuple::{Tuple, TupleLength, TupleSpec, TupleType};
98+ use crate::types::tuple::{
99+ Tuple, TupleLength, TupleSpec, TupleSpecBuilder, TupleType, VariableLengthTuple,
100+ };
99101use crate::types::typed_dict::{
100102 TypedDictAssignmentKind, validate_typed_dict_constructor, validate_typed_dict_dict_literal,
101103 validate_typed_dict_key_assignment,
@@ -7048,7 +7050,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
70487050 ast::Expr::If(if_expression) => self.infer_if_expression(if_expression, tcx),
70497051 ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
70507052 ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
7051- ast::Expr::Starred(starred) => self.infer_starred_expression(starred),
7053+ ast::Expr::Starred(starred) => self.infer_starred_expression(starred, tcx ),
70527054 ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
70537055 ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
70547056 ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression),
@@ -7284,25 +7286,66 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
72847286 )
72857287 });
72867288
7289+ let mut is_homogeneous_tuple_annotation = false;
7290+
72877291 let annotated_tuple = tcx
72887292 .known_specialization(self.db(), KnownClass::Tuple)
72897293 .and_then(|specialization| {
7290- specialization
7294+ let spec = specialization
72917295 .tuple(self.db())
7292- .expect("the specialization of `KnownClass::Tuple` must have a tuple spec")
7293- .resize(self.db(), TupleLength::Fixed(elts.len()))
7294- .ok()
7296+ .expect("the specialization of `KnownClass::Tuple` must have a tuple spec");
7297+
7298+ if matches!(
7299+ spec,
7300+ Tuple::Variable(VariableLengthTuple { prefix, variable: _, suffix})
7301+ if prefix.is_empty() && suffix.is_empty()
7302+ ) {
7303+ is_homogeneous_tuple_annotation = true;
7304+ }
7305+
7306+ spec.resize(self.db(), TupleLength::Fixed(elts.len())).ok()
72957307 });
72967308
72977309 let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements);
72987310
72997311 let db = self.db();
7300- let element_types = elts.iter().map(|element| {
7301- let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
7302- self.infer_expression(element, TypeContext::new(annotated_elt_ty))
7303- });
73047312
7305- Type::heterogeneous_tuple(db, element_types)
7313+ let can_use_type_context =
7314+ is_homogeneous_tuple_annotation || elts.iter().all(|elt| !elt.is_starred_expr());
7315+
7316+ let mut infer_element = |elt: &ast::Expr| {
7317+ if can_use_type_context {
7318+ let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
7319+ let context = if let ast::Expr::Starred(starred) = elt {
7320+ annotated_elt_ty
7321+ .map(|expected_element_type| {
7322+ TypeContext::for_starred_expression(db, expected_element_type, starred)
7323+ })
7324+ .unwrap_or_default()
7325+ } else {
7326+ TypeContext::new(annotated_elt_ty)
7327+ };
7328+ self.infer_expression(elt, context)
7329+ } else {
7330+ self.infer_expression(elt, TypeContext::default())
7331+ }
7332+ };
7333+
7334+ let mut builder = TupleSpecBuilder::with_capacity(elts.len());
7335+
7336+ for element in elts {
7337+ if element.is_starred_expr() {
7338+ let element_type = infer_element(element);
7339+ // Fine to use `iterate` rather than `try_iterate` here:
7340+ // errors from iterating over something not iterable will have been
7341+ // emitted in the `infer_element` call above.
7342+ builder = builder.concat(db, &element_type.iterate(db));
7343+ } else {
7344+ builder.push(infer_element(element));
7345+ }
7346+ }
7347+
7348+ Type::tuple(TupleType::new(db, &builder.build()))
73067349 }
73077350
73087351 fn infer_list_expression(&mut self, list: &ast::ExprList, tcx: TypeContext<'db>) -> Type<'db> {
@@ -7459,7 +7502,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
74597502
74607503 let inferable = generic_context.inferable_typevars(self.db());
74617504
7462- // Remove any union elements of that are unrelated to the collection type.
7505+ // Remove any union elements of the annotation that are unrelated to the collection type.
74637506 //
74647507 // For example, we only want the `list[int]` from `annotation: list[int] | None` if
74657508 // `collection_ty` is `list`.
@@ -7499,8 +7542,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
74997542 }
75007543
75017544 let elt_tcxs = match annotated_elt_tys {
7502- None => Either::Left(iter::repeat(TypeContext::default() )),
7503- Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new( Some(*ty)) )),
7545+ None => Either::Left(iter::repeat(None )),
7546+ Some(tys) => Either::Right(tys.iter().copied(). map(Some)),
75047547 };
75057548
75067549 for elts in elts {
@@ -7529,6 +7572,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
75297572 {
75307573 let Some(elt) = elt else { continue };
75317574
7575+ let elt_tcx = if let ast::Expr::Starred(starred) = elt {
7576+ elt_tcx
7577+ .map(|ty| TypeContext::for_starred_expression(self.db(), ty, starred))
7578+ .unwrap_or_default()
7579+ } else {
7580+ TypeContext::new(elt_tcx)
7581+ };
7582+
75327583 let inferred_elt_ty = infer_elt_expression(self, elt, elt_tcx);
75337584
75347585 // Simplify the inference based on the declared type of the element.
@@ -7542,7 +7593,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
75427593 // unions for large nested list literals, which the constraint solver struggles with.
75437594 let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db(), elt_tcx);
75447595
7545- builder.infer(Type::TypeVar(elt_ty), inferred_elt_ty).ok()?;
7596+ builder
7597+ .infer(
7598+ Type::TypeVar(elt_ty),
7599+ if elt.is_starred_expr() {
7600+ inferred_elt_ty
7601+ .iterate(self.db())
7602+ .homogeneous_element_type(self.db())
7603+ } else {
7604+ inferred_elt_ty
7605+ },
7606+ )
7607+ .ok()?;
75467608 }
75477609 }
75487610
@@ -8359,25 +8421,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
83598421 }
83608422 }
83618423
8362- fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
8424+ fn infer_starred_expression(
8425+ &mut self,
8426+ starred: &ast::ExprStarred,
8427+ tcx: TypeContext<'db>,
8428+ ) -> Type<'db> {
83638429 let ast::ExprStarred {
83648430 range: _,
83658431 node_index: _,
83668432 value,
83678433 ctx: _,
83688434 } = starred;
83698435
8370- let iterable_type = self.infer_expression(value, TypeContext::default());
8436+ let db = self.db();
8437+ let iterable_type = self.infer_expression(value, tcx);
8438+
83718439 iterable_type
8372- .try_iterate(self.db() )
8373- .map(|tuple| tuple.homogeneous_element_type(self.db( )))
8440+ .try_iterate(db )
8441+ .map(|spec| Type:: tuple(TupleType::new(db, &spec )))
83748442 .unwrap_or_else(|err| {
83758443 err.report_diagnostic(&self.context, iterable_type, value.as_ref().into());
8376- err.fallback_element_type(self.db())
8377- });
8378-
8379- // TODO
8380- todo_type!("starred expression")
8444+ Type::homogeneous_tuple(db, err.fallback_element_type(db))
8445+ })
83818446 }
83828447
83838448 fn infer_yield_expression(&mut self, yield_expression: &ast::ExprYield) -> Type<'db> {
0 commit comments