-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathconvert_test_to_async.py
141 lines (121 loc) · 4.69 KB
/
convert_test_to_async.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import annotations
import asyncio
import sys
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
replacements = {
"Collection": "AsyncCollection",
"Database": "AsyncDatabase",
"Cursor": "AsyncCursor",
"MongoClient": "AsyncMongoClient",
"CommandCursor": "AsyncCommandCursor",
"RawBatchCursor": "AsyncRawBatchCursor",
"RawBatchCommandCursor": "AsyncRawBatchCommandCursor",
"ClientSession": "AsyncClientSession",
"ChangeStream": "AsyncChangeStream",
"CollectionChangeStream": "AsyncCollectionChangeStream",
"DatabaseChangeStream": "AsyncDatabaseChangeStream",
"ClusterChangeStream": "AsyncClusterChangeStream",
"_Bulk": "_AsyncBulk",
"_ClientBulk": "_AsyncClientBulk",
"Connection": "AsyncConnection",
"synchronous": "asynchronous",
"Synchronous": "Asynchronous",
"next": "await anext",
"_Lock": "_ALock",
"_Condition": "_ACondition",
"GridFS": "AsyncGridFS",
"GridFSBucket": "AsyncGridFSBucket",
"GridIn": "AsyncGridIn",
"GridOut": "AsyncGridOut",
"GridOutCursor": "AsyncGridOutCursor",
"GridOutIterator": "AsyncGridOutIterator",
"GridOutChunkIterator": "_AsyncGridOutChunkIterator",
"_grid_in_property": "_a_grid_in_property",
"_grid_out_property": "_a_grid_out_property",
"ClientEncryption": "AsyncClientEncryption",
"MongoCryptCallback": "AsyncMongoCryptCallback",
"ExplicitEncrypter": "AsyncExplicitEncrypter",
"AutoEncrypter": "AsyncAutoEncrypter",
"ContextManager": "AsyncContextManager",
"ClientContext": "AsyncClientContext",
"TestCollection": "AsyncTestCollection",
"IntegrationTest": "AsyncIntegrationTest",
"PyMongoTestCase": "AsyncPyMongoTestCase",
"MockClientTest": "AsyncMockClientTest",
"client_context": "async_client_context",
"setUp": "asyncSetUp",
"tearDown": "asyncTearDown",
"wait_until": "await async_wait_until",
"addCleanup": "addAsyncCleanup",
"TestCase": "IsolatedAsyncioTestCase",
"UnitTest": "AsyncUnitTest",
"MockClient": "AsyncMockClient",
"SpecRunner": "AsyncSpecRunner",
"TransactionsBase": "AsyncTransactionsBase",
"get_pool": "await async_get_pool",
"is_mongos": "await async_is_mongos",
"rs_or_single_client": "await async_rs_or_single_client",
"rs_or_single_client_noauth": "await async_rs_or_single_client_noauth",
"rs_client": "await async_rs_client",
"single_client": "await async_single_client",
"from_client": "await async_from_client",
"closing": "aclosing",
"assertRaisesExactly": "asyncAssertRaisesExactly",
"get_mock_client": "await get_async_mock_client",
"close": "await aclose",
}
async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor]
def get_async_methods() -> set[str]:
result: set[str] = set()
for x in async_classes:
methods = {
k
for k, v in vars(x).items()
if callable(v)
and not isinstance(v, classmethod)
and asyncio.iscoroutinefunction(v)
and v.__name__[0] != "_"
}
result = result | methods
return result
async_methods = get_async_methods()
def apply_replacements(lines: list[str]) -> list[str]:
for i in range(len(lines)):
if "_IS_SYNC = True" in lines[i]:
lines[i] = "_IS_SYNC = False"
if "def test" in lines[i]:
lines[i] = lines[i].replace("def test", "async def test")
for k in replacements:
if k in lines[i]:
lines[i] = lines[i].replace(k, replacements[k])
for k in async_methods:
if k + "(" in lines[i]:
tokens = lines[i].split(" ")
for j in range(len(tokens)):
if k + "(" in tokens[j]:
if j < 2:
tokens.insert(0, "await")
else:
tokens.insert(j, "await")
break
new_line = " ".join(tokens)
lines[i] = new_line
return lines
def process_file(input_file: str, output_file: str) -> None:
with open(input_file, "r+") as f:
lines = f.readlines()
lines = apply_replacements(lines)
with open(output_file, "w+") as f2:
f2.seek(0)
f2.writelines(lines)
f2.truncate()
def main() -> None:
args = sys.argv[1:]
sync_file = "./test/" + args[0]
async_file = "./" + args[0]
process_file(sync_file, async_file)
main()