Skip to content

Commit f15b60a

Browse files
committed
feat: add SQL statement for egin transaction isolation level
Adds an additional option to the `begin [transaction]` SQL statement to specify the isolation level of that transaction. The following format is now supported: ``` {begin | start} [transaction] [isolation level {repeatable read | serializable}] ```
1 parent b3c259d commit f15b60a

File tree

5 files changed

+178
-2
lines changed

5 files changed

+178
-2
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import TYPE_CHECKING
15+
from google.cloud.spanner_v1 import TransactionOptions
1516

1617
if TYPE_CHECKING:
1718
from google.cloud.spanner_dbapi.cursor import Cursor
@@ -58,7 +59,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
5859
connection.commit()
5960
return None
6061
if statement_type == ClientSideStatementType.BEGIN:
61-
connection.begin()
62+
connection.begin(isolation_level=_get_isolation_level(parsed_statement))
6263
return None
6364
if statement_type == ClientSideStatementType.ROLLBACK:
6465
connection.rollback()
@@ -121,3 +122,19 @@ def _get_streamed_result_set(column_name, type_code, column_values):
121122
column_values_pb.append(_make_value_pb(column_value))
122123
result_set.values.extend(column_values_pb)
123124
return StreamedResultSet(iter([result_set]))
125+
126+
127+
def _get_isolation_level(
128+
statement: ParsedStatement,
129+
) -> TransactionOptions.IsolationLevel | None:
130+
if (
131+
statement.client_side_statement_params is None
132+
or len(statement.client_side_statement_params) == 0
133+
):
134+
return None
135+
level = statement.client_side_statement_params[0]
136+
if not isinstance(level, str) or level == "":
137+
return None
138+
# Replace (duplicate) whitespaces in the string with an underscore.
139+
level = "_".join(level.split()).upper()
140+
return TransactionOptions.IsolationLevel[level]

google/cloud/spanner_dbapi/client_side_statement_parser.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
Statement,
2222
)
2323

24-
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
24+
RE_BEGIN = re.compile(
25+
r"^\s*(?:BEGIN|START)(?:\s+TRANSACTION)?(?:\s+ISOLATION\s+LEVEL\s+(REPEATABLE\s+READ|SERIALIZABLE))?\s*$",
26+
re.IGNORECASE,
27+
)
2528
RE_COMMIT = re.compile(r"^\s*(COMMIT)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
2629
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(\s+TRANSACTION)?\s*$", re.IGNORECASE)
2730
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
@@ -68,6 +71,10 @@ def parse_stmt(query):
6871
elif RE_START_BATCH_DML.match(query):
6972
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
7073
elif RE_BEGIN.match(query):
74+
match = re.search(RE_BEGIN, query)
75+
isolation_level = match.group(1)
76+
if isolation_level is not None:
77+
client_side_statement_params.append(isolation_level)
7178
client_side_statement_type = ClientSideStatementType.BEGIN
7279
elif RE_RUN_BATCH.match(query):
7380
client_side_statement_type = ClientSideStatementType.RUN_BATCH

tests/mockserver_tests/test_dbapi_isolation_level.py

+24
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,27 @@ def test_transaction_isolation_level(self):
117117
self.assertEqual(1, len(begin_requests))
118118
self.assertEqual(begin_requests[0].options.isolation_level, level)
119119
MockServerTestBase.spanner_service.clear_requests()
120+
121+
def test_begin_isolation_level(self):
122+
connection = Connection(self.instance, self.database)
123+
for level in [
124+
TransactionOptions.IsolationLevel.REPEATABLE_READ,
125+
TransactionOptions.IsolationLevel.SERIALIZABLE,
126+
]:
127+
isolation_level_name = level.name.replace("_", " ")
128+
with connection.cursor() as cursor:
129+
cursor.execute(f"begin isolation level {isolation_level_name}")
130+
cursor.execute(
131+
"insert into singers (id, name) values (1, 'Some Singer')"
132+
)
133+
self.assertEqual(1, cursor.rowcount)
134+
connection.commit()
135+
begin_requests = list(
136+
filter(
137+
lambda msg: isinstance(msg, BeginTransactionRequest),
138+
self.spanner_service.requests,
139+
)
140+
)
141+
self.assertEqual(1, len(begin_requests))
142+
self.assertEqual(begin_requests[0].options.isolation_level, level)
143+
MockServerTestBase.spanner_service.clear_requests()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from google.cloud.spanner_dbapi.client_side_statement_executor import (
18+
_get_isolation_level,
19+
)
20+
from google.cloud.spanner_dbapi.parse_utils import classify_statement
21+
from google.cloud.spanner_v1 import TransactionOptions
22+
23+
24+
class TestParseUtils(unittest.TestCase):
25+
def test_get_isolation_level(self):
26+
self.assertIsNone(_get_isolation_level(classify_statement("begin")))
27+
self.assertEqual(
28+
TransactionOptions.IsolationLevel.SERIALIZABLE,
29+
_get_isolation_level(
30+
classify_statement("begin isolation level serializable")
31+
),
32+
)
33+
self.assertEqual(
34+
TransactionOptions.IsolationLevel.SERIALIZABLE,
35+
_get_isolation_level(
36+
classify_statement(
37+
"begin transaction isolation level serializable "
38+
)
39+
),
40+
)
41+
self.assertEqual(
42+
TransactionOptions.IsolationLevel.REPEATABLE_READ,
43+
_get_isolation_level(
44+
classify_statement("begin isolation level repeatable read")
45+
),
46+
)
47+
self.assertEqual(
48+
TransactionOptions.IsolationLevel.REPEATABLE_READ,
49+
_get_isolation_level(
50+
classify_statement(
51+
"begin transaction isolation level repeatable read "
52+
)
53+
),
54+
)

tests/unit/spanner_dbapi/test_parse_utils.py

+74
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,28 @@ def test_classify_stmt(self):
6363
("commit", StatementType.CLIENT_SIDE),
6464
("begin", StatementType.CLIENT_SIDE),
6565
("start", StatementType.CLIENT_SIDE),
66+
("begin isolation level serializable", StatementType.CLIENT_SIDE),
67+
("start isolation level serializable", StatementType.CLIENT_SIDE),
68+
("begin isolation level repeatable read", StatementType.CLIENT_SIDE),
69+
("start isolation level repeatable read", StatementType.CLIENT_SIDE),
6670
("begin transaction", StatementType.CLIENT_SIDE),
6771
("start transaction", StatementType.CLIENT_SIDE),
72+
(
73+
"begin transaction isolation level serializable",
74+
StatementType.CLIENT_SIDE,
75+
),
76+
(
77+
"start transaction isolation level serializable",
78+
StatementType.CLIENT_SIDE,
79+
),
80+
(
81+
"begin transaction isolation level repeatable read",
82+
StatementType.CLIENT_SIDE,
83+
),
84+
(
85+
"start transaction isolation level repeatable read",
86+
StatementType.CLIENT_SIDE,
87+
),
6888
("rollback", StatementType.CLIENT_SIDE),
6989
(" commit TRANSACTION ", StatementType.CLIENT_SIDE),
7090
(" rollback TRANSACTION ", StatementType.CLIENT_SIDE),
@@ -84,6 +104,16 @@ def test_classify_stmt(self):
84104
("udpate table set col2=1 where col1 = 2", StatementType.UNKNOWN),
85105
("begin foo", StatementType.UNKNOWN),
86106
("begin transaction foo", StatementType.UNKNOWN),
107+
("begin transaction isolation level", StatementType.UNKNOWN),
108+
("begin transaction repeatable read", StatementType.UNKNOWN),
109+
(
110+
"begin transaction isolation level repeatable read foo",
111+
StatementType.UNKNOWN,
112+
),
113+
(
114+
"begin transaction isolation level unspecified",
115+
StatementType.UNKNOWN,
116+
),
87117
("commit foo", StatementType.UNKNOWN),
88118
("commit transaction foo", StatementType.UNKNOWN),
89119
("rollback foo", StatementType.UNKNOWN),
@@ -100,6 +130,50 @@ def test_classify_stmt(self):
100130
classify_statement(query).statement_type, want_class, query
101131
)
102132

133+
def test_begin_isolation_level(self):
134+
parsed_statement = classify_statement("begin")
135+
self.assertEqual(
136+
parsed_statement,
137+
ParsedStatement(
138+
StatementType.CLIENT_SIDE,
139+
Statement("begin"),
140+
ClientSideStatementType.BEGIN,
141+
[],
142+
),
143+
)
144+
parsed_statement = classify_statement("begin isolation level serializable")
145+
self.assertEqual(
146+
parsed_statement,
147+
ParsedStatement(
148+
StatementType.CLIENT_SIDE,
149+
Statement("begin isolation level serializable"),
150+
ClientSideStatementType.BEGIN,
151+
["serializable"],
152+
),
153+
)
154+
parsed_statement = classify_statement("begin isolation level repeatable read")
155+
self.assertEqual(
156+
parsed_statement,
157+
ParsedStatement(
158+
StatementType.CLIENT_SIDE,
159+
Statement("begin isolation level repeatable read"),
160+
ClientSideStatementType.BEGIN,
161+
["repeatable read"],
162+
),
163+
)
164+
parsed_statement = classify_statement(
165+
"begin isolation level repeatable read "
166+
)
167+
self.assertEqual(
168+
parsed_statement,
169+
ParsedStatement(
170+
StatementType.CLIENT_SIDE,
171+
Statement("begin isolation level repeatable read"),
172+
ClientSideStatementType.BEGIN,
173+
["repeatable read"],
174+
),
175+
)
176+
103177
def test_partition_query_classify_stmt(self):
104178
parsed_statement = classify_statement(
105179
" PARTITION SELECT s.SongName FROM Songs AS s "

0 commit comments

Comments
 (0)