From fd3e0d120f572b474e8b8576fbcde7797501111a Mon Sep 17 00:00:00 2001 From: Victorien <65306057+Viicos@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:31:10 +0200 Subject: [PATCH 1/4] Fix `style` argument of `BaseDatabaseOperations.sql_flush` --- django-stubs/db/backends/base/operations.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django-stubs/db/backends/base/operations.pyi b/django-stubs/db/backends/base/operations.pyi index ac3a1981b..0d26c6bb1 100644 --- a/django-stubs/db/backends/base/operations.pyi +++ b/django-stubs/db/backends/base/operations.pyi @@ -66,7 +66,7 @@ class BaseDatabaseOperations: def savepoint_rollback_sql(self, sid: str) -> str: ... def set_time_zone_sql(self) -> str: ... def sql_flush( - self, style: Any, tables: Sequence[str], *, reset_sequences: bool = ..., allow_cascade: bool = ... + self, style: Style, tables: Sequence[str], *, reset_sequences: bool = ..., allow_cascade: bool = ... ) -> list[str]: ... def execute_sql_flush(self, sql_list: Iterable[str]) -> None: ... def sequence_reset_by_name_sql(self, style: Style | None, sequences: list[Any]) -> list[Any]: ... From 5474e6239dc8e3e6abef7c6be62c99a08a068049 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:01:23 +0200 Subject: [PATCH 2/4] Improve types of `BaseDatabaseOperations` --- .../gis/db/backends/base/operations.pyi | 6 +++-- django-stubs/db/backends/base/operations.pyi | 27 ++++++++++--------- django-stubs/db/backends/mysql/operations.pyi | 8 ------ .../db/backends/oracle/operations.pyi | 10 ------- django-stubs/db/models/expressions.pyi | 7 ++--- 5 files changed, 23 insertions(+), 35 deletions(-) diff --git a/django-stubs/contrib/gis/db/backends/base/operations.pyi b/django-stubs/contrib/gis/db/backends/base/operations.pyi index bf6aa10d3..91cfc0387 100644 --- a/django-stubs/contrib/gis/db/backends/base/operations.pyi +++ b/django-stubs/contrib/gis/db/backends/base/operations.pyi @@ -1,5 +1,7 @@ from typing import Any +from django.db.backends.base.operations import _Converter +from django.db.models.expressions import Expression from django.utils.functional import cached_property class BaseSpatialOperations: @@ -24,13 +26,13 @@ class BaseSpatialOperations: def geo_db_type(self, f: Any) -> Any: ... def get_distance(self, f: Any, value: Any, lookup_type: Any) -> Any: ... def get_geom_placeholder(self, f: Any, value: Any, compiler: Any) -> Any: ... - def check_expression_support(self, expression: Any) -> None: ... + def check_expression_support(self, expression: Expression) -> None: ... def spatial_aggregate_name(self, agg_name: Any) -> Any: ... def spatial_function_name(self, func_name: Any) -> Any: ... def geometry_columns(self) -> Any: ... def spatial_ref_sys(self) -> Any: ... distance_expr_for_lookup: Any - def get_db_converters(self, expression: Any) -> Any: ... + def get_db_converters(self, expression: Expression) -> list[_Converter]: ... def get_geometry_converter(self, expression: Any) -> Any: ... def get_area_att_for_field(self, field: Any) -> Any: ... def get_distance_att_for_field(self, field: Any) -> Any: ... diff --git a/django-stubs/db/backends/base/operations.pyi b/django-stubs/db/backends/base/operations.pyi index 0d26c6bb1..6ad3f8d93 100644 --- a/django-stubs/db/backends/base/operations.pyi +++ b/django-stubs/db/backends/base/operations.pyi @@ -1,5 +1,5 @@ import json -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Sequence from datetime import date, time, timedelta from datetime import datetime as real_datetime from decimal import Decimal @@ -13,6 +13,9 @@ from django.db.models.constants import OnConflict from django.db.models.expressions import Case, Col, Expression from django.db.models.fields import Field from django.db.models.sql.compiler import SQLCompiler +from typing_extensions import TypeAlias + +_Converter: TypeAlias = Callable[[Any, Expression, BaseDatabaseWrapper], Any] class BaseDatabaseOperations: compiler_module: str @@ -57,7 +60,7 @@ class BaseDatabaseOperations: def pk_default_value(self) -> str: ... def prepare_sql_script(self, sql: Any) -> list[str]: ... def process_clob(self, value: str) -> str: ... - def return_insert_columns(self, fields: Any) -> Any: ... + def return_insert_columns(self, fields: list[Field[Any, Any]]) -> tuple[str, list[Any]]: ... def compiler(self, compiler_name: str) -> type[SQLCompiler]: ... def quote_name(self, name: str) -> str: ... def regex_lookup(self, lookup_type: str) -> str: ... @@ -68,14 +71,14 @@ class BaseDatabaseOperations: def sql_flush( self, style: Style, tables: Sequence[str], *, reset_sequences: bool = ..., allow_cascade: bool = ... ) -> list[str]: ... - def execute_sql_flush(self, sql_list: Iterable[str]) -> None: ... - def sequence_reset_by_name_sql(self, style: Style | None, sequences: list[Any]) -> list[Any]: ... - def sequence_reset_sql(self, style: Style, model_list: Sequence[type[Model]]) -> list[Any]: ... + def execute_sql_flush(self, sql_list: list[str]) -> None: ... + def sequence_reset_by_name_sql(self, style: Style, sequences: list[dict[str, str | None]]) -> list[str]: ... + def sequence_reset_sql(self, style: Style, model_list: list[type[Model]]) -> list[str]: ... def start_transaction_sql(self) -> str: ... def end_transaction_sql(self, success: bool = ...) -> str: ... - def tablespace_sql(self, tablespace: str | None, inline: bool = ...) -> str: ... - def prep_for_like_query(self, x: str) -> str: ... - prep_for_iexact_query: Any + def tablespace_sql(self, tablespace: str, inline: bool = ...) -> str: ... + def prep_for_like_query(self, x: object) -> str: ... + def prep_for_iexact_query(self, x: object) -> str: ... def validate_autopk_value(self, value: int) -> int: ... def adapt_unknown_value(self, value: Any) -> Any: ... def adapt_datefield_value(self, value: date | None) -> str | None: ... @@ -89,14 +92,14 @@ class BaseDatabaseOperations: def adapt_integerfield_value(self, value: Any, internal_type: Any) -> Any: ... def year_lookup_bounds_for_date_field(self, value: int, iso_year: bool = ...) -> list[str]: ... def year_lookup_bounds_for_datetime_field(self, value: int, iso_year: bool = ...) -> list[str]: ... - def get_db_converters(self, expression: Expression) -> list[Any]: ... + def get_db_converters(self, expression: Expression) -> list[_Converter]: ... def convert_durationfield_value( self, value: float | None, expression: Expression, connection: BaseDatabaseWrapper ) -> timedelta | None: ... - def check_expression_support(self, expression: Any) -> None: ... - def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool: ... + def check_expression_support(self, expression: Expression) -> None: ... + def conditional_expression_supported_in_where_clause(self, expression: Expression) -> bool: ... def combine_expression(self, connector: str, sub_expressions: list[str]) -> str: ... - def combine_duration_expression(self, connector: Any, sub_expressions: Any) -> str: ... + def combine_duration_expression(self, connector: str, sub_expressions: list[str]) -> str: ... def binary_placeholder_sql(self, value: Case | None) -> str: ... def modify_insert_params(self, placeholder: str, params: Any) -> Any: ... def integer_field_range(self, internal_type: Any) -> tuple[int, int]: ... diff --git a/django-stubs/db/backends/mysql/operations.pyi b/django-stubs/db/backends/mysql/operations.pyi index 7c330c9ad..28a709383 100644 --- a/django-stubs/db/backends/mysql/operations.pyi +++ b/django-stubs/db/backends/mysql/operations.pyi @@ -23,22 +23,14 @@ class DatabaseOperations(BaseDatabaseOperations): def force_no_ordering(self) -> Any: ... def last_executed_query(self, cursor: Any, sql: Any, params: Any) -> Any: ... def no_limit_value(self) -> Any: ... - def quote_name(self, name: str) -> Any: ... - def return_insert_columns(self, fields: Any) -> Any: ... - def sequence_reset_by_name_sql(self, style: Any, sequences: Any) -> Any: ... - def validate_autopk_value(self, value: Any) -> Any: ... def adapt_datetimefield_value(self, value: Any) -> Any: ... def adapt_timefield_value(self, value: Any) -> Any: ... def max_name_length(self) -> Any: ... def bulk_insert_sql(self, fields: Any, placeholder_rows: Any) -> Any: ... - def combine_expression(self, connector: Any, sub_expressions: Any) -> Any: ... - def get_db_converters(self, expression: Any) -> Any: ... def convert_booleanfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ... def convert_datetimefield_value(self, value: Any, expression: Any, connection: Any) -> Any: ... def convert_uuidfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ... - def binary_placeholder_sql(self, value: Any) -> Any: ... def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any) -> Any: ... def explain_query_prefix(self, format: Any | None = ..., **options: Any) -> Any: ... - def regex_lookup(self, lookup_type: str) -> Any: ... def insert_statement(self, on_conflict: OnConflict | None = ...) -> str: ... def lookup_cast(self, lookup_type: str, internal_type: Any | None = ...) -> Any: ... diff --git a/django-stubs/db/backends/oracle/operations.pyi b/django-stubs/db/backends/oracle/operations.pyi index 8047fecd5..9d104286f 100644 --- a/django-stubs/db/backends/oracle/operations.pyi +++ b/django-stubs/db/backends/oracle/operations.pyi @@ -17,7 +17,6 @@ class DatabaseOperations(BaseDatabaseOperations): def datetime_extract_sql(self, lookup_type: str, sql: str, params: Any, tzname: str | None) -> tuple[str, Any]: ... def datetime_trunc_sql(self, lookup_type: str, sql: str, params: Any, tzname: str | None) -> str: ... def time_trunc_sql(self, lookup_type: str, sql: str, params: Any, tzname: str | None = ...) -> str: ... - def get_db_converters(self, expression: Any) -> list[Any]: ... def convert_textfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ... def convert_binaryfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ... def convert_booleanfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ... @@ -40,15 +39,7 @@ class DatabaseOperations(BaseDatabaseOperations): def max_in_list_size(self) -> int: ... def max_name_length(self) -> int: ... def pk_default_value(self) -> str: ... - def prep_for_iexact_query(self, x: Any) -> str: ... def process_clob(self, value: Any) -> Any: ... - def quote_name(self, name: str) -> str: ... - def regex_lookup(self, lookup_type: str) -> str: ... - def return_insert_columns(self, fields: Any) -> Any: ... - def sequence_reset_by_name_sql(self, style: Any, sequences: Any) -> list[str]: ... - def sequence_reset_sql(self, style: Any, model_list: Any) -> list[str]: ... - def start_transaction_sql(self) -> str: ... - def tablespace_sql(self, tablespace: Any, inline: bool = ...) -> str: ... def adapt_datefield_value(self, value: Any) -> Any: ... def adapt_datetimefield_value(self, value: Any) -> Any: ... def adapt_timefield_value(self, value: Any) -> Any: ... @@ -56,4 +47,3 @@ class DatabaseOperations(BaseDatabaseOperations): def bulk_insert_sql(self, fields: Any, placeholder_rows: Any) -> str: ... def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any) -> Any: ... def bulk_batch_size(self, fields: Any, objs: Any) -> int: ... - def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool: ... diff --git a/django-stubs/db/models/expressions.pyi b/django-stubs/db/models/expressions.pyi index 304050c70..be95aaca2 100644 --- a/django-stubs/db/models/expressions.pyi +++ b/django-stubs/db/models/expressions.pyi @@ -1,9 +1,10 @@ import datetime -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from decimal import Decimal from typing import Any, ClassVar, Generic, Literal, TypeVar from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.backends.base.operations import _Converter from django.db.models import Q, fields from django.db.models.fields import Field from django.db.models.lookups import Lookup, Transform @@ -63,7 +64,7 @@ class BaseExpression: window_compatible: bool allowed_default: bool def __init__(self, output_field: Field | None = ...) -> None: ... - def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Callable]: ... + def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[_Converter]: ... def get_source_expressions(self) -> list[Any]: ... def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: ... @cached_property @@ -89,7 +90,7 @@ class BaseExpression: @cached_property def output_field(self) -> Field: ... @cached_property - def convert_value(self) -> Callable: ... + def convert_value(self) -> _Converter: ... def get_lookup(self, lookup: str) -> type[Lookup] | None: ... def get_transform(self, name: str) -> type[Transform] | None: ... def relabeled_clone(self, change_map: Mapping[str, str]) -> Self: ... From 51bad6566b04a7977727f6c5bbc842fef5e482fb Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:04:16 +0200 Subject: [PATCH 3/4] Add type for `cast_char_field_without_max_length` --- django-stubs/db/backends/base/operations.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django-stubs/db/backends/base/operations.pyi b/django-stubs/db/backends/base/operations.pyi index 6ad3f8d93..50bd08d89 100644 --- a/django-stubs/db/backends/base/operations.pyi +++ b/django-stubs/db/backends/base/operations.pyi @@ -22,7 +22,7 @@ class BaseDatabaseOperations: integer_field_ranges: dict[str, tuple[int, int]] set_operators: dict[str, str] cast_data_types: dict[Any, Any] - cast_char_field_without_max_length: Any + cast_char_field_without_max_length: str | None PRECEDING: str FOLLOWING: str UNBOUNDED_PRECEDING: str From 73353a9ce709e2835b3eb34f0d0e3ca1072e8028 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 30 Jun 2024 12:05:25 +0200 Subject: [PATCH 4/4] Fix overrides --- django-stubs/contrib/gis/db/models/aggregates.pyi | 4 +++- django-stubs/db/models/functions/datetime.pyi | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/django-stubs/contrib/gis/db/models/aggregates.pyi b/django-stubs/contrib/gis/db/models/aggregates.pyi index 88497cfc6..7b24feda6 100644 --- a/django-stubs/contrib/gis/db/models/aggregates.pyi +++ b/django-stubs/contrib/gis/db/models/aggregates.pyi @@ -1,7 +1,7 @@ from typing import Any from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.models import Aggregate +from django.db.models import Aggregate, Expression from django.db.models.sql.compiler import SQLCompiler, _AsSqlType class GeoAggregate(Aggregate): @@ -15,10 +15,12 @@ class Collect(GeoAggregate): class Extent(GeoAggregate): name: str def __init__(self, expression: Any, **extra: Any) -> None: ... + def convert_value(self, value: Any, expression: Expression, connection: BaseDatabaseWrapper) -> Any: ... class Extent3D(GeoAggregate): name: str def __init__(self, expression: Any, **extra: Any) -> None: ... + def convert_value(self, value: Any, expression: Expression, connection: BaseDatabaseWrapper) -> Any: ... class MakeLine(GeoAggregate): name: str diff --git a/django-stubs/db/models/functions/datetime.pyi b/django-stubs/db/models/functions/datetime.pyi index fde368f62..21e9c9ff3 100644 --- a/django-stubs/db/models/functions/datetime.pyi +++ b/django-stubs/db/models/functions/datetime.pyi @@ -4,7 +4,7 @@ from typing import Any, ClassVar from django.db import models from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Func, Transform -from django.db.models.expressions import Combinable +from django.db.models.expressions import Combinable, Expression from django.db.models.fields import Field from django.db.models.sql.compiler import SQLCompiler, _AsSqlType @@ -44,6 +44,7 @@ class TruncBase(TimezoneMixin, Transform): self, expression: Combinable | str, output_field: Field | None = ..., tzinfo: tzinfo | None = ..., **extra: Any ) -> None: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ... # type: ignore[override] + def convert_value(self, value: Any, expression: Expression, connection: BaseDatabaseWrapper) -> Any: ... class Trunc(TruncBase): def __init__(