diff --git a/alloydb/notebooks/batch_embeddings_update.ipynb b/alloydb/notebooks/batch_embeddings_update.ipynb index 4c836eb88b17..5a1ac3e7d0b0 100644 --- a/alloydb/notebooks/batch_embeddings_update.ipynb +++ b/alloydb/notebooks/batch_embeddings_update.ipynb @@ -55,22 +55,6 @@ "* A Google Cloud Account and Google Cloud Project" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "vHdR4fF3vLWA" - }, - "source": [ - "## Objectives\n", - "\n", - "In the following instructions you will learn to:\n", - "\n", - "1. Install required dependencies for our application\n", - "2. Set up authentication for our project\n", - "3. Set up a AlloyDB for PostgreSQL Instance\n", - "4. Import the data used by our application" - ] - }, { "cell_type": "markdown", "metadata": { @@ -382,6 +366,17 @@ " return pool" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.cloud.alloydb.connector import AsyncConnector\n", + "\n", + "connector = AsyncConnector()" + ] + }, { "cell_type": "markdown", "metadata": { @@ -405,15 +400,9 @@ }, "outputs": [], "source": [ - "from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine\n", "from sqlalchemy import text, exc\n", "\n", - "from google.cloud.alloydb.connector import AsyncConnector, IPTypes\n", - "\n", - "async def create_db(database_name):\n", - " # Get a raw connection directly from the connector\n", - " connector = AsyncConnector()\n", - " connection_string = f\"projects/{project_id}/locations/{region}/clusters/{cluster_name}/instances/{instance_name}\"\n", + "async def create_db(database_name, connector): \n", " pool = await init_connection_pool(connector, \"postgres\")\n", " async with pool.connect() as conn:\n", " try:\n", @@ -423,7 +412,7 @@ " except exc.ProgrammingError:\n", " print(f\"Database '{database_name}' already exists\")\n", "\n", - "await create_db(database_name=database_name)" + "await create_db(database_name=database_name, connector=connector)" ] }, { @@ -600,7 +589,7 @@ " \"overview\": row[\"overview\"],\n", " \"analysis\": row[\"analysis\"],\n", " }\n", - " for index, row in df.iterrows()\n", + " for _, row in df.iterrows()\n", "]" ] }, @@ -614,8 +603,6 @@ "source": [ "from google.cloud.alloydb.connector import AsyncConnector\n", "\n", - "connector = AsyncConnector()\n", - "\n", "# Create table and insert data\n", "async def insert_data(pool):\n", " async with pool.connect() as db_conn:\n", @@ -1042,7 +1029,6 @@ "source": [ "import vertexai\n", "import time\n", - "import asyncio\n", "from vertexai.language_models import TextEmbeddingModel\n", "\n", "pool_size = 10\n", @@ -1051,7 +1037,6 @@ "total_char_count = 0\n", "\n", "# Set up connections to the database\n", - "connector = AsyncConnector()\n", "pool = await init_connection_pool(connector, database_name, pool_size=pool_size)\n", "\n", "# Initialise VertexAI and the model to be used to generate embeddings\n", @@ -1067,10 +1052,14 @@ "batch_data = batch_source_data(source_data, cols_to_embed)\n", "\n", "# Generate embeddings for the batched data concurrently\n", - "embeddings_data = embed_objects_concurrently(cols_to_embed, batch_data, model, task, max_concurrency=embed_data_concurrency)\n", + "embeddings_data = embed_objects_concurrently(\n", + " cols_to_embed, batch_data, model, task, max_concurrency=embed_data_concurrency\n", + ")\n", "\n", "# Update the database with the generated embeddings concurrently\n", - "await batch_update_rows_concurrently(pool, embeddings_data, cols_to_embed, max_concurrency=batch_update_concurrency)\n", + "await batch_update_rows_concurrently(\n", + " pool, embeddings_data, cols_to_embed, max_concurrency=batch_update_concurrency\n", + ")\n", "\n", "end_time = time.monotonic()\n", "elapsed_time = end_time - start_time\n", @@ -1084,15 +1073,6 @@ "print(f\"Total run time: {elapsed_time:.2f} seconds\")\n", "print(f\"Total characters embedded: {total_char_count}\")" ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": { - "id": "fzZJsWRZAMxs" - }, - "outputs": [], - "source": [] } ], "metadata": {