Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSONPathTokenizer, parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.parsers.base import BaseParser
from sqlglot.time import TIMEZONES, format_time, subsecond_precision
from sqlglot.time import STRICT_TIME_FORMATS, TIMEZONES, format_time, subsecond_precision
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
from sqlglot.typing import EXPRESSION_METADATA
Expand Down Expand Up @@ -134,6 +134,16 @@ class NormalizationStrategy(str, AutoName):
"""Always case-insensitive (uppercase), regardless of quotes."""


def _with_strict_time_fallback(inverse_mapping: dict[str, str]) -> dict[str, str]:
# Dialects that define a "strict" format (e.g. Spark) keep their own mapping;
# everyone else degrades it to the lax counterpart's mapping, so the internal
# token never leaks into generated SQL.
for strict_format, lax_format in STRICT_TIME_FORMATS.items():
inverse_mapping.setdefault(strict_format, inverse_mapping.get(lax_format, lax_format))

return inverse_mapping


class _Dialect(type):
_classes: dict[str, Type[Dialect]] = {}

Expand Down Expand Up @@ -232,16 +242,21 @@ def __new__(cls, clsname, bases, attrs):
cls._classes[enum.value if enum is not None else clsname.lower()] = klass

klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
klass.STRICT_TIME_TRIE = new_trie(klass.STRICT_TIME_MAPPING)
klass.LENIENT_INVERSE_TIME_TRIE = new_trie(klass.LENIENT_INVERSE_TIME_MAPPING)
klass.FORMAT_TRIE = (
new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
)
# Merge class-defined INVERSE_TIME_MAPPING with auto-generated mappings
# This allows dialects to define custom inverse mappings for roundtrip correctness
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} | (
klass.__dict__.get("INVERSE_TIME_MAPPING") or {}
klass.INVERSE_TIME_MAPPING = _with_strict_time_fallback(
{v: k for k, v in klass.TIME_MAPPING.items()}
| (klass.__dict__.get("INVERSE_TIME_MAPPING") or {})
)
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
klass.INVERSE_FORMAT_MAPPING = _with_strict_time_fallback(
{v: k for k, v in klass.FORMAT_MAPPING.items()}
)
klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)

klass.INVERSE_CREATABLE_KIND_MAPPING = {
Expand Down Expand Up @@ -412,6 +427,20 @@ class Dialect(metaclass=_Dialect):
TIME_MAPPING: dict[str, str] = {}
"""Associates this dialect's time formats with their equivalent Python `strftime` formats."""

STRICT_TIME_MAPPING: dict[str, str] = {}
"""
Variant of `TIME_MAPPING` used when *parsing* a string with a format (e.g. `StrToTime`).
Lets dialects with strict parsing (e.g. Spark 3+'s zero-padded `MM`/`dd`) map those to a
distinct canonical format, preserving the roundtrip. Empty means `TIME_MAPPING` is used.
"""

LENIENT_INVERSE_TIME_MAPPING: dict[str, str] = {}
"""
Inverse mapping used when *generating* a parse format (e.g. `StrToTime`) for dialects that
parse leniently (e.g. Spark). Maps the canonical specifiers to their lenient single-letter
forms, and the strict tokens back to the padded forms.
"""

# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Exprs-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
FORMAT_MAPPING: dict[str, str] = {}
Expand Down Expand Up @@ -770,10 +799,12 @@ class Dialect(metaclass=_Dialect):

# A trie of the time_mapping keys
TIME_TRIE: dict = {}
STRICT_TIME_TRIE: dict = {}
FORMAT_TRIE: dict = {}

INVERSE_TIME_MAPPING: dict[str, str] = {}
INVERSE_TIME_TRIE: dict = {}
LENIENT_INVERSE_TIME_TRIE: dict = {}
INVERSE_FORMAT_MAPPING: dict[str, str] = {}
INVERSE_FORMAT_TRIE: dict = {}

Expand Down Expand Up @@ -966,16 +997,23 @@ def get_or_raise(cls, dialect: DialectType) -> Dialect:
raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

@classmethod
def format_time(cls, expression: str | exp.Expr | None) -> exp.Expr | None:
def format_time(
cls, expression: str | exp.Expr | None, strict: bool = False
) -> exp.Expr | None:
"""Converts a time format in this dialect to its equivalent Python `strftime` format."""
if strict and cls.STRICT_TIME_MAPPING:
mapping, trie = cls.STRICT_TIME_MAPPING, cls.STRICT_TIME_TRIE
else:
mapping, trie = cls.TIME_MAPPING, cls.TIME_TRIE

if isinstance(expression, str):
return exp.Literal.string(
# the time formats are quoted
format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
format_time(expression[1:-1], mapping, trie)
)

if expression and expression.is_string:
return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
return exp.Literal.string(format_time(expression.this, mapping, trie))

return expression

Expand Down Expand Up @@ -1544,6 +1582,13 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str:
return self.sql(result)


# Expressions that parse a string with a format (vs. formatting one, like TimeToStr).
# Dialects with strict parsing semantics (STRICT_TIME_MAPPING) use it for these on the
# parser side, and the corresponding generator (e.g. SparkGenerator.format_time) reuses
# this same set to emit the lenient inverse, which is what preserves the roundtrip.
STRICT_PARSE_TIME_EXPRESSIONS = (exp.StrToTime, exp.StrToDate, exp.TsOrDsToDate)


def build_formatted_time(
exp_class: Type[E], dialect_override: str | None = None, default: bool | str | None = None
) -> t.Callable[[BuilderArgs, Dialect], E]:
Expand All @@ -1569,7 +1614,10 @@ def _builder(args: BuilderArgs, dialect: Dialect) -> E:
if not fmt:
fmt = target_dialect.TIME_FORMAT if default is True else default or None

return exp_class(this=seq_get(args, 0), format=target_dialect.format_time(fmt))
strict = exp_class in STRICT_PARSE_TIME_EXPRESSIONS
return exp_class(
this=seq_get(args, 0), format=target_dialect.format_time(fmt, strict=strict)
)

return _builder

Expand Down
17 changes: 17 additions & 0 deletions sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ class Spark(Spark2):
ARRAY_FUNCS_PROPAGATES_NULLS = True
EXPRESSION_METADATA = EXPRESSION_METADATA.copy()

# Spark 3+ parses MM/dd strictly (single-digit months/days don't parse), unlike the
# lax %m/%d other dialects produce. When *parsing* (StrToTime/StrToDate/...), MM/dd
# map to a distinct canonical token so the strict roundtrip is preserved; formatting
# keeps the regular padded %m/%d -> MM/dd (TIME_MAPPING is unchanged).
STRICT_TIME_MAPPING = {
**Spark2.TIME_MAPPING,
"MM": "%mstrict",
"dd": "%dstrict",
}
# Generating a parse format is lenient: %m/%d -> M/d (matching strptime), while the
# strict tokens map back to MM/dd.
LENIENT_INVERSE_TIME_MAPPING = {
**{v: k for k, v in STRICT_TIME_MAPPING.items()},
"%m": "M",
"%d": "d",
}

class Tokenizer(Spark2.Tokenizer):
STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False

Expand Down
12 changes: 10 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sqlglot.expressions import apply_index_offset
from sqlglot.helper import csv, name_sequence, seq_get
from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS
from sqlglot.time import format_time
from sqlglot.time import STRICT_TIME_FORMATS, STRICT_TIME_TRIE, format_time
from sqlglot.tokens import TokenType

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -4052,7 +4052,15 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: str | None = None) -> str:

# Base implementation that excludes safe, zone, and target_type metadata args
def strtotime_sql(self, expression: exp.StrToTime) -> str:
return self.func("STR_TO_TIME", expression.this, expression.args.get("format"))
# STR_TO_TIME is sqlglot's canonical form, so the format must stay canonical
# strftime - we only strip the internal "strict" tokens (e.g. Spark's %mstrict)
# rather than routing through self.format_time(), which would also rewrite every
# other specifier into the dialect's INVERSE_TIME_MAPPING.
return self.func(
"STR_TO_TIME",
expression.this,
self.format_time(expression, STRICT_TIME_FORMATS, STRICT_TIME_TRIE),
)

def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
Expand Down
17 changes: 16 additions & 1 deletion sqlglot/generators/spark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations


from sqlglot import exp
from sqlglot import generator
from sqlglot.dialects.dialect import (
STRICT_PARSE_TIME_EXPRESSIONS,
array_append_sql,
rename_func,
unit_to_var,
Expand Down Expand Up @@ -89,6 +89,21 @@ class SparkGenerator(Spark2Generator):
exp.DType.SMALLMONEY: ((6, 4), ()),
}

def format_time(
self,
expression: exp.Expr,
inverse_time_mapping: dict[str, str] | None = None,
inverse_time_trie: dict | None = None,
) -> str | None:
# Spark 3+ parses these leniently, so emit M/d (not the padded MM/dd used for
# formatting) for the canonical %m/%d. The expression set is shared with the parser
# (STRICT_PARSE_TIME_EXPRESSIONS), which is what guarantees the strict roundtrip.
if isinstance(expression, STRICT_PARSE_TIME_EXPRESSIONS):
inverse_time_mapping = inverse_time_mapping or self.dialect.LENIENT_INVERSE_TIME_MAPPING
inverse_time_trie = inverse_time_trie or self.dialect.LENIENT_INVERSE_TIME_TRIE

return super().format_time(expression, inverse_time_mapping, inverse_time_trie)

TRANSFORMS = {
k: v
for k, v in {
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
# https://docs.python.org/3/library/time.html#time.strftime
from sqlglot.trie import TrieResult, in_trie, new_trie

# "Strict" canonical time formats round-trip in dialects that define them (e.g.
# Spark 3+'s zero-padded MM/dd, which don't parse single-digit values) and degrade
# to their lax counterpart elsewhere. These are sqlglot-internal tokens, not valid
# strftime directives, so they must be normalized away when emitting generic SQL.
STRICT_TIME_FORMATS = {"%mstrict": "%m", "%dstrict": "%d"}
STRICT_TIME_TRIE = new_trie(STRICT_TIME_FORMATS)


def format_time(
string: str, mapping: dict[str, str], trie: dict[t.Any, t.Any] | None = None
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def test_time(self):
"presto": "DATE_PARSE(x, '%Y-%m-%dT%T')",
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH24:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-dTHH:mm:ss')",
},
)
self.validate_all(
Expand All @@ -776,7 +776,7 @@ def test_time(self):
"postgres": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-M-d')",
},
)
self.validate_all(
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def test_time(self):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)",
"spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')",
"spark": "TO_DATE(x, 'yyyy-M-dTHH:mm:ss')",
"doris": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
},
)
Expand All @@ -1231,7 +1231,7 @@ def test_time(self):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
"hive": "CAST(x AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
"spark": "TO_DATE(x, 'yyyy-M-d')",
"doris": "STR_TO_DATE(x, '%Y-%m-%d')",
},
)
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,9 @@ def test_datetime_functions(self):
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)",
"spark": "TO_DATE(x)",
"spark": "TO_DATE(x, 'yyyy-M-d')",
"snowflake": "TO_DATE(x, 'yyyy-mm-DD')",
"databricks": "TO_DATE(x)",
"databricks": "TO_DATE(x, 'yyyy-M-d')",
},
)
self.validate_all(
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_time(self):
"duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_PARSE(x, '%Y-%m-%d %T')",
"hive": "CAST(x AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-d HH:mm:ss')",
},
)
self.validate_all(
Expand All @@ -315,7 +315,7 @@ def test_time(self):
"duckdb": "STRPTIME(x, '%Y-%m-%d')",
"presto": "DATE_PARSE(x, '%Y-%m-%d')",
"hive": "CAST(x AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-d')",
},
)
self.validate_all(
Expand All @@ -330,7 +330,7 @@ def test_time(self):
"duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
"presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
"hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-M-d')",
},
)
self.validate_all(
Expand All @@ -339,7 +339,7 @@ def test_time(self):
"duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
"presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
"hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-M-d')",
},
)
self.validate_all(
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ def test_snowflake(self):
"bigquery": "SELECT PARSE_TIMESTAMP('%d-%m-%Y %I:%M:%S', col) FROM t",
"duckdb": "SELECT STRPTIME(col, '%d-%m-%Y %I:%M:%S') FROM t",
"snowflake": "SELECT TO_TIMESTAMP(col, 'DD-mm-yyyy hh12:mi:ss') FROM t",
"spark": "SELECT TO_TIMESTAMP(col, 'dd-MM-yyyy hh:mm:ss') FROM t",
"spark": "SELECT TO_TIMESTAMP(col, 'd-M-yyyy hh:mm:ss') FROM t",
},
)
self.validate_all(
Expand Down Expand Up @@ -1904,7 +1904,7 @@ def test_snowflake(self):
write={
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %T', '04/05/2013 01:02:03')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')",
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'M/d/yyyy HH:mm:ss')",
},
)
self.validate_all(
Expand Down
29 changes: 27 additions & 2 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,14 +659,39 @@ def test_spark(self):
},
)
self.validate_all(
"SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')",
"SELECT TO_TIMESTAMP('2016-1-1', 'yyyy-M-d')",
read={
"duckdb": "SELECT STRPTIME('2016-12-31', '%Y-%m-%d')",
"duckdb": "SELECT STRPTIME('2016-1-1', '%Y-%m-%d')",
},
write={
"": "SELECT STR_TO_TIME('2016-1-1', '%Y-%-m-%-d')",
"duckdb": "SELECT STRPTIME('2016-1-1', '%Y-%-m-%-d')",
"spark": "SELECT TO_TIMESTAMP('2016-1-1', 'yyyy-M-d')",
},
)
# Spark 3+ parses MM/dd strictly, so the strict parse format roundtrips, but
# widens to the lax %m/%d for dialects that parse leniently (e.g. duckdb).
self.validate_all(
"SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')",
write={
"": "SELECT STR_TO_TIME('2016-12-31', '%Y-%m-%d')",
"duckdb": "SELECT STRPTIME('2016-12-31', '%Y-%m-%d')",
"spark": "SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')",
"databricks": "SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')",
},
)
# Formatting keeps zero-padded MM/dd, unlike the lenient parsing above.
self.validate_identity("SELECT DATE_FORMAT(x, 'yyyy-MM-dd')")
# The strict canonical token must degrade in BigQuery's FORMAT clause too,
# not just INVERSE_TIME_MAPPING (it previously leaked as 'MMstrict/DDstrict').
self.validate_all(
"SELECT TO_DATE(x, 'MM/dd/yyyy')",
write={
"": "SELECT CAST(STR_TO_TIME(x, '%m/%d/%Y') AS DATE)",
"duckdb": "SELECT CAST(CAST(TRY_STRPTIME(x, '%m/%d/%Y') AS TIMESTAMP) AS DATE)",
"bigquery": "SELECT CAST(SAFE_CAST(x AS TIMESTAMP FORMAT 'MM/DD/YYYY') AS DATE)",
"spark": "SELECT TO_DATE(x, 'MM/dd/yyyy')",
"databricks": "SELECT TO_DATE(x, 'MM/dd/yyyy')",
},
)
self.validate_all(
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def test_cast(self):
write={
"teradata": "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')",
"bigquery": "PARSE_DATE('%Y-%d', '1992-01')",
"databricks": "TO_DATE('1992-01', 'yyyy-dd')",
"databricks": "TO_DATE('1992-01', 'yyyy-d')",
"mysql": "STR_TO_DATE('1992-01', '%Y-%d')",
"spark": "TO_DATE('1992-01', 'yyyy-dd')",
"spark": "TO_DATE('1992-01', 'yyyy-d')",
"": "STR_TO_DATE('1992-01', '%Y-%d')",
},
)
Expand Down
Loading
Loading