Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pgsqlite/pgsqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def get_table_sql(self, table: ParsedTable) -> SQL:
col_sql_str = col.parsed_column.sql(dialect="postgres")
if "SERIAL" in col_sql_str:
col_sql_str = col_sql_str.replace("INT", "")

if "INT NOT NULL" in col_sql_str and col.source_name in table.src_table.pks:
col_sql_str = col_sql_str.replace("INT", "SERIAL")

if "PRIMARY KEY SERIAL" in col_sql_str:
col_sql_str = col_sql_str.replace("PRIMARY KEY SERIAL", "SERIAL PRIMARY KEY")
cols[col.source_name] = SQL(col_sql_str)
Expand Down Expand Up @@ -393,6 +397,12 @@ async def write_table_data(self, table: ParsedTable) -> None:
sl_conn = sqlite3.connect(self.sqlite_filename)
sl_cur = sl_conn.cursor()
logger.info(f"Loading data into {table}", table=table.transpiled_name)

pk_col_name = next((col.name for col in table.src_table.columns if col.is_pk and col.notnull and col.type == "INTEGER"), None)
last_index = None
if pk_col_name:
last_index, = sl_cur.execute(f'SELECT MAX("{pk_col_name}") FROM "{table.source_name}"').fetchone()

# Given the table name came from the SQLITE database, and we're using it
# to read from the sqlite database, we are okay with the literal substitution here
sl_cur.execute(f'SELECT * FROM "{table.source_name}"')
Expand Down Expand Up @@ -423,6 +433,9 @@ async def write_table_data(self, table: ParsedTable) -> None:

self.summary["tables"]["data"][table.source_name]["status"] = f"LOADED {rows_copied}"
logger.info(f"Finished loading {rows_copied} rows of data into {table.transpiled_name}")
if pk_col_name and last_index:
logger.info(f"Updating sequence data for {table.transpiled_name}.{pk_col_name}")
await pg_cur.execute(f'ALTER SEQUENCE "{table.transpiled_name}_{pk_col_name}_seq" RESTART WITH {last_index + 1}')

sl_conn.close()

Expand Down