Skip to content

Commit 3840d9d

Browse files
authored
Add script to help convert sync tests to async tests (#1825)
1 parent a4645f0 commit 3840d9d

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

CONTRIBUTING.md

+7
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,10 @@ you are attempting to validate new spec tests in PyMongo.
248248
## Making a Release
249249

250250
Follow the [Python Driver Release Process Wiki](https://wiki.corp.mongodb.com/display/DRIVERS/Python+Driver+Release+Process).
251+
252+
## Converting a test to async
253+
The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a
254+
partially-converted asynchronous version of the same name to the `test/asynchronous` directory.
255+
Use this generated file as a starting point for the completed conversion.
256+
257+
The script is used like so: `python tools/convert_test_to_async.py [test_file.py]`

tools/convert_test_to_async.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import sys
5+
6+
from pymongo import AsyncMongoClient
7+
from pymongo.asynchronous.collection import AsyncCollection
8+
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
9+
from pymongo.asynchronous.cursor import AsyncCursor
10+
from pymongo.asynchronous.database import AsyncDatabase
11+
12+
replacements = {
13+
"Collection": "AsyncCollection",
14+
"Database": "AsyncDatabase",
15+
"Cursor": "AsyncCursor",
16+
"MongoClient": "AsyncMongoClient",
17+
"CommandCursor": "AsyncCommandCursor",
18+
"RawBatchCursor": "AsyncRawBatchCursor",
19+
"RawBatchCommandCursor": "AsyncRawBatchCommandCursor",
20+
"ClientSession": "AsyncClientSession",
21+
"ChangeStream": "AsyncChangeStream",
22+
"CollectionChangeStream": "AsyncCollectionChangeStream",
23+
"DatabaseChangeStream": "AsyncDatabaseChangeStream",
24+
"ClusterChangeStream": "AsyncClusterChangeStream",
25+
"_Bulk": "_AsyncBulk",
26+
"_ClientBulk": "_AsyncClientBulk",
27+
"Connection": "AsyncConnection",
28+
"synchronous": "asynchronous",
29+
"Synchronous": "Asynchronous",
30+
"next": "await anext",
31+
"_Lock": "_ALock",
32+
"_Condition": "_ACondition",
33+
"GridFS": "AsyncGridFS",
34+
"GridFSBucket": "AsyncGridFSBucket",
35+
"GridIn": "AsyncGridIn",
36+
"GridOut": "AsyncGridOut",
37+
"GridOutCursor": "AsyncGridOutCursor",
38+
"GridOutIterator": "AsyncGridOutIterator",
39+
"GridOutChunkIterator": "_AsyncGridOutChunkIterator",
40+
"_grid_in_property": "_a_grid_in_property",
41+
"_grid_out_property": "_a_grid_out_property",
42+
"ClientEncryption": "AsyncClientEncryption",
43+
"MongoCryptCallback": "AsyncMongoCryptCallback",
44+
"ExplicitEncrypter": "AsyncExplicitEncrypter",
45+
"AutoEncrypter": "AsyncAutoEncrypter",
46+
"ContextManager": "AsyncContextManager",
47+
"ClientContext": "AsyncClientContext",
48+
"TestCollection": "AsyncTestCollection",
49+
"IntegrationTest": "AsyncIntegrationTest",
50+
"PyMongoTestCase": "AsyncPyMongoTestCase",
51+
"MockClientTest": "AsyncMockClientTest",
52+
"client_context": "async_client_context",
53+
"setUp": "asyncSetUp",
54+
"tearDown": "asyncTearDown",
55+
"wait_until": "await async_wait_until",
56+
"addCleanup": "addAsyncCleanup",
57+
"TestCase": "IsolatedAsyncioTestCase",
58+
"UnitTest": "AsyncUnitTest",
59+
"MockClient": "AsyncMockClient",
60+
"SpecRunner": "AsyncSpecRunner",
61+
"TransactionsBase": "AsyncTransactionsBase",
62+
"get_pool": "await async_get_pool",
63+
"is_mongos": "await async_is_mongos",
64+
"rs_or_single_client": "await async_rs_or_single_client",
65+
"rs_or_single_client_noauth": "await async_rs_or_single_client_noauth",
66+
"rs_client": "await async_rs_client",
67+
"single_client": "await async_single_client",
68+
"from_client": "await async_from_client",
69+
"closing": "aclosing",
70+
"assertRaisesExactly": "asyncAssertRaisesExactly",
71+
"get_mock_client": "await get_async_mock_client",
72+
"close": "await aclose",
73+
}
74+
75+
async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor]
76+
77+
78+
def get_async_methods() -> set[str]:
79+
result: set[str] = set()
80+
for x in async_classes:
81+
methods = {
82+
k
83+
for k, v in vars(x).items()
84+
if callable(v)
85+
and not isinstance(v, classmethod)
86+
and asyncio.iscoroutinefunction(v)
87+
and v.__name__[0] != "_"
88+
}
89+
result = result | methods
90+
return result
91+
92+
93+
async_methods = get_async_methods()
94+
95+
96+
def apply_replacements(lines: list[str]) -> list[str]:
97+
for i in range(len(lines)):
98+
if "_IS_SYNC = True" in lines[i]:
99+
lines[i] = "_IS_SYNC = False"
100+
if "def test" in lines[i]:
101+
lines[i] = lines[i].replace("def test", "async def test")
102+
for k in replacements:
103+
if k in lines[i]:
104+
lines[i] = lines[i].replace(k, replacements[k])
105+
for k in async_methods:
106+
if k + "(" in lines[i]:
107+
tokens = lines[i].split(" ")
108+
for j in range(len(tokens)):
109+
if k + "(" in tokens[j]:
110+
if j < 2:
111+
tokens.insert(0, "await")
112+
else:
113+
tokens.insert(j, "await")
114+
break
115+
new_line = " ".join(tokens)
116+
117+
lines[i] = new_line
118+
119+
return lines
120+
121+
122+
def process_file(input_file: str, output_file: str) -> None:
123+
with open(input_file, "r+") as f:
124+
lines = f.readlines()
125+
lines = apply_replacements(lines)
126+
127+
with open(output_file, "w+") as f2:
128+
f2.seek(0)
129+
f2.writelines(lines)
130+
f2.truncate()
131+
132+
133+
def main() -> None:
134+
args = sys.argv[1:]
135+
sync_file = "./test/" + args[0]
136+
async_file = "./" + args[0]
137+
138+
process_file(sync_file, async_file)
139+
140+
141+
main()

0 commit comments

Comments
 (0)