|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import sys |
| 5 | + |
| 6 | +from pymongo import AsyncMongoClient |
| 7 | +from pymongo.asynchronous.collection import AsyncCollection |
| 8 | +from pymongo.asynchronous.command_cursor import AsyncCommandCursor |
| 9 | +from pymongo.asynchronous.cursor import AsyncCursor |
| 10 | +from pymongo.asynchronous.database import AsyncDatabase |
| 11 | + |
| 12 | +replacements = { |
| 13 | + "Collection": "AsyncCollection", |
| 14 | + "Database": "AsyncDatabase", |
| 15 | + "Cursor": "AsyncCursor", |
| 16 | + "MongoClient": "AsyncMongoClient", |
| 17 | + "CommandCursor": "AsyncCommandCursor", |
| 18 | + "RawBatchCursor": "AsyncRawBatchCursor", |
| 19 | + "RawBatchCommandCursor": "AsyncRawBatchCommandCursor", |
| 20 | + "ClientSession": "AsyncClientSession", |
| 21 | + "ChangeStream": "AsyncChangeStream", |
| 22 | + "CollectionChangeStream": "AsyncCollectionChangeStream", |
| 23 | + "DatabaseChangeStream": "AsyncDatabaseChangeStream", |
| 24 | + "ClusterChangeStream": "AsyncClusterChangeStream", |
| 25 | + "_Bulk": "_AsyncBulk", |
| 26 | + "_ClientBulk": "_AsyncClientBulk", |
| 27 | + "Connection": "AsyncConnection", |
| 28 | + "synchronous": "asynchronous", |
| 29 | + "Synchronous": "Asynchronous", |
| 30 | + "next": "await anext", |
| 31 | + "_Lock": "_ALock", |
| 32 | + "_Condition": "_ACondition", |
| 33 | + "GridFS": "AsyncGridFS", |
| 34 | + "GridFSBucket": "AsyncGridFSBucket", |
| 35 | + "GridIn": "AsyncGridIn", |
| 36 | + "GridOut": "AsyncGridOut", |
| 37 | + "GridOutCursor": "AsyncGridOutCursor", |
| 38 | + "GridOutIterator": "AsyncGridOutIterator", |
| 39 | + "GridOutChunkIterator": "_AsyncGridOutChunkIterator", |
| 40 | + "_grid_in_property": "_a_grid_in_property", |
| 41 | + "_grid_out_property": "_a_grid_out_property", |
| 42 | + "ClientEncryption": "AsyncClientEncryption", |
| 43 | + "MongoCryptCallback": "AsyncMongoCryptCallback", |
| 44 | + "ExplicitEncrypter": "AsyncExplicitEncrypter", |
| 45 | + "AutoEncrypter": "AsyncAutoEncrypter", |
| 46 | + "ContextManager": "AsyncContextManager", |
| 47 | + "ClientContext": "AsyncClientContext", |
| 48 | + "TestCollection": "AsyncTestCollection", |
| 49 | + "IntegrationTest": "AsyncIntegrationTest", |
| 50 | + "PyMongoTestCase": "AsyncPyMongoTestCase", |
| 51 | + "MockClientTest": "AsyncMockClientTest", |
| 52 | + "client_context": "async_client_context", |
| 53 | + "setUp": "asyncSetUp", |
| 54 | + "tearDown": "asyncTearDown", |
| 55 | + "wait_until": "await async_wait_until", |
| 56 | + "addCleanup": "addAsyncCleanup", |
| 57 | + "TestCase": "IsolatedAsyncioTestCase", |
| 58 | + "UnitTest": "AsyncUnitTest", |
| 59 | + "MockClient": "AsyncMockClient", |
| 60 | + "SpecRunner": "AsyncSpecRunner", |
| 61 | + "TransactionsBase": "AsyncTransactionsBase", |
| 62 | + "get_pool": "await async_get_pool", |
| 63 | + "is_mongos": "await async_is_mongos", |
| 64 | + "rs_or_single_client": "await async_rs_or_single_client", |
| 65 | + "rs_or_single_client_noauth": "await async_rs_or_single_client_noauth", |
| 66 | + "rs_client": "await async_rs_client", |
| 67 | + "single_client": "await async_single_client", |
| 68 | + "from_client": "await async_from_client", |
| 69 | + "closing": "aclosing", |
| 70 | + "assertRaisesExactly": "asyncAssertRaisesExactly", |
| 71 | + "get_mock_client": "await get_async_mock_client", |
| 72 | + "close": "await aclose", |
| 73 | +} |
| 74 | + |
| 75 | +async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor] |
| 76 | + |
| 77 | + |
| 78 | +def get_async_methods() -> set[str]: |
| 79 | + result: set[str] = set() |
| 80 | + for x in async_classes: |
| 81 | + methods = { |
| 82 | + k |
| 83 | + for k, v in vars(x).items() |
| 84 | + if callable(v) |
| 85 | + and not isinstance(v, classmethod) |
| 86 | + and asyncio.iscoroutinefunction(v) |
| 87 | + and v.__name__[0] != "_" |
| 88 | + } |
| 89 | + result = result | methods |
| 90 | + return result |
| 91 | + |
| 92 | + |
| 93 | +async_methods = get_async_methods() |
| 94 | + |
| 95 | + |
| 96 | +def apply_replacements(lines: list[str]) -> list[str]: |
| 97 | + for i in range(len(lines)): |
| 98 | + if "_IS_SYNC = True" in lines[i]: |
| 99 | + lines[i] = "_IS_SYNC = False" |
| 100 | + if "def test" in lines[i]: |
| 101 | + lines[i] = lines[i].replace("def test", "async def test") |
| 102 | + for k in replacements: |
| 103 | + if k in lines[i]: |
| 104 | + lines[i] = lines[i].replace(k, replacements[k]) |
| 105 | + for k in async_methods: |
| 106 | + if k + "(" in lines[i]: |
| 107 | + tokens = lines[i].split(" ") |
| 108 | + for j in range(len(tokens)): |
| 109 | + if k + "(" in tokens[j]: |
| 110 | + if j < 2: |
| 111 | + tokens.insert(0, "await") |
| 112 | + else: |
| 113 | + tokens.insert(j, "await") |
| 114 | + break |
| 115 | + new_line = " ".join(tokens) |
| 116 | + |
| 117 | + lines[i] = new_line |
| 118 | + |
| 119 | + return lines |
| 120 | + |
| 121 | + |
| 122 | +def process_file(input_file: str, output_file: str) -> None: |
| 123 | + with open(input_file, "r+") as f: |
| 124 | + lines = f.readlines() |
| 125 | + lines = apply_replacements(lines) |
| 126 | + |
| 127 | + with open(output_file, "w+") as f2: |
| 128 | + f2.seek(0) |
| 129 | + f2.writelines(lines) |
| 130 | + f2.truncate() |
| 131 | + |
| 132 | + |
| 133 | +def main() -> None: |
| 134 | + args = sys.argv[1:] |
| 135 | + sync_file = "./test/" + args[0] |
| 136 | + async_file = "./" + args[0] |
| 137 | + |
| 138 | + process_file(sync_file, async_file) |
| 139 | + |
| 140 | + |
| 141 | +main() |
0 commit comments