Skip to content

Commit 0ec1f8b

Browse files
committed
working demo
1 parent da9df57 commit 0ec1f8b

File tree

2 files changed

+133
-55
lines changed

2 files changed

+133
-55
lines changed

src/include/wvlet_extension.hpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,31 @@
11
#pragma once
22

33
#include "duckdb.hpp"
4+
#include "duckdb/common/types/data_chunk.hpp"
45
#include "duckdb/function/scalar_function.hpp"
6+
#include "duckdb/function/table_function.hpp"
7+
#include "duckdb/main/client_context.hpp"
8+
9+
// Declare the external wvlet_compile_query function
10+
extern "C" {
11+
int wvlet_compile_main(const char*);
12+
const char* wvlet_compile_query(const char* json_query);
13+
}
514

615
namespace duckdb {
716

8-
class WvletExtension : public Extension {
9-
public:
10-
void Load(DuckDB &db) override;
11-
std::string Name() override;
12-
std::string Version() const override;
17+
struct WvletQueryResult {
18+
unique_ptr<QueryResult> result;
19+
bool initialized;
20+
21+
WvletQueryResult() : initialized(false) {}
22+
};
23+
24+
struct WvletBindData : public TableFunctionData {
25+
string query;
26+
unique_ptr<WvletQueryResult> query_result;
27+
28+
WvletBindData() : query_result(make_uniq<WvletQueryResult>()) {}
1329
};
1430

1531
struct WvletScriptFunction {
@@ -18,4 +34,11 @@ struct WvletScriptFunction {
1834
vector<unique_ptr<Expression>> &arguments);
1935
};
2036

37+
class WvletExtension : public Extension {
38+
public:
39+
void Load(DuckDB &db) override;
40+
std::string Name() override;
41+
std::string Version() const override;
42+
};
43+
2144
} // namespace duckdb

src/wvlet_extension.cpp

Lines changed: 105 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,78 +8,137 @@
88
#include <duckdb/parser/parsed_data/create_table_function_info.hpp>
99
#include <fstream>
1010
#include <sstream>
11-
12-
#include <iostream>
13-
#include <cstdio>
1411
#include <stdexcept>
1512

16-
// OpenSSL linked through vcpkg
17-
#include <openssl/opensslv.h>
13+
extern "C" {
14+
int wvlet_compile_main(const char*);
15+
const char* wvlet_compile_compile(const char*);
16+
}
1817

1918
namespace duckdb {
2019

21-
extern "C" {
22-
int wvlet_compile_main(const char*);
20+
void WvletScriptFunction::ParseWvletScript(DataChunk &args, ExpressionState &state, Vector &result) {
21+
auto &input_vector = args.data[0];
22+
auto input = FlatVector::GetData<string_t>(input_vector);
23+
24+
for (idx_t i = 0; i < args.size(); i++) {
25+
string query = input[i].GetString();
26+
std::string json = "[\"-q\", \"" + query + "\"]";
27+
28+
// std::cout << "Input wvlet: " << query << std::endl;
29+
30+
// Initialize wvlet compiler
31+
// wvlet_compile_main(json.c_str());
32+
33+
// Get compiled SQL
34+
const char* sql_result = wvlet_compile_query(json.c_str());
35+
// std::cout << "Compiled SQL: " << sql_result << std::endl;
36+
37+
if (!sql_result || strlen(sql_result) == 0) {
38+
throw std::runtime_error("Failed to compile wvlet script");
39+
}
40+
41+
FlatVector::GetData<string_t>(result)[i] = StringVector::AddString(result, sql_result);
42+
}
43+
44+
result.Verify(args.size());
2345
}
2446

25-
struct WvletBindData : public TableFunctionData {
26-
string query;
27-
bool has_returned = false;
28-
};
47+
unique_ptr<FunctionData> WvletScriptFunction::Bind(ClientContext &context, ScalarFunction &bound_function,
48+
vector<unique_ptr<Expression>> &arguments) {
49+
return nullptr;
50+
}
51+
52+
static std::string CleanSQL(const std::string& sql) {
53+
// Find first occurrence of "select" (case insensitive)
54+
std::string lower_sql = sql;
55+
std::transform(lower_sql.begin(), lower_sql.end(), lower_sql.begin(), ::tolower);
56+
auto pos = lower_sql.find("select");
57+
if (pos == std::string::npos) {
58+
throw std::runtime_error("No SELECT statement found in compiled SQL");
59+
}
60+
return sql.substr(pos);
61+
}
2962

3063
static unique_ptr<FunctionData> WvletBind(ClientContext &context, TableFunctionBindInput &input,
3164
vector<LogicalType> &return_types, vector<string> &names) {
32-
// Get all the lineitem columns here
3365
auto result = make_uniq<WvletBindData>();
3466
result->query = input.inputs[0].GetValue<string>();
3567

36-
// TODO: We should probably get these from the schema of the target table
37-
return_types = {LogicalType::INTEGER, LogicalType::VARCHAR}; // Example columns
38-
names = {"id", "name"}; // Example column names
68+
std::string json = "[\"-q\", \"" + result->query + "\"]";
69+
// std::cout << "Input wvlet: " << result->query << std::endl;
70+
71+
// Initialize wvlet compiler
72+
// std::cout << "Calling wvlet_compile_main..." << std::endl;
73+
wvlet_compile_main(json.c_str());
74+
75+
// Get compiled SQL
76+
// std::cout << "Calling wvlet_compile_query..." << std::endl;
77+
const char* sql_result = wvlet_compile_query(json.c_str());
78+
// std::cout << "Compiled SQL: " << sql_result << std::endl;
79+
80+
if (!sql_result || strlen(sql_result) == 0) {
81+
throw std::runtime_error("Failed to compile wvlet script");
82+
}
3983

84+
// Store the compiled SQL query
85+
result->query = std::string(sql_result);
86+
87+
// For t1, we know it has two INTEGER columns
88+
return_types = {LogicalType::INTEGER, LogicalType::INTEGER};
89+
names = {"i", "j"};
90+
91+
// std::cout << "Bind complete with query: " << result->query << std::endl;
4092
return std::move(result);
4193
}
4294

4395
static void WvletFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
44-
auto &bind_data = (WvletBindData &)*data_p.bind_data;
96+
auto &bind_data = data_p.bind_data->Cast<WvletBindData>();
97+
98+
if (!bind_data.query_result->initialized) {
99+
// std::cout << "Starting query execution..." << std::endl;
100+
101+
// Use a new connection with the existing database instance
102+
Connection conn(*context.db);
103+
104+
// std::cout << "Executing query: " << bind_data.query << std::endl;
105+
auto result = conn.Query(bind_data.query);
106+
107+
if (result->HasError()) {
108+
throw std::runtime_error(result->GetError());
109+
}
110+
111+
bind_data.query_result->result = std::move(result);
112+
bind_data.query_result->initialized = true;
113+
114+
// Initialize output with INTEGER types to match t1
115+
output.Initialize(Allocator::DefaultAllocator(), {LogicalType::INTEGER, LogicalType::INTEGER});
116+
// std::cout << "Query initialized successfully" << std::endl;
117+
}
118+
119+
// Fetch next chunk
120+
// std::cout << "Fetching chunk..." << std::endl;
121+
auto chunk = bind_data.query_result->result->Fetch();
122+
// std::cout << "Chunk fetched" << std::endl;
45123

46-
if (bind_data.has_returned) {
124+
if (!chunk || chunk->size() == 0) {
125+
// std::cout << "No more data" << std::endl;
47126
output.SetCardinality(0);
48127
return;
49128
}
50-
51-
std::ostringstream captured_output;
52-
FILE* original_stdout = stdout;
53-
stdout = fdopen(fileno(stdout), "w");
54-
std::ostringstream captured_error;
55-
FILE* original_stderr = stderr;
56-
stderr = fdopen(fileno(stderr), "w");
57-
58-
// Convert script to JSON array format as expected by wvlet_compile_main
59-
std::string json = "[\"-x\", \"-q\", \"" + bind_data.query + "\"]";
60-
61-
// Call wvlet compiler - it will print the SQL
62-
int compile_result = wvlet_compile_main(json.c_str());
63-
64-
if (compile_result != 0) {
65-
throw std::runtime_error("Failed to compile wvlet script");
66-
}
67-
68-
std::string query = captured_output.str();
69-
std::cout << "Captured Output: " << query << std::endl;
70-
71-
stdout = original_stdout;
72-
stderr = original_stderr;
73-
74-
// The SQL has been printed, now we can execute it
75-
// TODO: Execute the printed SQL and fill the output chunk with results
76129

77-
bind_data.has_returned = true;
130+
// std::cout << "Got chunk with " << chunk->size() << " rows" << std::endl;
131+
output.Reference(*chunk);
132+
output.SetCardinality(chunk->size());
78133
}
79134

80135
static void LoadInternal(DatabaseInstance &instance) {
81-
TableFunction wvlet_func("wvlet", {LogicalType::VARCHAR},
82-
WvletFunction, WvletBind);
136+
auto wvlet_fun = ScalarFunction("wvlet", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
137+
WvletScriptFunction::ParseWvletScript,
138+
WvletScriptFunction::Bind);
139+
ExtensionUtil::RegisterFunction(instance, wvlet_fun);
140+
141+
TableFunction wvlet_func("wvlet", {LogicalType::VARCHAR}, WvletFunction, WvletBind);
83142
ExtensionUtil::RegisterFunction(instance, wvlet_func);
84143
}
85144

@@ -111,7 +170,3 @@ DUCKDB_EXTENSION_API const char *wvlet_version() {
111170
return duckdb::DuckDB::LibraryVersion();
112171
}
113172
}
114-
115-
#ifndef DUCKDB_EXTENSION_MAIN
116-
#error DUCKDB_EXTENSION_MAIN not defined
117-
#endif

0 commit comments

Comments
 (0)