Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(alloydb): Added generate batch embeddings sample #12721

Merged
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
06cdd77
docs: Added generate batch embeddings sample
twishabansal Oct 23, 2024
1e24a67
Added outputs
twishabansal Oct 23, 2024
7166b56
Changes to be able to run the notebook in local
twishabansal Oct 23, 2024
f91d286
Improved structure and readability
twishabansal Oct 24, 2024
4cc34e4
Back to old commit
twishabansal Oct 25, 2024
f912ec3
Back to working code
twishabansal Oct 25, 2024
169e7d2
Resolved comments
twishabansal Oct 25, 2024
1a92d0b
Added indentation
twishabansal Oct 25, 2024
baa4f7f
Merge branch 'GoogleCloudPlatform:main' into generate_batch_embeddings
twishabansal Oct 28, 2024
d8cbb33
code cleanup
twishabansal Oct 28, 2024
951cb04
Limit the batch size for text embeddings
twishabansal Oct 29, 2024
6142e9c
fixed: any empty cols to embed, max instances per prediction
twishabansal Oct 30, 2024
888a800
Moved connector above
twishabansal Oct 30, 2024
edca943
lint
twishabansal Oct 30, 2024
37d06a2
cleanup
twishabansal Oct 30, 2024
90b38a4
Merge branch 'main' into generate_batch_embeddings
iennae Oct 30, 2024
f424b23
No errors on empty data to embed
twishabansal Nov 6, 2024
73a8d4b
Retry on all embed errors
twishabansal Nov 6, 2024
a56940e
Added function docstrings
twishabansal Nov 8, 2024
8f97fbe
Renamed notebook
twishabansal Nov 8, 2024
ebda403
Deleted parameter map
twishabansal Nov 8, 2024
93f6611
formatting
twishabansal Nov 8, 2024
482ba08
Improve clarity of embeddings workflow setup
twishabansal Nov 8, 2024
5d37b03
Minor fix
twishabansal Nov 8, 2024
752b26b
small fix
twishabansal Nov 8, 2024
4d9359d
Added tests for ipynb notebook
twishabansal Nov 8, 2024
729e12c
Added license header
twishabansal Nov 8, 2024
e776b8b
formatting fix
twishabansal Nov 14, 2024
0a2542e
log embeddings failure with batch
twishabansal Nov 14, 2024
3a3d065
logged data for which embedding is failing
twishabansal Nov 14, 2024
d1995ec
Moved files to work with automated tests
twishabansal Nov 14, 2024
7c9f122
fix imports
twishabansal Nov 14, 2024
0ee945e
Update alloydb/notebooks/embeddings_batch_processing_e2e_test.py
glasnt Nov 14, 2024
38da55e
fix lint errors
twishabansal Nov 18, 2024
168473b
fix import order
twishabansal Nov 18, 2024
e08caa6
Update alloydb/notebooks/embeddings_batch_processing_e2e_test.py
glasnt Nov 19, 2024
ebaa427
fix: ignore Python 3.8 (pandas deps issue)
glasnt Nov 20, 2024
aa253d6
update tested python versions
glasnt Nov 20, 2024
371825b
Rename e2e_file.py, update header commentary
glasnt Nov 20, 2024
19e2039
Update alloydb/notebooks/e2e_test.py
glasnt Nov 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions alloydb/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
import re
import subprocess
import sys
import textwrap
import uuid
from collections.abc import Callable, Iterable
from datetime import datetime
from typing import AsyncIterator

import pytest
import pytest_asyncio


def get_env_var(key: str) -> str:
v = os.environ.get(key)
if v is None:
raise ValueError(f"Must set env var {key}")
return v


@pytest.fixture(scope="session")
def table_name() -> str:
return "investments"


@pytest.fixture(scope="session")
def cluster_name() -> str:
return get_env_var("ALLOYDB_CLUSTER")


@pytest.fixture(scope="session")
def instance_name() -> str:
return get_env_var("ALLOYDB_INSTANCE")


@pytest.fixture(scope="session")
def region() -> str:
return get_env_var("ALLOYDB_REGION")


@pytest.fixture(scope="session")
def database_name() -> str:
return get_env_var("ALLOYDB_DATABASE_NAME")


@pytest.fixture(scope="session")
def password() -> str:
return get_env_var("ALLOYDB_PASSWORD")


@pytest_asyncio.fixture(scope="session")
def project_id() -> str:
gcp_project = get_env_var("GOOGLE_CLOUD_PROJECT")
run_cmd("gcloud", "config", "set", "project", gcp_project)
# Since everything requires the project, let's confiugre and show some
# debugging information here.
run_cmd("gcloud", "version")
run_cmd("gcloud", "config", "list")
return gcp_project


def run_cmd(*cmd: str) -> subprocess.CompletedProcess:
try:
print(f">> {cmd}")
start = datetime.now()
p = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(p.stderr.decode("utf-8"))
print(p.stdout.decode("utf-8"))
elapsed = (datetime.now() - start).seconds
minutes = int(elapsed / 60)
seconds = elapsed - minutes * 60
print(f"Command `{cmd[0]}` finished in {minutes}m {seconds}s")
return p
except subprocess.CalledProcessError as e:
# Include the error message from the failed command.
print(e.stderr.decode("utf-8"))
print(e.stdout.decode("utf-8"))
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e


def run_notebook(
ipynb_file: str,
prelude: str = "",
section: str = "",
variables: dict = {},
replace: dict[str, str] = {},
preprocess: Callable[[str], str] = lambda source: source,
skip_shell_commands: bool = False,
until_end: bool = False,
) -> None:
import nbformat
from nbclient.client import NotebookClient
from nbclient.exceptions import CellExecutionError

def notebook_filter_section(
start: str,
end: str,
cells: list[nbformat.NotebookNode],
until_end: bool = False,
) -> Iterable[nbformat.NotebookNode]:
in_section = False
for cell in cells:
if cell["cell_type"] == "markdown":
if not in_section and cell["source"].startswith(start):
in_section = True
elif in_section and not until_end and cell["source"].startswith(end):
return
if in_section:
yield cell

# Regular expression to match and remove shell commands from the notebook.
# https://regex101.com/r/EHWBpT/1
shell_command_re = re.compile(r"^!((?:[^\n]+\\\n)*(?:[^\n]+))$", re.MULTILINE)
# Compile regular expressions for variable substitutions.
# https://regex101.com/r/e32vfW/1
compiled_substitutions = [
(
re.compile(rf"""\b{name}\s*=\s*(?:f?'[^']*'|f?"[^"]*"|\w+)"""),
f"{name} = {repr(value)}",
)
for name, value in variables.items()
]
# Filter the section if any, otherwise use the entire notebook.
nb = nbformat.read(ipynb_file, as_version=4)
if section:
start = section
end = section.split(" ", 1)[0] + " "
nb.cells = list(notebook_filter_section(start, end, nb.cells, until_end))
if len(nb.cells) == 0:
raise ValueError(
f"Section {repr(section)} not found in notebook {repr(ipynb_file)}"
)
# Preprocess the cells.
for cell in nb.cells:
# Only preprocess code cells.
if cell["cell_type"] != "code":
continue
# Run any custom preprocessing functions before.
cell["source"] = preprocess(cell["source"])
# Preprocess shell commands.
if skip_shell_commands:
cmd = "pass"
cell["source"] = shell_command_re.sub(cmd, cell["source"])
else:
cell["source"] = shell_command_re.sub(r"_run(f'''\1''')", cell["source"])
# Apply variable substitutions.
for regex, new_value in compiled_substitutions:
cell["source"] = regex.sub(new_value, cell["source"])
# Apply replacements.
for old, new in replace.items():
cell["source"] = cell["source"].replace(old, new)
# Clear outputs.
cell["outputs"] = []
# Prepend the prelude cell.
prelude_src = textwrap.dedent(
"""\
def _run(cmd):
import subprocess as _sp
import sys as _sys
_p = _sp.run(cmd, shell=True, stdout=_sp.PIPE, stderr=_sp.PIPE)
_stdout = _p.stdout.decode('utf-8').strip()
_stderr = _p.stderr.decode('utf-8').strip()
if _stdout:
print(f'➜ !{cmd}')
print(_stdout)
if _stderr:
print(f'➜ !{cmd}', file=_sys.stderr)
print(_stderr, file=_sys.stderr)
if _p.returncode:
raise RuntimeError('\\n'.join([
f"Command returned non-zero exit status {_p.returncode}.",
f"-------- command --------",
f"{cmd}",
f"-------- stderr --------",
f"{_stderr}",
f"-------- stdout --------",
f"{_stdout}",
]))
"""
+ prelude
)
nb.cells = [nbformat.v4.new_code_cell(prelude_src)] + nb.cells
# Run the notebook.
error = ""
client = NotebookClient(nb)
try:
client.execute()
except CellExecutionError as e:
# Remove colors and other escape characters to make it easier to read in the logs.
# https://stackoverflow.com/a/33925425
color_chars = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
error = color_chars.sub("", str(e))
for cell in nb.cells:
if cell["cell_type"] != "code":
continue
for output in cell["outputs"]:
if output.get("name") == "stdout":
print(color_chars.sub("", output["text"]))
elif output.get("name") == "stderr":
print(color_chars.sub("", output["text"]), file=sys.stderr)
if error:
raise RuntimeError(
f"Error on {repr(ipynb_file)}, section {repr(section)}: {error}"
)
155 changes: 155 additions & 0 deletions alloydb/notebooks/e2e_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Maintainer Note: this sample presumes data exists in
glasnt marked this conversation as resolved.
Show resolved Hide resolved
# ALLOYDB_TABLE_NAME within the ALLOYDB_(cluster/instance/database)

import asyncpg # type: ignore
import conftest as conftest # python-docs-samples/alloydb/conftest.py
from google.cloud.alloydb.connector import AsyncConnector, IPTypes
import pytest
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine


def preprocess(source: str) -> str:
# Skip the cells which add data to table
if "df" in source:
return ""
# Skip the colab auth cell
if "colab" in source:
return ""
return source


async def _init_connection_pool(
connector: AsyncConnector,
db_name: str,
project_id: str,
cluster_name: str,
instance_name: str,
region: str,
password: str,
) -> AsyncEngine:
connection_string = (
f"projects/{project_id}/locations/"
f"{region}/clusters/{cluster_name}/"
f"instances/{instance_name}"
)

async def getconn() -> asyncpg.Connection:
conn: asyncpg.Connection = await connector.connect(
connection_string,
"asyncpg",
user="postgres",
password=password,
db=db_name,
ip_type=IPTypes.PUBLIC,
)
return conn

pool = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
max_overflow=0,
)
return pool


@pytest.mark.asyncio
async def test_embeddings_batch_processing(
project_id: str,
cluster_name: str,
instance_name: str,
region: str,
database_name: str,
password: str,
table_name: str,
) -> None:
# TODO: Create new table
# Populate the table with embeddings by running the notebook
conftest.run_notebook(
"embeddings_batch_processing.ipynb",
variables={
"project_id": project_id,
"cluster_name": cluster_name,
"database_name": database_name,
"region": region,
"instance_name": instance_name,
"table_name": table_name,
},
preprocess=preprocess,
skip_shell_commands=True,
replace={
(
"password = input(\"Please provide "
"a password to be used for 'postgres' "
"database user: \")"
): f"password = '{password}'",
(
"await create_db("
"database_name=database_name, "
"connector=connector)"
): "",
},
until_end=True,
)

# Connect to the populated table for validation and clean up
async with AsyncConnector() as connector:
pool = await _init_connection_pool(
connector,
database_name,
project_id,
cluster_name,
instance_name,
region,
password,
)
async with pool.connect() as conn:
# Validate that embeddings are non-empty for all rows
result = await conn.execute(
sqlalchemy.text(
f"SELECT COUNT(*) FROM "
f"{table_name} WHERE "
f"analysis_embedding IS NULL"
)
)
row = result.fetchone()
assert row[0] == 0
result = await conn.execute(
sqlalchemy.text(
f"SELECT COUNT(*) FROM "
f"{table_name} WHERE "
f"overview_embedding IS NULL"
)
)
row = result.fetchone()
assert row[0] == 0

# Get the table back to the original state
await conn.execute(
sqlalchemy.text(
f"UPDATE {table_name} set "
f"analysis_embedding = NULL"
)
)
await conn.execute(
sqlalchemy.text(
f"UPDATE {table_name} set "
f"overview_embedding = NULL"
)
)
await conn.commit()
await pool.dispose()
Loading