Skip to content

Commit

Permalink
feat(alloydb): Added generate batch embeddings sample (#12721)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Jennifer Davis <[email protected]>
Co-authored-by: Katie McLaughlin <[email protected]>
Co-authored-by: Katie McLaughlin <[email protected]>
  • Loading branch information
4 people authored Nov 20, 2024
1 parent 0fdcba8 commit 08e0146
Show file tree
Hide file tree
Showing 7 changed files with 1,592 additions and 0 deletions.
225 changes: 225 additions & 0 deletions alloydb/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
import re
import subprocess
import sys
import textwrap
import uuid
from collections.abc import Callable, Iterable
from datetime import datetime
from typing import AsyncIterator

import pytest
import pytest_asyncio


def get_env_var(key: str) -> str:
v = os.environ.get(key)
if v is None:
raise ValueError(f"Must set env var {key}")
return v


@pytest.fixture(scope="session")
def table_name() -> str:
return "investments"


@pytest.fixture(scope="session")
def cluster_name() -> str:
return get_env_var("ALLOYDB_CLUSTER")


@pytest.fixture(scope="session")
def instance_name() -> str:
return get_env_var("ALLOYDB_INSTANCE")


@pytest.fixture(scope="session")
def region() -> str:
return get_env_var("ALLOYDB_REGION")


@pytest.fixture(scope="session")
def database_name() -> str:
return get_env_var("ALLOYDB_DATABASE_NAME")


@pytest.fixture(scope="session")
def password() -> str:
return get_env_var("ALLOYDB_PASSWORD")


@pytest_asyncio.fixture(scope="session")
def project_id() -> str:
gcp_project = get_env_var("GOOGLE_CLOUD_PROJECT")
run_cmd("gcloud", "config", "set", "project", gcp_project)
# Since everything requires the project, let's confiugre and show some
# debugging information here.
run_cmd("gcloud", "version")
run_cmd("gcloud", "config", "list")
return gcp_project


def run_cmd(*cmd: str) -> subprocess.CompletedProcess:
try:
print(f">> {cmd}")
start = datetime.now()
p = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(p.stderr.decode("utf-8"))
print(p.stdout.decode("utf-8"))
elapsed = (datetime.now() - start).seconds
minutes = int(elapsed / 60)
seconds = elapsed - minutes * 60
print(f"Command `{cmd[0]}` finished in {minutes}m {seconds}s")
return p
except subprocess.CalledProcessError as e:
# Include the error message from the failed command.
print(e.stderr.decode("utf-8"))
print(e.stdout.decode("utf-8"))
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e


def run_notebook(
ipynb_file: str,
prelude: str = "",
section: str = "",
variables: dict = {},
replace: dict[str, str] = {},
preprocess: Callable[[str], str] = lambda source: source,
skip_shell_commands: bool = False,
until_end: bool = False,
) -> None:
import nbformat
from nbclient.client import NotebookClient
from nbclient.exceptions import CellExecutionError

def notebook_filter_section(
start: str,
end: str,
cells: list[nbformat.NotebookNode],
until_end: bool = False,
) -> Iterable[nbformat.NotebookNode]:
in_section = False
for cell in cells:
if cell["cell_type"] == "markdown":
if not in_section and cell["source"].startswith(start):
in_section = True
elif in_section and not until_end and cell["source"].startswith(end):
return
if in_section:
yield cell

# Regular expression to match and remove shell commands from the notebook.
# https://regex101.com/r/EHWBpT/1
shell_command_re = re.compile(r"^!((?:[^\n]+\\\n)*(?:[^\n]+))$", re.MULTILINE)
# Compile regular expressions for variable substitutions.
# https://regex101.com/r/e32vfW/1
compiled_substitutions = [
(
re.compile(rf"""\b{name}\s*=\s*(?:f?'[^']*'|f?"[^"]*"|\w+)"""),
f"{name} = {repr(value)}",
)
for name, value in variables.items()
]
# Filter the section if any, otherwise use the entire notebook.
nb = nbformat.read(ipynb_file, as_version=4)
if section:
start = section
end = section.split(" ", 1)[0] + " "
nb.cells = list(notebook_filter_section(start, end, nb.cells, until_end))
if len(nb.cells) == 0:
raise ValueError(
f"Section {repr(section)} not found in notebook {repr(ipynb_file)}"
)
# Preprocess the cells.
for cell in nb.cells:
# Only preprocess code cells.
if cell["cell_type"] != "code":
continue
# Run any custom preprocessing functions before.
cell["source"] = preprocess(cell["source"])
# Preprocess shell commands.
if skip_shell_commands:
cmd = "pass"
cell["source"] = shell_command_re.sub(cmd, cell["source"])
else:
cell["source"] = shell_command_re.sub(r"_run(f'''\1''')", cell["source"])
# Apply variable substitutions.
for regex, new_value in compiled_substitutions:
cell["source"] = regex.sub(new_value, cell["source"])
# Apply replacements.
for old, new in replace.items():
cell["source"] = cell["source"].replace(old, new)
# Clear outputs.
cell["outputs"] = []
# Prepend the prelude cell.
prelude_src = textwrap.dedent(
"""\
def _run(cmd):
import subprocess as _sp
import sys as _sys
_p = _sp.run(cmd, shell=True, stdout=_sp.PIPE, stderr=_sp.PIPE)
_stdout = _p.stdout.decode('utf-8').strip()
_stderr = _p.stderr.decode('utf-8').strip()
if _stdout:
print(f'➜ !{cmd}')
print(_stdout)
if _stderr:
print(f'➜ !{cmd}', file=_sys.stderr)
print(_stderr, file=_sys.stderr)
if _p.returncode:
raise RuntimeError('\\n'.join([
f"Command returned non-zero exit status {_p.returncode}.",
f"-------- command --------",
f"{cmd}",
f"-------- stderr --------",
f"{_stderr}",
f"-------- stdout --------",
f"{_stdout}",
]))
"""
+ prelude
)
nb.cells = [nbformat.v4.new_code_cell(prelude_src)] + nb.cells
# Run the notebook.
error = ""
client = NotebookClient(nb)
try:
client.execute()
except CellExecutionError as e:
# Remove colors and other escape characters to make it easier to read in the logs.
# https://stackoverflow.com/a/33925425
color_chars = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
error = color_chars.sub("", str(e))
for cell in nb.cells:
if cell["cell_type"] != "code":
continue
for output in cell["outputs"]:
if output.get("name") == "stdout":
print(color_chars.sub("", output["text"]))
elif output.get("name") == "stderr":
print(color_chars.sub("", output["text"]), file=sys.stderr)
if error:
raise RuntimeError(
f"Error on {repr(ipynb_file)}, section {repr(section)}: {error}"
)
155 changes: 155 additions & 0 deletions alloydb/notebooks/e2e_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Maintainer Note: this sample presumes data exists in
# ALLOYDB_TABLE_NAME within the ALLOYDB_(cluster/instance/database)

import asyncpg # type: ignore
import conftest as conftest # python-docs-samples/alloydb/conftest.py
from google.cloud.alloydb.connector import AsyncConnector, IPTypes
import pytest
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine


def preprocess(source: str) -> str:
# Skip the cells which add data to table
if "df" in source:
return ""
# Skip the colab auth cell
if "colab" in source:
return ""
return source


async def _init_connection_pool(
connector: AsyncConnector,
db_name: str,
project_id: str,
cluster_name: str,
instance_name: str,
region: str,
password: str,
) -> AsyncEngine:
connection_string = (
f"projects/{project_id}/locations/"
f"{region}/clusters/{cluster_name}/"
f"instances/{instance_name}"
)

async def getconn() -> asyncpg.Connection:
conn: asyncpg.Connection = await connector.connect(
connection_string,
"asyncpg",
user="postgres",
password=password,
db=db_name,
ip_type=IPTypes.PUBLIC,
)
return conn

pool = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
max_overflow=0,
)
return pool


@pytest.mark.asyncio
async def test_embeddings_batch_processing(
project_id: str,
cluster_name: str,
instance_name: str,
region: str,
database_name: str,
password: str,
table_name: str,
) -> None:
# TODO: Create new table
# Populate the table with embeddings by running the notebook
conftest.run_notebook(
"embeddings_batch_processing.ipynb",
variables={
"project_id": project_id,
"cluster_name": cluster_name,
"database_name": database_name,
"region": region,
"instance_name": instance_name,
"table_name": table_name,
},
preprocess=preprocess,
skip_shell_commands=True,
replace={
(
"password = input(\"Please provide "
"a password to be used for 'postgres' "
"database user: \")"
): f"password = '{password}'",
(
"await create_db("
"database_name=database_name, "
"connector=connector)"
): "",
},
until_end=True,
)

# Connect to the populated table for validation and clean up
async with AsyncConnector() as connector:
pool = await _init_connection_pool(
connector,
database_name,
project_id,
cluster_name,
instance_name,
region,
password,
)
async with pool.connect() as conn:
# Validate that embeddings are non-empty for all rows
result = await conn.execute(
sqlalchemy.text(
f"SELECT COUNT(*) FROM "
f"{table_name} WHERE "
f"analysis_embedding IS NULL"
)
)
row = result.fetchone()
assert row[0] == 0
result = await conn.execute(
sqlalchemy.text(
f"SELECT COUNT(*) FROM "
f"{table_name} WHERE "
f"overview_embedding IS NULL"
)
)
row = result.fetchone()
assert row[0] == 0

# Get the table back to the original state
await conn.execute(
sqlalchemy.text(
f"UPDATE {table_name} set "
f"analysis_embedding = NULL"
)
)
await conn.execute(
sqlalchemy.text(
f"UPDATE {table_name} set "
f"overview_embedding = NULL"
)
)
await conn.commit()
await pool.dispose()
Loading

0 comments on commit 08e0146

Please sign in to comment.