-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(alloydb): Added generate batch embeddings sample (#12721)
--------- Co-authored-by: Jennifer Davis <[email protected]> Co-authored-by: Katie McLaughlin <[email protected]> Co-authored-by: Katie McLaughlin <[email protected]>
- Loading branch information
1 parent
0fdcba8
commit 08e0146
Showing
7 changed files
with
1,592 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import os | ||
import re | ||
import subprocess | ||
import sys | ||
import textwrap | ||
import uuid | ||
from collections.abc import Callable, Iterable | ||
from datetime import datetime | ||
from typing import AsyncIterator | ||
|
||
import pytest | ||
import pytest_asyncio | ||
|
||
|
||
def get_env_var(key: str) -> str: | ||
v = os.environ.get(key) | ||
if v is None: | ||
raise ValueError(f"Must set env var {key}") | ||
return v | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def table_name() -> str: | ||
return "investments" | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def cluster_name() -> str: | ||
return get_env_var("ALLOYDB_CLUSTER") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def instance_name() -> str: | ||
return get_env_var("ALLOYDB_INSTANCE") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def region() -> str: | ||
return get_env_var("ALLOYDB_REGION") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def database_name() -> str: | ||
return get_env_var("ALLOYDB_DATABASE_NAME") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def password() -> str: | ||
return get_env_var("ALLOYDB_PASSWORD") | ||
|
||
|
||
@pytest_asyncio.fixture(scope="session") | ||
def project_id() -> str: | ||
gcp_project = get_env_var("GOOGLE_CLOUD_PROJECT") | ||
run_cmd("gcloud", "config", "set", "project", gcp_project) | ||
# Since everything requires the project, let's confiugre and show some | ||
# debugging information here. | ||
run_cmd("gcloud", "version") | ||
run_cmd("gcloud", "config", "list") | ||
return gcp_project | ||
|
||
|
||
def run_cmd(*cmd: str) -> subprocess.CompletedProcess: | ||
try: | ||
print(f">> {cmd}") | ||
start = datetime.now() | ||
p = subprocess.run( | ||
cmd, | ||
check=True, | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
) | ||
print(p.stderr.decode("utf-8")) | ||
print(p.stdout.decode("utf-8")) | ||
elapsed = (datetime.now() - start).seconds | ||
minutes = int(elapsed / 60) | ||
seconds = elapsed - minutes * 60 | ||
print(f"Command `{cmd[0]}` finished in {minutes}m {seconds}s") | ||
return p | ||
except subprocess.CalledProcessError as e: | ||
# Include the error message from the failed command. | ||
print(e.stderr.decode("utf-8")) | ||
print(e.stdout.decode("utf-8")) | ||
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e | ||
|
||
|
||
def run_notebook( | ||
ipynb_file: str, | ||
prelude: str = "", | ||
section: str = "", | ||
variables: dict = {}, | ||
replace: dict[str, str] = {}, | ||
preprocess: Callable[[str], str] = lambda source: source, | ||
skip_shell_commands: bool = False, | ||
until_end: bool = False, | ||
) -> None: | ||
import nbformat | ||
from nbclient.client import NotebookClient | ||
from nbclient.exceptions import CellExecutionError | ||
|
||
def notebook_filter_section( | ||
start: str, | ||
end: str, | ||
cells: list[nbformat.NotebookNode], | ||
until_end: bool = False, | ||
) -> Iterable[nbformat.NotebookNode]: | ||
in_section = False | ||
for cell in cells: | ||
if cell["cell_type"] == "markdown": | ||
if not in_section and cell["source"].startswith(start): | ||
in_section = True | ||
elif in_section and not until_end and cell["source"].startswith(end): | ||
return | ||
if in_section: | ||
yield cell | ||
|
||
# Regular expression to match and remove shell commands from the notebook. | ||
# https://regex101.com/r/EHWBpT/1 | ||
shell_command_re = re.compile(r"^!((?:[^\n]+\\\n)*(?:[^\n]+))$", re.MULTILINE) | ||
# Compile regular expressions for variable substitutions. | ||
# https://regex101.com/r/e32vfW/1 | ||
compiled_substitutions = [ | ||
( | ||
re.compile(rf"""\b{name}\s*=\s*(?:f?'[^']*'|f?"[^"]*"|\w+)"""), | ||
f"{name} = {repr(value)}", | ||
) | ||
for name, value in variables.items() | ||
] | ||
# Filter the section if any, otherwise use the entire notebook. | ||
nb = nbformat.read(ipynb_file, as_version=4) | ||
if section: | ||
start = section | ||
end = section.split(" ", 1)[0] + " " | ||
nb.cells = list(notebook_filter_section(start, end, nb.cells, until_end)) | ||
if len(nb.cells) == 0: | ||
raise ValueError( | ||
f"Section {repr(section)} not found in notebook {repr(ipynb_file)}" | ||
) | ||
# Preprocess the cells. | ||
for cell in nb.cells: | ||
# Only preprocess code cells. | ||
if cell["cell_type"] != "code": | ||
continue | ||
# Run any custom preprocessing functions before. | ||
cell["source"] = preprocess(cell["source"]) | ||
# Preprocess shell commands. | ||
if skip_shell_commands: | ||
cmd = "pass" | ||
cell["source"] = shell_command_re.sub(cmd, cell["source"]) | ||
else: | ||
cell["source"] = shell_command_re.sub(r"_run(f'''\1''')", cell["source"]) | ||
# Apply variable substitutions. | ||
for regex, new_value in compiled_substitutions: | ||
cell["source"] = regex.sub(new_value, cell["source"]) | ||
# Apply replacements. | ||
for old, new in replace.items(): | ||
cell["source"] = cell["source"].replace(old, new) | ||
# Clear outputs. | ||
cell["outputs"] = [] | ||
# Prepend the prelude cell. | ||
prelude_src = textwrap.dedent( | ||
"""\ | ||
def _run(cmd): | ||
import subprocess as _sp | ||
import sys as _sys | ||
_p = _sp.run(cmd, shell=True, stdout=_sp.PIPE, stderr=_sp.PIPE) | ||
_stdout = _p.stdout.decode('utf-8').strip() | ||
_stderr = _p.stderr.decode('utf-8').strip() | ||
if _stdout: | ||
print(f'➜ !{cmd}') | ||
print(_stdout) | ||
if _stderr: | ||
print(f'➜ !{cmd}', file=_sys.stderr) | ||
print(_stderr, file=_sys.stderr) | ||
if _p.returncode: | ||
raise RuntimeError('\\n'.join([ | ||
f"Command returned non-zero exit status {_p.returncode}.", | ||
f"-------- command --------", | ||
f"{cmd}", | ||
f"-------- stderr --------", | ||
f"{_stderr}", | ||
f"-------- stdout --------", | ||
f"{_stdout}", | ||
])) | ||
""" | ||
+ prelude | ||
) | ||
nb.cells = [nbformat.v4.new_code_cell(prelude_src)] + nb.cells | ||
# Run the notebook. | ||
error = "" | ||
client = NotebookClient(nb) | ||
try: | ||
client.execute() | ||
except CellExecutionError as e: | ||
# Remove colors and other escape characters to make it easier to read in the logs. | ||
# https://stackoverflow.com/a/33925425 | ||
color_chars = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]") | ||
error = color_chars.sub("", str(e)) | ||
for cell in nb.cells: | ||
if cell["cell_type"] != "code": | ||
continue | ||
for output in cell["outputs"]: | ||
if output.get("name") == "stdout": | ||
print(color_chars.sub("", output["text"])) | ||
elif output.get("name") == "stderr": | ||
print(color_chars.sub("", output["text"]), file=sys.stderr) | ||
if error: | ||
raise RuntimeError( | ||
f"Error on {repr(ipynb_file)}, section {repr(section)}: {error}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2022 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Maintainer Note: this sample presumes data exists in | ||
# ALLOYDB_TABLE_NAME within the ALLOYDB_(cluster/instance/database) | ||
|
||
import asyncpg # type: ignore | ||
import conftest as conftest # python-docs-samples/alloydb/conftest.py | ||
from google.cloud.alloydb.connector import AsyncConnector, IPTypes | ||
import pytest | ||
import sqlalchemy | ||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine | ||
|
||
|
||
def preprocess(source: str) -> str: | ||
# Skip the cells which add data to table | ||
if "df" in source: | ||
return "" | ||
# Skip the colab auth cell | ||
if "colab" in source: | ||
return "" | ||
return source | ||
|
||
|
||
async def _init_connection_pool( | ||
connector: AsyncConnector, | ||
db_name: str, | ||
project_id: str, | ||
cluster_name: str, | ||
instance_name: str, | ||
region: str, | ||
password: str, | ||
) -> AsyncEngine: | ||
connection_string = ( | ||
f"projects/{project_id}/locations/" | ||
f"{region}/clusters/{cluster_name}/" | ||
f"instances/{instance_name}" | ||
) | ||
|
||
async def getconn() -> asyncpg.Connection: | ||
conn: asyncpg.Connection = await connector.connect( | ||
connection_string, | ||
"asyncpg", | ||
user="postgres", | ||
password=password, | ||
db=db_name, | ||
ip_type=IPTypes.PUBLIC, | ||
) | ||
return conn | ||
|
||
pool = create_async_engine( | ||
"postgresql+asyncpg://", | ||
async_creator=getconn, | ||
max_overflow=0, | ||
) | ||
return pool | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_embeddings_batch_processing( | ||
project_id: str, | ||
cluster_name: str, | ||
instance_name: str, | ||
region: str, | ||
database_name: str, | ||
password: str, | ||
table_name: str, | ||
) -> None: | ||
# TODO: Create new table | ||
# Populate the table with embeddings by running the notebook | ||
conftest.run_notebook( | ||
"embeddings_batch_processing.ipynb", | ||
variables={ | ||
"project_id": project_id, | ||
"cluster_name": cluster_name, | ||
"database_name": database_name, | ||
"region": region, | ||
"instance_name": instance_name, | ||
"table_name": table_name, | ||
}, | ||
preprocess=preprocess, | ||
skip_shell_commands=True, | ||
replace={ | ||
( | ||
"password = input(\"Please provide " | ||
"a password to be used for 'postgres' " | ||
"database user: \")" | ||
): f"password = '{password}'", | ||
( | ||
"await create_db(" | ||
"database_name=database_name, " | ||
"connector=connector)" | ||
): "", | ||
}, | ||
until_end=True, | ||
) | ||
|
||
# Connect to the populated table for validation and clean up | ||
async with AsyncConnector() as connector: | ||
pool = await _init_connection_pool( | ||
connector, | ||
database_name, | ||
project_id, | ||
cluster_name, | ||
instance_name, | ||
region, | ||
password, | ||
) | ||
async with pool.connect() as conn: | ||
# Validate that embeddings are non-empty for all rows | ||
result = await conn.execute( | ||
sqlalchemy.text( | ||
f"SELECT COUNT(*) FROM " | ||
f"{table_name} WHERE " | ||
f"analysis_embedding IS NULL" | ||
) | ||
) | ||
row = result.fetchone() | ||
assert row[0] == 0 | ||
result = await conn.execute( | ||
sqlalchemy.text( | ||
f"SELECT COUNT(*) FROM " | ||
f"{table_name} WHERE " | ||
f"overview_embedding IS NULL" | ||
) | ||
) | ||
row = result.fetchone() | ||
assert row[0] == 0 | ||
|
||
# Get the table back to the original state | ||
await conn.execute( | ||
sqlalchemy.text( | ||
f"UPDATE {table_name} set " | ||
f"analysis_embedding = NULL" | ||
) | ||
) | ||
await conn.execute( | ||
sqlalchemy.text( | ||
f"UPDATE {table_name} set " | ||
f"overview_embedding = NULL" | ||
) | ||
) | ||
await conn.commit() | ||
await pool.dispose() |
Oops, something went wrong.