diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index d20c755ee3..e8ff930ad8 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -1128,3 +1128,14 @@ def __getstate__(self) -> Any: del state["naming"] del state["data_item_normalizer"] return state + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Schema): + raise NotImplementedError( + f"Equality between `dlt.Schema` object and {type(other).__name__} is not supported." + ) + + return self.version_hash == other.version_hash + + def __hash__(self) -> int: + return hash(self.version_hash) diff --git a/dlt/dataset/dataset.py b/dlt/dataset/dataset.py index 9711725836..ca3117e9c1 100644 --- a/dlt/dataset/dataset.py +++ b/dlt/dataset/dataset.py @@ -1,7 +1,9 @@ from __future__ import annotations +from contextlib import contextmanager +import tempfile from types import TracebackType -from typing import Any, Optional, Type, Union, TYPE_CHECKING, Literal, overload +from typing import Any, Generator, Optional, Type, Union, TYPE_CHECKING, Literal, overload from sqlglot.schema import Schema as SQLGlotSchema import sqlglot.expressions as sge @@ -12,21 +14,27 @@ from dlt.common.json import json from dlt.common.destination.reference import AnyDestination, TDestinationReferenceArg, Destination from dlt.common.destination.client import JobClientBase, SupportsOpenTables, WithStateSync -from dlt.common.schema import Schema -from dlt.common.typing import Self -from dlt.common.schema.typing import C_DLT_LOAD_ID +from dlt.common.typing import Self, TDataItems +from dlt.common.schema.typing import C_DLT_LOAD_ID, TWriteDisposition +from dlt.common.pipeline import LoadInfo from dlt.common.utils import simple_repr, without_none from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.dataset import lineage from dlt.dataset.utils import get_destination_clients from dlt.destinations.queries import build_row_counts_expr from dlt.common.destination.exceptions import SqlClientNotAvailable +from dlt.common.schema.exceptions import ( + TableNotFound, +) if TYPE_CHECKING: from ibis import ir from ibis import BaseBackend as IbisBackend +_INTERNAL_DATASET_PIPELINE_NAME_TEMPLATE = "_dlt_dataset_{dataset_name}" + + class Dataset: """Access to dataframes and arrow tables in the destination dataset via dbapi""" @@ -170,6 +178,62 @@ def is_same_physical_destination(self, other: dlt.Dataset) -> bool: """ return is_same_physical_destination(self, other) + # TODO explain users can inspect `_dlt_loads` table to differentiate data originating + # from `pipeline.run()` or `dataset.write()` + @contextmanager + def write_pipeline(self) -> Generator[dlt.Pipeline, None, None]: + """Get the internal pipeline used by `Dataset.write()`. + It uses "_dlt_dataset_{dataset_name}" as pipeline name. Its working directory is + so that load packages can be inspected after a run, but is cleared before each write. + + """ + pipeline = _get_internal_pipeline( + dataset_name=self.dataset_name, destination=self._destination + ) + yield pipeline + + def write( + self, + data: TDataItems, + *, + table_name: str, + ) -> LoadInfo: + """Write `data` to the specified table. + + This method uses a full-on `dlt.Pipeline` internally. You can retrieve this pipeline + using `Dataset.get_write_pipeline()` for complete flexibility. + The resulting load packages can be inspected in the pipeline's working directory which is + named "_dlt_dataset_{dataset_name}". + This directory will be wiped before each `write()` call. + """ + with self.write_pipeline() as internal_pipeline: + # drop all load packages from previous writes + # internal_pipeline._wipe_working_folder() + internal_pipeline.drop() + + # get write dispostion for existing table from schema (or "append" if table is new) + try: + write_disposition = self.schema.get_table(table_name)["write_disposition"] + except TableNotFound: + write_disposition = "append" + # TODO should we try/except this run to gracefully handle failed writes? + info = internal_pipeline.run( + data, + dataset_name=self.dataset_name, + table_name=table_name, + schema=self.schema, + write_disposition=write_disposition, + ) + + # maybe update the dataset schema + self._update_schema(internal_pipeline.default_schema) + return info + + def _update_schema(self, new_schema: dlt.Schema) -> None: + """Update the dataset schema""" + # todo: verify if we need to purge any cached objects (eg. sql_client) + self._schema = new_schema + def query( self, query: Union[str, sge.Select, ir.Expr], @@ -387,7 +451,7 @@ def __str__(self) -> str: def dataset( destination: TDestinationReferenceArg, dataset_name: str, - schema: Union[Schema, str, None] = None, + schema: Union[dlt.Schema, str, None] = None, ) -> Dataset: return Dataset(destination, dataset_name, schema) @@ -451,3 +515,22 @@ def _get_dataset_schema_from_destination_using_dataset_name( schema = dlt.Schema.from_stored_schema(json.loads(stored_schema.schema)) return schema + + +def _get_internal_pipeline( + dataset_name: str, + destination: TDestinationReferenceArg, + pipelines_dir: str = None, +) -> dlt.Pipeline: + """Setup the internal pipeline used by `Dataset.write()`""" + pipeline = dlt.pipeline( + pipeline_name=_INTERNAL_DATASET_PIPELINE_NAME_TEMPLATE.format(dataset_name=dataset_name), + dataset_name=dataset_name, + destination=destination, + pipelines_dir=pipelines_dir, + ) + # the internal write pipeline should be stateless; it is limited to the data passed + # it shouldn't persist state (e.g., incremntal cursor) and interfere with other `pipeline.run()` + pipeline.config.restore_from_destination = False + + return pipeline diff --git a/tests/dataset/__init__.py b/tests/dataset/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dataset/test_dataset_write.py b/tests/dataset/test_dataset_write.py new file mode 100644 index 0000000000..b0f674b6e4 --- /dev/null +++ b/tests/dataset/test_dataset_write.py @@ -0,0 +1,276 @@ +import pathlib +from typing import Any, Dict, List, Sequence, Tuple + +import duckdb +import pytest + +import dlt +from dlt.common.pipeline import LoadInfo +from dlt.dataset.dataset import ( + _INTERNAL_DATASET_PIPELINE_NAME_TEMPLATE, + is_same_physical_destination, + _get_internal_pipeline, +) +from dlt.destinations.exceptions import DatabaseUndefinedRelation + +from tests.utils import preserve_environ, patch_home_dir, autouse_test_storage, TEST_STORAGE_ROOT +from tests.pipeline.utils import assert_load_info, assert_records_as_set, assert_table_counts + +from dlt.common.destination import Destination, TDestinationReferenceArg + + +@pytest.fixture() +def pipeline_and_foo_dataset() -> Tuple[dlt.Pipeline, dlt.Dataset, str, Destination]: + dataset_name = "foo" + destination = dlt.destinations.duckdb(duckdb.connect()) + pipeline = dlt.pipeline(destination=destination, dataset_name=dataset_name) + dataset = dlt.dataset(destination, dataset_name) + return pipeline, dataset, dataset_name, destination + + +def test_get_internal_pipeline( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + _, dataset, dataset_name, destination = pipeline_and_foo_dataset + + expected_pipeline_name = _INTERNAL_DATASET_PIPELINE_NAME_TEMPLATE.format( + dataset_name=dataset_name + ) + + internal_pipeline = _get_internal_pipeline(dataset_name=dataset_name, destination=destination) + internal_dataset = internal_pipeline.dataset() + + assert isinstance(internal_pipeline, dlt.Pipeline) + assert internal_pipeline.pipeline_name == expected_pipeline_name + assert internal_pipeline.dataset_name == dataset_name + assert internal_pipeline.destination == destination + assert is_same_physical_destination(dataset, internal_dataset) + assert dataset.schema == internal_dataset.schema + + +def test_dataset_get_write_pipeline( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + _, dataset, dataset_name, destination = pipeline_and_foo_dataset + expected_pipeline_name = _INTERNAL_DATASET_PIPELINE_NAME_TEMPLATE.format( + dataset_name=dataset_name + ) + + with dataset.write_pipeline() as write_pipeline: + write_dataset = write_pipeline.dataset() + + assert isinstance(write_pipeline, dlt.Pipeline) + assert write_pipeline.pipeline_name == expected_pipeline_name + assert write_pipeline.dataset_name == dataset_name + assert write_pipeline.destination == destination + + assert is_same_physical_destination(dataset, write_dataset) + assert dataset.schema == write_dataset.schema + + +def test_dataset_write( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + _, dataset, _, _ = pipeline_and_foo_dataset + table_name = "bar" + items = [{"id": 0, "value": "bingo"}, {"id": 1, "value": "bongo"}] + + # TODO this is currently odd because the tables exists on the `Schema` + # used by the `Dataset` but the tables don't exist on the destination yet + assert dataset.tables == ["_dlt_version", "_dlt_loads"] + with pytest.raises(DatabaseUndefinedRelation): + dataset.table("_dlt_version").fetchall() + with pytest.raises(DatabaseUndefinedRelation): + dataset.table("_dlt_loads").fetchall() + + load_info = dataset.write(items, table_name=table_name) + + assert isinstance(load_info, LoadInfo) + assert dataset.tables == ["bar", "_dlt_version", "_dlt_loads"] + assert dataset.table("bar").select("id", "value").fetchall() == [ + tuple(i.values()) for i in items + ] + + +def test_dataset_write_to_existing_table( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + pipeline, dataset, _, _ = pipeline_and_foo_dataset + + # create existing table + data = [{"id": 0, "value": 1}, {"id": 1, "value": 2}] + pipeline.run(data, table_name="numbers") + assert_table_counts(pipeline, {"numbers": 2}) + + schema_before = dataset.schema.clone() + + # execute + load_info = dataset.write([{"id": 2, "value": 3}], table_name="numbers") + + # verify data got written + assert_load_info(load_info, expected_load_packages=1) + assert_table_counts(pipeline, {"numbers": 3}) + + # verify data is readable from the dataset + expected_rows = data + [{"id": 2, "value": 3}] + assert dataset.table("numbers").select("id", "value").fetchall() == [ + tuple(i.values()) for i in expected_rows + ] + + # schema didn't change + assert schema_before == dataset.schema + + +def test_dataset_write_respects_write_disposition_of_existing_tables( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + pipeline, _, dataset_name, destination = pipeline_and_foo_dataset + + # create existing table with merge write disposition + data = [{"id": 0, "value": 1}, {"id": 1, "value": 2}] + pipeline.run(data, table_name="merge_table", write_disposition="merge", primary_key="id") + pipeline.run(data, table_name="replace_table", write_disposition="replace") + pipeline.run(data, table_name="append_table", write_disposition="append") + assert_table_counts(pipeline, {"merge_table": 2, "replace_table": 2, "append_table": 2}) + + dataset = dlt.dataset(destination, dataset_name) + assert dataset.schema.get_table("merge_table")["write_disposition"] == "merge" + assert dataset.schema.get_table("replace_table")["write_disposition"] == "replace" + assert dataset.schema.get_table("append_table")["write_disposition"] == "append" + + schema_before = dataset.schema.clone() + + # execute + new_data = [{"id": 0, "value": 3}] + dataset.write(new_data, table_name="merge_table") + dataset.write(new_data, table_name="replace_table") + dataset.write(new_data, table_name="append_table") + + assert_table_counts(pipeline, {"merge_table": 2, "replace_table": 1, "append_table": 3}) + + # verify data that is returned from the dataset for id 0 + assert dataset.table("merge_table").where("id = 0").select("value").fetchall() == [(3,)] + assert dataset.table("replace_table").where("id = 0").select("value").fetchall() == [(3,)] + assert dataset.table("append_table").where("id = 0").select("value").fetchall() == [(1,), (3,)] + + # schema didn't change + assert schema_before == dataset.schema + + +def test_dataset_writes_new_table_to_existing_schema( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + pipeline, dataset, _, _ = pipeline_and_foo_dataset + + # create existing table in the destination + data = [{"id": 0, "value": 1}, {"id": 1, "value": 2}] + pipeline.run(data, table_name="numbers") + assert_table_counts(pipeline, {"numbers": 2}) + + schema_before = dataset.schema.clone() + + # execute + new_table_name = "letters" + load_info = dataset.write( + [{"id": 1, "value": "a"}, {"id": 2, "value": "b"}], table_name=new_table_name + ) + + assert_load_info(load_info, expected_load_packages=1) + + # assert schema has changed + assert schema_before != dataset.schema + + # new table should show up in the Dataset schema + assert "letters" in dataset.schema.data_table_names() + + # data is queryable from the Dataset + assert dataset.table("letters").select("id", "value").fetchall() == [ + tuple(i.values()) for i in [{"id": 1, "value": "a"}, {"id": 2, "value": "b"}] + ] + + # expect write_disposition of new table to be "append" + assert dataset.schema.get_table("letters")["write_disposition"] == "append" + + +@pytest.mark.xfail( + reason=( + "schema syncing is using version hash and doesnt check if there is a newer schema in the" + " _dlt_versions table" + ) +) # noqa: E501 +def test_pipeline_dataset_updates_after_dataset_write(): + # future issue: + # pipeline dataset should also see the new table after doing something + + # maybe after syncing the pipeline + # pipeline.sync_destination() + # assert "letters" not in pipeline_dataset.schema.data_table_names() + + # after syncing the schema + # pipeline.sync_schema(schema_name=dataset.schema.name) + + # after dropping the pipeline + + # using sql client on pipeline dataset will work + # pipeline_dataset = pipeline.dataset() + # assert pipeline_dataset.query("SELECT id, value FROM letters").fetchall() == [ + # tuple(i.values()) for i in [{"id": 1, "value": 'a'}, {"id": 2, "value": 'b'}] + # ] + pass + + +def test_data_write_wipes_working_directory( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + _, dataset, _, _ = pipeline_and_foo_dataset + + table_name = "bar" + storage_dir = pathlib.Path(TEST_STORAGE_ROOT) + + # create load package with faulty data + with dataset.write_pipeline() as write_pipeline: + write_pipeline.extract([{"id": 0, "value": "faulty"}], table_name=table_name) + + items = [{"id": 0, "value": "correct"}, {"id": 1, "value": "also correct"}] + load_info = dataset.write(items, table_name=table_name) + assert_load_info(load_info, expected_load_packages=1) + + assert dataset.table("bar").select("id", "value").fetchall() == [ + (0, "correct"), + (1, "also correct"), + ] + + +def test_internal_pipeline_can_write_to_scratchpad_dataset( + pipeline_and_foo_dataset: Tuple[dlt.Pipeline, dlt.Dataset, str, Destination] +): + _, dataset, dataset_name, destination = pipeline_and_foo_dataset + items = [{"id": 0, "value": "something"}] + + # write to the main dataset + dataset.write(items, table_name="bar") + + # write to a scratchpad dataset + # todo: maybe this could be exposed to dataset.write(data, dataset_name="scratchpad_1") + with dataset.write_pipeline() as write_pipeline: + new_items = [{"id": 1, "value": "something else"}] + load_info = write_pipeline.run(new_items, table_name="bar", dataset_name="scratchpad_1") + assert_load_info(load_info, expected_load_packages=1) + + # check that main dataset is still the same + dataset.row_counts().fetchall() == [("bar", 1)] + assert dataset.table("bar").select("id", "value").fetchall() == [ + (0, "something"), + ] + + # but new data exists in the scratchpad dataset + scratchpad_dataset = dlt.dataset(destination, "scratchpad_1") + scratchpad_dataset.table("bar").select("id", "value").fetchall() == [ + (1, "something else"), + ] + + +# Helpers +def _rows_to_dicts(rows: List[Tuple[Any, ...]], columns: Sequence[str]) -> List[Dict[str, Any]]: + """Convert SQL result tuples into dictionaries keyed by the supplied column names.""" + return [dict(zip(columns, row)) for row in rows]