Skip to content

Commit c330249

Browse files
authored
Query on multiple Pandas DataFrame (#89)
* Add dataframe.query() to run sql on multi dataframe * Add join dataframe test * Support passing dataframe as table
1 parent e503e5b commit c330249

File tree

7 files changed

+222
-58
lines changed

7 files changed

+222
-58
lines changed

README-zh.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,13 @@ chdb.query('select * from file("data.parquet", Parquet)', 'Dataframe')
7777
```python
7878
import chdb.dataframe as cdf
7979
import pandas as pd
80-
tbl = cdf.Table(dataframe=pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']}))
81-
ret_tbl = tbl.query('select * from __table__')
80+
# Join 2 DataFrames
81+
df1 = pd.DataFrame({'a': [1, 2, 3], 'b': ["one", "two", "three"]})
82+
df2 = pd.DataFrame({'c': [1, 2, 3], 'd': ["", "", ""]})
83+
ret_tbl = cdf.query(sql="select * from __tbl1__ t1 join __tbl2__ t2 on t1.a = t2.c",
84+
tbl1=df1, tbl2=df2)
8285
print(ret_tbl)
86+
# Query on the DataFrame Table
8387
print(ret_tbl.query('select b, sum(a) from __table__ group by b'))
8488
```
8589
</details>

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@ chdb.query('select * from file("data.parquet", Parquet)', 'Dataframe')
8282
```python
8383
import chdb.dataframe as cdf
8484
import pandas as pd
85-
tbl = cdf.Table(dataframe=pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']}))
86-
ret_tbl = tbl.query('select * from __table__')
85+
# Join 2 DataFrames
86+
df1 = pd.DataFrame({'a': [1, 2, 3], 'b': ["one", "two", "three"]})
87+
df2 = pd.DataFrame({'c': [1, 2, 3], 'd': ["", "", ""]})
88+
ret_tbl = cdf.query(sql="select * from __tbl1__ t1 join __tbl2__ t2 on t1.a = t2.c",
89+
tbl1=df1, tbl2=df2)
8790
print(ret_tbl)
91+
# Query on the DataFrame Table
8892
print(ret_tbl.query('select b, sum(a) from __table__ group by b'))
8993
```
9094
</details>

chdb/dataframe/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@
1111
if pd.__version__[0] < '2':
1212
print('Please upgrade pandas to version 2.0.0 or higher to have better performance')
1313

14-
from .query import *
14+
from .query import Table, pandas_read_parquet # noqa: C0413
15+
16+
query = Table.queryStatic
17+
18+
__all__ = ['Table', 'query', 'pandas_read_parquet']

chdb/dataframe/query.py

Lines changed: 155 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11
import os
22
import tempfile
3+
from io import BytesIO
4+
import re
35
import pandas as pd
46
import pyarrow as pa
5-
from io import BytesIO
67
from chdb import query as chdb_query
78

89

9-
class Table(object):
10+
class Table:
1011
"""
1112
Table is a wrapper of multiple formats of data buffer, including parquet file path,
1213
parquet bytes, and pandas dataframe.
1314
if use_memfd is True, will try using memfd_create to create a temp file in memory, which is
1415
only available on Linux. If failed, will fallback to use tempfile.mkstemp to create a temp file
1516
"""
1617

17-
def __init__(self,
18-
parquet_path: str = None,
19-
temp_parquet_path: str = None,
20-
parquet_memoryview: memoryview = None,
21-
dataframe: pd.DataFrame = None,
22-
arrow_table: pa.Table = None,
23-
use_memfd: bool = False):
18+
def __init__(
19+
self,
20+
parquet_path: str = None,
21+
temp_parquet_path: str = None,
22+
parquet_memoryview: memoryview = None,
23+
dataframe: pd.DataFrame = None,
24+
arrow_table: pa.Table = None,
25+
use_memfd: bool = False,
26+
):
2427
"""
2528
Initialize a Table object with one of parquet file path, parquet bytes, pandas dataframe or
2629
parquet table.
@@ -33,11 +36,11 @@ def __init__(self,
3336
self.use_memfd = use_memfd
3437

3538
def __del__(self):
36-
try:
37-
if self._temp_parquet_path is not None:
39+
if self._temp_parquet_path is not None:
40+
try:
3841
os.remove(self._temp_parquet_path)
39-
except:
40-
pass
42+
except OSError:
43+
pass
4144

4245
def to_pandas(self) -> pd.DataFrame:
4346
if self._dataframe is None:
@@ -63,55 +66,53 @@ def flush_to_disk(self):
6366
return
6467

6568
if self._dataframe is not None:
66-
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
67-
self._dataframe.to_parquet(tmp)
68-
self._temp_parquet_path = tmp.name
69-
del self._dataframe
70-
self._dataframe = None
69+
self._df_to_disk(self._dataframe)
70+
self._dataframe = None
7171
elif self._arrow_table is not None:
72-
import pyarrow.parquet as pq
73-
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
74-
pq.write_table(self._arrow_table, tmp.name)
75-
self._temp_parquet_path = tmp.name
76-
del self._arrow_table
77-
self._arrow_table = None
72+
self._arrow_table_to_disk(self._arrow_table)
73+
self._arrow_table = None
7874
elif self._parquet_memoryview is not None:
79-
# copy memoryview to temp file
80-
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
81-
tmp.write(self._parquet_memoryview.tobytes())
82-
self._temp_parquet_path = tmp.name
83-
self._parquet_memoryview.release()
84-
del self._parquet_memoryview
85-
self._parquet_memoryview = None
75+
self._memoryview_to_disk(self._parquet_memoryview)
76+
self._parquet_memoryview = None
8677
else:
8778
raise ValueError("No data in Table object")
8879

80+
def _df_to_disk(self, df):
81+
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
82+
df.to_parquet(tmp)
83+
self._temp_parquet_path = tmp.name
84+
85+
def _arrow_table_to_disk(self, arrow_table):
86+
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
87+
pa.parquet.write_table(arrow_table, tmp.name)
88+
self._temp_parquet_path = tmp.name
89+
90+
def _memoryview_to_disk(self, memoryview):
91+
# copy memoryview to temp file
92+
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
93+
tmp.write(memoryview.tobytes())
94+
self._temp_parquet_path = tmp.name
95+
8996
def __repr__(self):
9097
return repr(self.to_pandas())
9198

9299
def __str__(self):
93100
return str(self.to_pandas())
94101

95-
def query(self, sql, **kwargs) -> "Table":
102+
def query(self, sql: str, **kwargs) -> "Table":
96103
"""
97104
Query on current Table object, return a new Table object.
98105
The `FROM` table name in SQL should always be `__table__`. eg:
99106
`SELECT * FROM __table__ WHERE ...`
100107
"""
101-
# check if "__table__" is in sql
102-
if "__table__" not in sql:
103-
raise ValueError("SQL should always contain `FROM __table__`")
108+
self._validate_sql(sql)
104109

105-
if self._parquet_path is not None: # if we have parquet file path, run chdb query on it directly is faster
106-
# replace "__table__" with file("self._parquet_path", Parquet)
107-
new_sql = sql.replace("__table__", f"file(\"{self._parquet_path}\", Parquet)")
108-
res = chdb_query(new_sql, "Parquet", **kwargs)
109-
return Table(parquet_memoryview=res.get_memview())
110+
if (
111+
self._parquet_path is not None
112+
): # if we have parquet file path, run chdb query on it directly is faster
113+
return self._query_on_path(self._parquet_path, sql, **kwargs)
110114
elif self._temp_parquet_path is not None:
111-
# replace "__table__" with file("self._temp_parquet_path", Parquet)
112-
new_sql = sql.replace("__table__", f"file(\"{self._temp_parquet_path}\", Parquet)")
113-
res = chdb_query(new_sql, "Parquet", **kwargs)
114-
return Table(parquet_memoryview=res.get_memview())
115+
return self._query_on_path(self._temp_parquet_path, sql, **kwargs)
115116
elif self._parquet_memoryview is not None:
116117
return self.queryParquetBuffer(sql, **kwargs)
117118
elif self._dataframe is not None:
@@ -121,6 +122,15 @@ def query(self, sql, **kwargs) -> "Table":
121122
else:
122123
raise ValueError("Table object is not initialized correctly")
123124

125+
def _query_on_path(self, path, sql, **kwargs):
126+
new_sql = sql.replace("__table__", f'file("{path}", Parquet)')
127+
res = chdb_query(new_sql, "Parquet", **kwargs)
128+
return Table(parquet_memoryview=res.get_memview())
129+
130+
def _validate_sql(self, sql):
131+
if "__table__" not in sql:
132+
raise ValueError("SQL should always contain `FROM __table__`")
133+
124134
def queryParquetBuffer(self, sql: str, **kwargs) -> "Table":
125135
if "__table__" not in sql:
126136
raise ValueError("SQL should always contain `FROM __table__`")
@@ -139,6 +149,8 @@ def queryParquetBuffer(self, sql: str, **kwargs) -> "Table":
139149
ffd.flush()
140150
ret = self._run_on_temp(parquet_fd, temp_path, sql=sql, fmt="Parquet", **kwargs)
141151
ffd.close()
152+
if temp_path is not None:
153+
os.remove(temp_path)
142154
return ret
143155

144156
def queryArrowTable(self, sql: str, **kwargs) -> "Table":
@@ -159,6 +171,8 @@ def queryArrowTable(self, sql: str, **kwargs) -> "Table":
159171
ffd.flush()
160172
ret = self._run_on_temp(arrow_fd, temp_path, sql=sql, fmt="Arrow", **kwargs)
161173
ffd.close()
174+
if temp_path is not None:
175+
os.remove(temp_path)
162176
return ret
163177

164178
def queryDF(self, sql: str, **kwargs) -> "Table":
@@ -174,19 +188,104 @@ def queryDF(self, sql: str, **kwargs) -> "Table":
174188
if parquet_fd == -1:
175189
parquet_fd, temp_path = tempfile.mkstemp()
176190
ffd = os.fdopen(parquet_fd, "wb")
177-
self._dataframe.to_parquet(ffd, engine='pyarrow', compression=None)
191+
self._dataframe.to_parquet(ffd, engine="pyarrow", compression=None)
178192
ffd.flush()
179193
ret = self._run_on_temp(parquet_fd, temp_path, sql=sql, fmt="Parquet", **kwargs)
180194
ffd.close()
195+
if temp_path is not None:
196+
os.remove(temp_path)
181197
return ret
182198

183-
def _run_on_temp(self, fd: int, temp_path: str = None, sql: str = None, fmt: str = "Parquet", **kwargs) -> "Table":
199+
@staticmethod
200+
def queryStatic(sql: str, **kwargs) -> "Table":
201+
"""
202+
Query on multiple Tables, use Table variables as the table name in SQL
203+
eg.
204+
table1 = Table(...)
205+
table2 = Table(...)
206+
query("SELECT * FROM __table1__ JOIN __table2__ ON ...", table1=table1, table2=table2)
207+
"""
208+
ansiTablePattern = re.compile(r"__([a-zA-Z][a-zA-Z0-9_]*)__")
209+
temp_paths = []
210+
ffds = []
211+
212+
def replace_table_name(match):
213+
tableName = match.group(1)
214+
if tableName not in kwargs:
215+
raise ValueError(f"Table {tableName} should be passed as a parameter")
216+
217+
tbl = kwargs[tableName]
218+
# if tbl is DataFrame, convert it to Table
219+
if isinstance(tbl, pd.DataFrame):
220+
tbl = Table(dataframe=tbl)
221+
elif not isinstance(tbl, Table):
222+
raise ValueError(f"Table {tableName} should be an instance of Table or DataFrame")
223+
224+
if tbl._parquet_path is not None:
225+
return f'file("{tbl._parquet_path}", Parquet)'
226+
227+
if tbl._temp_parquet_path is not None:
228+
return f'file("{tbl._temp_parquet_path}", Parquet)'
229+
230+
temp_path = None
231+
data_fd = -1
232+
233+
if tbl.use_memfd:
234+
data_fd = memfd_create()
235+
236+
if data_fd == -1:
237+
data_fd, temp_path = tempfile.mkstemp()
238+
temp_paths.append(temp_path)
239+
240+
ffd = os.fdopen(data_fd, "wb")
241+
ffds.append(ffd)
242+
243+
if tbl._parquet_memoryview is not None:
244+
ffd.write(tbl._parquet_memoryview.tobytes())
245+
ffd.flush()
246+
os.lseek(data_fd, 0, os.SEEK_SET)
247+
return f'file("/dev/fd/{data_fd}", Parquet)'
248+
249+
if tbl._dataframe is not None:
250+
ffd.write(tbl._dataframe.to_parquet(engine="pyarrow", compression=None))
251+
ffd.flush()
252+
os.lseek(data_fd, 0, os.SEEK_SET)
253+
return f'file("/dev/fd/{data_fd}", Parquet)'
254+
255+
if tbl._arrow_table is not None:
256+
with pa.RecordBatchFileWriter(ffd, tbl._arrow_table.schema) as writer:
257+
writer.write_table(tbl._arrow_table)
258+
ffd.flush()
259+
os.lseek(data_fd, 0, os.SEEK_SET)
260+
return f'file("/dev/fd/{data_fd}", Arrow)'
261+
262+
raise ValueError(f"Table {tableName} is not initialized correctly")
263+
264+
sql = ansiTablePattern.sub(replace_table_name, sql)
265+
res = chdb_query(sql, "Parquet")
266+
267+
for fd in ffds:
268+
fd.close()
269+
270+
for tmp_path in temp_paths:
271+
os.remove(tmp_path)
272+
273+
return Table(parquet_memoryview=res.get_memview())
274+
275+
def _run_on_temp(
276+
self,
277+
fd: int,
278+
temp_path: str = None,
279+
sql: str = None,
280+
fmt: str = "Parquet",
281+
**kwargs,
282+
) -> "Table":
184283
# replace "__table__" with file("temp_path", Parquet) or file("/dev/fd/{parquet_fd}", Parquet)
185284
if temp_path is not None:
186-
new_sql = sql.replace("__table__", f"file(\"{temp_path}\", {fmt})")
285+
new_sql = sql.replace("__table__", f'file("{temp_path}", {fmt})')
187286
else:
188287
os.lseek(fd, 0, os.SEEK_SET)
189-
new_sql = sql.replace("__table__", f"file(\"/dev/fd/{fd}\", {fmt})")
288+
new_sql = sql.replace("__table__", f'file("/dev/fd/{fd}", {fmt})')
190289
res = chdb_query(new_sql, "Parquet", **kwargs)
191290
return Table(parquet_memoryview=res.get_memview())
192291

@@ -212,10 +311,14 @@ def memfd_create(name: str = None) -> int:
212311
if __name__ == "__main__":
213312
import argparse
214313

215-
parser = argparse.ArgumentParser(description='Run SQL on parquet file')
216-
parser.add_argument('parquet_path', type=str, help='path to parquet file')
217-
parser.add_argument('sql', type=str, help='SQL to run')
218-
parser.add_argument('--use-memfd', action='store_true', help='use memfd_create to create file descriptor')
314+
parser = argparse.ArgumentParser(description="Run SQL on parquet file")
315+
parser.add_argument("parquet_path", type=str, help="path to parquet file")
316+
parser.add_argument("sql", type=str, help="SQL to run")
317+
parser.add_argument(
318+
"--use-memfd",
319+
action="store_true",
320+
help="use memfd_create to create file descriptor",
321+
)
219322
args = parser.parse_args()
220323

221324
table = Table(parquet_path=args.parquet_path, use_memfd=args.use_memfd)

tests/test_gc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ def test_gc(self):
3636
gc.collect()
3737
self.assertEqual(mv3.tobytes(), b'123,"adbcdefg"\n')
3838
self.assertEqual(len(mv3), 15)
39-
39+
4040
if __name__ == '__main__':
4141
unittest.main()

tests/test_joindf.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!python3
2+
3+
import unittest
4+
import pandas as pd
5+
from chdb import dataframe as cdf
6+
7+
8+
class TestJoinDf(unittest.TestCase):
9+
def test_1df(self):
10+
df1 = pd.DataFrame({"a": [1, 2, 3], "b": [b"one", b"two", b"three"]})
11+
cdf1 = cdf.Table(dataframe=df1)
12+
ret1 = cdf.query(sql="select * from __tbl1__", tbl1=cdf1)
13+
self.assertEqual(str(ret1), str(df1))
14+
15+
def test_2df(self):
16+
df1 = pd.DataFrame({"a": [1, 2, 3], "b": ["one", "two", "three"]})
17+
df2 = pd.DataFrame({"c": [1, 2, 3], "d": ["①", "②", "③"]})
18+
ret_tbl = cdf.query(
19+
sql="select * from __tbl1__ t1 join __tbl2__ t2 on t1.a = t2.c",
20+
tbl1=df1,
21+
tbl2=df2,
22+
)
23+
self.assertEqual(
24+
str(ret_tbl),
25+
str(
26+
pd.DataFrame(
27+
{
28+
"a": [1, 2, 3],
29+
"b": [b"one", b"two", b"three"],
30+
"c": [1, 2, 3],
31+
"d": [b"\xe2\x91\xa0", b"\xe2\x91\xa1", b"\xe2\x91\xa2"],
32+
}
33+
)
34+
),
35+
)
36+
37+
ret_tbl2 = ret_tbl.query(
38+
"select b, a+c s from __table__ order by s"
39+
)
40+
self.assertEqual(
41+
str(ret_tbl2),
42+
str(pd.DataFrame({"b": [b"one", b"two", b"three"], "s": [2, 4, 6]})),
43+
)
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

0 commit comments

Comments
 (0)