Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The in operator in ResultSetCollection.append causes problems with numpy arrays #1049

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 0.11.0dev

* [API Change] Disabled `%sql` and `%%sql` on Databricks ([#1047](https://github.com/ploomber/jupysql/issues/1047))
* [Fix] Add support for numpy arrays in result sets

## 0.10.17 (2025-01-08)

Expand Down
28 changes: 26 additions & 2 deletions src/sql/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from difflib import get_close_matches
import atexit
from functools import partial
from itertools import starmap

import sqlalchemy
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -36,6 +37,7 @@
from sql.warnings import JupySQLQuotedNamedParametersWarning, JupySQLRollbackPerformed
from sql import _current
from sql.connection import error_handling
from sql.run.resultset import ResultSet

BASE_DOC_URL = "https://jupysql.ploomber.io/en/latest"

Expand Down Expand Up @@ -107,13 +109,35 @@ def extract_module_name_from_NoSuchModuleError(e):
return str(e).split(":")[-1].split(".")[-1]


def _bool(x):
if x is True or x is False:
return x
return all(x)


def _eq(a, b):
return a is b or _bool(a == b)


def _results(x):
if isinstance(x, ResultSet):
return x._results
return x


class ResultSetCollection:
def __init__(self) -> None:
self._result_sets = []

def append(self, result):
if result in self._result_sets:
self._result_sets.remove(result)
for idx in reversed(
[
i
for i, item in enumerate(self._result_sets)
if all(starmap(_eq, zip(_results(result), _results(item))))
]
):
self._result_sets.pop(idx)
Comment on lines +133 to +140

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some comments to explain what this is doing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll explain here and let you comment on it first. What this is doing is replacing if result in self._result_sets: self._result_sets.remove(result) which simply removes all items equal to result from the result set list. I'm not actually clear on why this is done, but the in operator requires a bool to be returned by the == operation. That doesn't happen with np.array, pd.Series, etc.

What the new code does is iterates over result and item and compares each item individually using the _eq function which returns a bool for both atomic values and array comparisons. If all of the _eq operations return true, the index of that item is added to the resulting list. The final list of indexes is reversed so that they can be popped from self._result_sets without messing up the indexes of the remaining items.


self._result_sets.append(result)

Expand Down
33 changes: 26 additions & 7 deletions src/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,23 +1067,42 @@ def test_transpile_query_doesnt_transpile_if_it_doesnt_need_to(monkeypatch):

def test_result_set_collection_append():
collection = ResultSetCollection()
collection.append(1)
collection.append(2)
collection.append((1,))
collection.append((2,))

Comment on lines 1068 to 1072

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep the original test and add a new one

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason this was changed was because the code in append needs the values to be iterable objects. This is more consistent with what the result objects that get appended from the connection are. I'm not sure there is a case in real-world code where the value will be atomic like the original test.

assert collection._result_sets == [1, 2]
assert collection._result_sets == [(1,), (2,)]


def test_result_set_collection_append_numpy():
try:
import numpy as np

a1 = (np.array([1, 2]),)
a2 = (np.array([3, 4]),)

collection = ResultSetCollection()
collection.append(a1)
collection.append(a2)

assert len(collection._result_sets) == 2
assert collection._result_sets[0] is a1
assert collection._result_sets[1] is a2

except ImportError:
pass
Comment on lines +1076 to +1092

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add numpy as a dev dependency and get rid of the ImportError:

DEV = [



def test_result_set_collection_iterate():
collection = ResultSetCollection()
collection.append(1)
collection.append(2)
collection.append((1,))
collection.append((2,))

assert list(collection) == [1, 2]
assert list(collection) == [(1,), (2,)]


def test_result_set_collection_is_last():
collection = ResultSetCollection()
first, second = object(), object()
first, second = (object(),), (object(),)
collection.append(first)

assert len(collection) == 1
Comment on lines 1095 to 1108

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, keep old tests, add new ones

Expand Down
Loading