diff --git a/pgsqlite/pgsqlite.py b/pgsqlite/pgsqlite.py index 18b2e1c..d06c514 100644 --- a/pgsqlite/pgsqlite.py +++ b/pgsqlite/pgsqlite.py @@ -213,6 +213,10 @@ def get_table_sql(self, table: ParsedTable) -> SQL: col_sql_str = col.parsed_column.sql(dialect="postgres") if "SERIAL" in col_sql_str: col_sql_str = col_sql_str.replace("INT", "") + + if "INT NOT NULL" in col_sql_str and col.source_name in table.src_table.pks: + col_sql_str = col_sql_str.replace("INT", "SERIAL") + if "PRIMARY KEY SERIAL" in col_sql_str: col_sql_str = col_sql_str.replace("PRIMARY KEY SERIAL", "SERIAL PRIMARY KEY") cols[col.source_name] = SQL(col_sql_str) @@ -393,6 +397,12 @@ async def write_table_data(self, table: ParsedTable) -> None: sl_conn = sqlite3.connect(self.sqlite_filename) sl_cur = sl_conn.cursor() logger.info(f"Loading data into {table}", table=table.transpiled_name) + + pk_col_name = next((col.name for col in table.src_table.columns if col.is_pk and col.notnull and col.type == "INTEGER"), None) + last_index = None + if pk_col_name: + last_index, = sl_cur.execute(f'SELECT MAX("{pk_col_name}") FROM "{table.source_name}"').fetchone() + # Given the table name came from the SQLITE database, and we're using it # to read from the sqlite database, we are okay with the literal substitution here sl_cur.execute(f'SELECT * FROM "{table.source_name}"') @@ -423,6 +433,9 @@ async def write_table_data(self, table: ParsedTable) -> None: self.summary["tables"]["data"][table.source_name]["status"] = f"LOADED {rows_copied}" logger.info(f"Finished loading {rows_copied} rows of data into {table.transpiled_name}") + if pk_col_name and last_index: + logger.info(f"Updating sequence data for {table.transpiled_name}.{pk_col_name}") + await pg_cur.execute(f'ALTER SEQUENCE "{table.transpiled_name}_{pk_col_name}_seq" RESTART WITH {last_index + 1}') sl_conn.close()