diff --git a/ext/sqlite3/database.c b/ext/sqlite3/database.c index 36b75787..2d12693a 100644 --- a/ext/sqlite3/database.c +++ b/ext/sqlite3/database.c @@ -17,6 +17,7 @@ database_mark(void *ctx) { sqlite3RubyPtr c = (sqlite3RubyPtr)ctx; rb_gc_mark(c->busy_handler); + rb_gc_mark(c->progress_handler); } static void @@ -51,6 +52,7 @@ static VALUE allocate(VALUE klass) { sqlite3RubyPtr ctx; + return TypedData_Make_Struct(klass, sqlite3Ruby, &database_type, ctx); } @@ -259,6 +261,57 @@ busy_handler(int argc, VALUE *argv, VALUE self) return self; } +#ifdef HAVE_SQLITE3_PROGRESS_HANDLER +static int +rb_sqlite3_progress_handler(void *context) +{ + sqlite3RubyPtr ctx = (sqlite3RubyPtr)context; + + VALUE handle = ctx->progress_handler; + VALUE result = rb_funcall(handle, rb_intern("call"), 0); + + if (Qfalse == result) { return 1; } + + return 0; +} + +/* call-seq: + * progress_handler([n]) { ... } + * progress_handler([n,] Class.new { def call; end }.new) + * + * Register a progress handler with this database instance. + * This handler will be invoked periodically during a long-running query or operation. + * If the handler returns +false+, the operation will be interrupted; otherwise, it continues. + * The parameter 'n' specifies the number of SQLite virtual machine instructions between invocations. + * If 'n' is not provided, the default value is 1. + */ +static VALUE +progress_handler(int argc, VALUE *argv, VALUE self) +{ + sqlite3RubyPtr ctx; + VALUE block, n_value; + + TypedData_Get_Struct(self, sqlite3Ruby, &database_type, ctx); + REQUIRE_OPEN_DB(ctx); + + rb_scan_args(argc, argv, "02", &n_value, &block); + + int n = NIL_P(n_value) ? 1000 : NUM2INT(n_value); + if (NIL_P(block) && rb_block_given_p()) { block = rb_block_proc(); } + ctx->progress_handler = block; + + sqlite3_progress_handler( + ctx->db, + n, + NIL_P(block) ? NULL : rb_sqlite3_progress_handler, + (void *)ctx + ); + + return self; +} +#endif + + /* call-seq: last_insert_row_id * * Obtains the unique row ID of the last row to be inserted by this Database @@ -888,6 +941,10 @@ init_sqlite3_database(void) rb_define_method(cSqlite3Database, "enable_load_extension", enable_load_extension, 1); #endif +#ifdef HAVE_SQLITE3_PROGRESS_HANDLER + rb_define_method(cSqlite3Database, "progress_handler", progress_handler, -1); +#endif + rb_sqlite3_aggregator_init(); } diff --git a/ext/sqlite3/database.h b/ext/sqlite3/database.h index 56833020..e72a85b0 100644 --- a/ext/sqlite3/database.h +++ b/ext/sqlite3/database.h @@ -6,6 +6,7 @@ struct _sqlite3Ruby { sqlite3 *db; VALUE busy_handler; + VALUE progress_handler; }; typedef struct _sqlite3Ruby sqlite3Ruby; diff --git a/ext/sqlite3/extconf.rb b/ext/sqlite3/extconf.rb index 733e22d5..1f251ee4 100644 --- a/ext/sqlite3/extconf.rb +++ b/ext/sqlite3/extconf.rb @@ -121,6 +121,7 @@ def configure_extension have_func("sqlite3_column_database_name") have_func("sqlite3_enable_load_extension") have_func("sqlite3_load_extension") + have_func("sqlite3_progress_handler") unless have_func("sqlite3_open_v2") # https://www.sqlite.org/releaselog/3_5_0.html abort("\nPlease use a version of SQLite3 >= 3.5.0\n\n") diff --git a/lib/sqlite3/database.rb b/lib/sqlite3/database.rb index 4718e9c9..b3f7ae40 100644 --- a/lib/sqlite3/database.rb +++ b/lib/sqlite3/database.rb @@ -118,6 +118,7 @@ def initialize file, options = {}, zvfs = nil @tracefunc = nil @authorizer = nil @busy_handler = nil + @progress_handler = nil @collations = {} @functions = {} @results_as_hash = options[:results_as_hash] diff --git a/test/test_integration.rb b/test/test_integration.rb index f0c005ab..609ecf48 100644 --- a/test/test_integration.rb +++ b/test/test_integration.rb @@ -505,4 +505,66 @@ def test_bind_array_parameter [1, "foo"]) assert_equal "foo", result end + + ### + # The `progress_handler` method may not exist depending on how sqlite3 was compiled + def test_progress_handler_used + skip("progress_handler method not defined") unless @db.respond_to?(:progress_handler) + + progress_calls = [] + @db.progress_handler(10) do + progress_calls << nil + true + end + @db.execute "create table test1(a, b)" + + assert_operator 1, :<, progress_calls.size + end + + def test_progress_handler_opcode_arg + skip("progress_handler method not defined") unless @db.respond_to?(:progress_handler) + + progress_calls = [] + handler = proc do + progress_calls << nil + true + end + @db.progress_handler(1, handler) + @db.execute "create table test1(a, b)" + first_count = progress_calls.size + + progress_calls = [] + @db.progress_handler(10, handler) + @db.execute "create table test2(a, b)" + second_count = progress_calls.size + + assert_operator first_count, :>=, second_count + end + + def test_progress_handler_interrupts_operation + skip("progress_handler method not defined") unless @db.respond_to?(:progress_handler) + + @db.progress_handler(10) do + false + end + + assert_raises(SQLite3::InterruptException) do + @db.execute "create table test1(a, b)" + end + end + + def test_clear_handler + skip("progress_handler method not defined") unless @db.respond_to?(:progress_handler) + + progress_calls = [] + @db.progress_handler do + progress_calls << nil + true + end + @db.progress_handler(nil) + + @db.execute "create table test1(a, b)" + + assert_equal 0, progress_calls.size + end end diff --git a/test/test_integration_pending.rb b/test/test_integration_pending.rb index 1a25910b..706f2adb 100644 --- a/test/test_integration_pending.rb +++ b/test/test_integration_pending.rb @@ -110,4 +110,31 @@ def test_busy_handler_timeout_releases_gvl assert_operator work.size - work.find_index("|"), :>, 3 end + + def test_progress_handler_releasing_gvl + work = [] + + Thread.new do + loop do + sleep 0.1 + work << "." + end + end + + @db.progress_handler { Thread.pass } + + work << ">" + @db.execute <<~SQL + WITH RECURSIVE r(i) AS ( + VALUES(0) + UNION ALL + SELECT i FROM r + LIMIT 10000000 + ) + SELECT i FROM r WHERE i = 1; + SQL + work << "<" + + assert_operator work.find_index("<") - work.find_index(">"), :>, 2 + end end