-
Notifications
You must be signed in to change notification settings - Fork 6.6k
feat(alloydb): Added generate batch embeddings sample #12721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
glasnt
merged 40 commits into
GoogleCloudPlatform:main
from
twishabansal:generate_batch_embeddings
Nov 20, 2024
Merged
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
06cdd77
docs: Added generate batch embeddings sample
twishabansal 1e24a67
Added outputs
twishabansal 7166b56
Changes to be able to run the notebook in local
twishabansal f91d286
Improved structure and readability
twishabansal 4cc34e4
Back to old commit
twishabansal f912ec3
Back to working code
twishabansal 169e7d2
Resolved comments
twishabansal 1a92d0b
Added indentation
twishabansal baa4f7f
Merge branch 'GoogleCloudPlatform:main' into generate_batch_embeddings
twishabansal d8cbb33
code cleanup
twishabansal 951cb04
Limit the batch size for text embeddings
twishabansal 6142e9c
fixed: any empty cols to embed, max instances per prediction
twishabansal 888a800
Moved connector above
twishabansal edca943
lint
twishabansal 37d06a2
cleanup
twishabansal 90b38a4
Merge branch 'main' into generate_batch_embeddings
iennae f424b23
No errors on empty data to embed
twishabansal 73a8d4b
Retry on all embed errors
twishabansal a56940e
Added function docstrings
twishabansal 8f97fbe
Renamed notebook
twishabansal ebda403
Deleted parameter map
twishabansal 93f6611
formatting
twishabansal 482ba08
Improve clarity of embeddings workflow setup
twishabansal 5d37b03
Minor fix
twishabansal 752b26b
small fix
twishabansal 4d9359d
Added tests for ipynb notebook
twishabansal 729e12c
Added license header
twishabansal e776b8b
formatting fix
twishabansal 0a2542e
log embeddings failure with batch
twishabansal 3a3d065
logged data for which embedding is failing
twishabansal d1995ec
Moved files to work with automated tests
twishabansal 7c9f122
fix imports
twishabansal 0ee945e
Update alloydb/notebooks/embeddings_batch_processing_e2e_test.py
glasnt 38da55e
fix lint errors
twishabansal 168473b
fix import order
twishabansal e08caa6
Update alloydb/notebooks/embeddings_batch_processing_e2e_test.py
glasnt ebaa427
fix: ignore Python 3.8 (pandas deps issue)
glasnt aa253d6
update tested python versions
glasnt 371825b
Rename e2e_file.py, update header commentary
glasnt 19e2039
Update alloydb/notebooks/e2e_test.py
glasnt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import os | ||
import re | ||
import subprocess | ||
import sys | ||
import textwrap | ||
import uuid | ||
from collections.abc import Callable, Iterable | ||
from datetime import datetime | ||
from typing import AsyncIterator | ||
|
||
import pytest | ||
import pytest_asyncio | ||
|
||
|
||
def get_env_var(key: str) -> str: | ||
v = os.environ.get(key) | ||
if v is None: | ||
raise ValueError(f"Must set env var {key}") | ||
return v | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def table_name() -> str: | ||
return "investments" | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def cluster_name() -> str: | ||
return get_env_var("ALLOYDB_CLUSTER") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def instance_name() -> str: | ||
return get_env_var("ALLOYDB_INSTANCE") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def region() -> str: | ||
return get_env_var("ALLOYDB_REGION") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def database_name() -> str: | ||
return get_env_var("ALLOYDB_DATABASE_NAME") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def password() -> str: | ||
return get_env_var("ALLOYDB_PASSWORD") | ||
|
||
|
||
@pytest_asyncio.fixture(scope="session") | ||
def project_id() -> str: | ||
gcp_project = get_env_var("GOOGLE_CLOUD_PROJECT") | ||
run_cmd("gcloud", "config", "set", "project", gcp_project) | ||
# Since everything requires the project, let's confiugre and show some | ||
# debugging information here. | ||
run_cmd("gcloud", "version") | ||
run_cmd("gcloud", "config", "list") | ||
return gcp_project | ||
|
||
|
||
def run_cmd(*cmd: str) -> subprocess.CompletedProcess: | ||
try: | ||
print(f">> {cmd}") | ||
start = datetime.now() | ||
p = subprocess.run( | ||
cmd, | ||
check=True, | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
) | ||
print(p.stderr.decode("utf-8")) | ||
print(p.stdout.decode("utf-8")) | ||
elapsed = (datetime.now() - start).seconds | ||
minutes = int(elapsed / 60) | ||
seconds = elapsed - minutes * 60 | ||
print(f"Command `{cmd[0]}` finished in {minutes}m {seconds}s") | ||
return p | ||
except subprocess.CalledProcessError as e: | ||
# Include the error message from the failed command. | ||
print(e.stderr.decode("utf-8")) | ||
print(e.stdout.decode("utf-8")) | ||
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e | ||
|
||
|
||
def run_notebook( | ||
ipynb_file: str, | ||
prelude: str = "", | ||
section: str = "", | ||
variables: dict = {}, | ||
replace: dict[str, str] = {}, | ||
preprocess: Callable[[str], str] = lambda source: source, | ||
skip_shell_commands: bool = False, | ||
until_end: bool = False, | ||
) -> None: | ||
import nbformat | ||
from nbclient.client import NotebookClient | ||
from nbclient.exceptions import CellExecutionError | ||
|
||
def notebook_filter_section( | ||
start: str, | ||
end: str, | ||
cells: list[nbformat.NotebookNode], | ||
until_end: bool = False, | ||
) -> Iterable[nbformat.NotebookNode]: | ||
in_section = False | ||
for cell in cells: | ||
if cell["cell_type"] == "markdown": | ||
if not in_section and cell["source"].startswith(start): | ||
in_section = True | ||
elif in_section and not until_end and cell["source"].startswith(end): | ||
return | ||
if in_section: | ||
yield cell | ||
|
||
# Regular expression to match and remove shell commands from the notebook. | ||
# https://regex101.com/r/EHWBpT/1 | ||
shell_command_re = re.compile(r"^!((?:[^\n]+\\\n)*(?:[^\n]+))$", re.MULTILINE) | ||
# Compile regular expressions for variable substitutions. | ||
# https://regex101.com/r/e32vfW/1 | ||
compiled_substitutions = [ | ||
( | ||
re.compile(rf"""\b{name}\s*=\s*(?:f?'[^']*'|f?"[^"]*"|\w+)"""), | ||
f"{name} = {repr(value)}", | ||
) | ||
for name, value in variables.items() | ||
] | ||
# Filter the section if any, otherwise use the entire notebook. | ||
nb = nbformat.read(ipynb_file, as_version=4) | ||
if section: | ||
start = section | ||
end = section.split(" ", 1)[0] + " " | ||
nb.cells = list(notebook_filter_section(start, end, nb.cells, until_end)) | ||
if len(nb.cells) == 0: | ||
raise ValueError( | ||
f"Section {repr(section)} not found in notebook {repr(ipynb_file)}" | ||
) | ||
# Preprocess the cells. | ||
for cell in nb.cells: | ||
# Only preprocess code cells. | ||
if cell["cell_type"] != "code": | ||
continue | ||
# Run any custom preprocessing functions before. | ||
cell["source"] = preprocess(cell["source"]) | ||
# Preprocess shell commands. | ||
if skip_shell_commands: | ||
cmd = "pass" | ||
cell["source"] = shell_command_re.sub(cmd, cell["source"]) | ||
else: | ||
cell["source"] = shell_command_re.sub(r"_run(f'''\1''')", cell["source"]) | ||
# Apply variable substitutions. | ||
for regex, new_value in compiled_substitutions: | ||
cell["source"] = regex.sub(new_value, cell["source"]) | ||
# Apply replacements. | ||
for old, new in replace.items(): | ||
cell["source"] = cell["source"].replace(old, new) | ||
# Clear outputs. | ||
cell["outputs"] = [] | ||
# Prepend the prelude cell. | ||
prelude_src = textwrap.dedent( | ||
"""\ | ||
def _run(cmd): | ||
import subprocess as _sp | ||
import sys as _sys | ||
_p = _sp.run(cmd, shell=True, stdout=_sp.PIPE, stderr=_sp.PIPE) | ||
_stdout = _p.stdout.decode('utf-8').strip() | ||
_stderr = _p.stderr.decode('utf-8').strip() | ||
if _stdout: | ||
print(f'➜ !{cmd}') | ||
print(_stdout) | ||
if _stderr: | ||
print(f'➜ !{cmd}', file=_sys.stderr) | ||
print(_stderr, file=_sys.stderr) | ||
if _p.returncode: | ||
raise RuntimeError('\\n'.join([ | ||
f"Command returned non-zero exit status {_p.returncode}.", | ||
f"-------- command --------", | ||
f"{cmd}", | ||
f"-------- stderr --------", | ||
f"{_stderr}", | ||
f"-------- stdout --------", | ||
f"{_stdout}", | ||
])) | ||
""" | ||
+ prelude | ||
) | ||
nb.cells = [nbformat.v4.new_code_cell(prelude_src)] + nb.cells | ||
# Run the notebook. | ||
error = "" | ||
client = NotebookClient(nb) | ||
try: | ||
client.execute() | ||
except CellExecutionError as e: | ||
# Remove colors and other escape characters to make it easier to read in the logs. | ||
# https://stackoverflow.com/a/33925425 | ||
color_chars = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]") | ||
error = color_chars.sub("", str(e)) | ||
for cell in nb.cells: | ||
if cell["cell_type"] != "code": | ||
continue | ||
for output in cell["outputs"]: | ||
if output.get("name") == "stdout": | ||
print(color_chars.sub("", output["text"])) | ||
elif output.get("name") == "stderr": | ||
print(color_chars.sub("", output["text"]), file=sys.stderr) | ||
if error: | ||
raise RuntimeError( | ||
f"Error on {repr(ipynb_file)}, section {repr(section)}: {error}" | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2022 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Maintainer Note: this sample presumes data exists in | ||
# ALLOYDB_TABLE_NAME within the ALLOYDB_(cluster/instance/database) | ||
|
||
import asyncpg # type: ignore | ||
import conftest as conftest # python-docs-samples/alloydb/conftest.py | ||
from google.cloud.alloydb.connector import AsyncConnector, IPTypes | ||
import pytest | ||
import sqlalchemy | ||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine | ||
|
||
|
||
def preprocess(source: str) -> str: | ||
# Skip the cells which add data to table | ||
if "df" in source: | ||
return "" | ||
# Skip the colab auth cell | ||
if "colab" in source: | ||
return "" | ||
return source | ||
|
||
|
||
async def _init_connection_pool( | ||
connector: AsyncConnector, | ||
db_name: str, | ||
project_id: str, | ||
cluster_name: str, | ||
instance_name: str, | ||
region: str, | ||
password: str, | ||
) -> AsyncEngine: | ||
connection_string = ( | ||
f"projects/{project_id}/locations/" | ||
f"{region}/clusters/{cluster_name}/" | ||
f"instances/{instance_name}" | ||
) | ||
|
||
async def getconn() -> asyncpg.Connection: | ||
conn: asyncpg.Connection = await connector.connect( | ||
connection_string, | ||
"asyncpg", | ||
user="postgres", | ||
password=password, | ||
db=db_name, | ||
ip_type=IPTypes.PUBLIC, | ||
) | ||
return conn | ||
|
||
pool = create_async_engine( | ||
"postgresql+asyncpg://", | ||
async_creator=getconn, | ||
max_overflow=0, | ||
) | ||
return pool | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_embeddings_batch_processing( | ||
project_id: str, | ||
cluster_name: str, | ||
instance_name: str, | ||
region: str, | ||
database_name: str, | ||
password: str, | ||
table_name: str, | ||
) -> None: | ||
# TODO: Create new table | ||
# Populate the table with embeddings by running the notebook | ||
conftest.run_notebook( | ||
"embeddings_batch_processing.ipynb", | ||
variables={ | ||
"project_id": project_id, | ||
"cluster_name": cluster_name, | ||
"database_name": database_name, | ||
"region": region, | ||
"instance_name": instance_name, | ||
"table_name": table_name, | ||
}, | ||
preprocess=preprocess, | ||
skip_shell_commands=True, | ||
replace={ | ||
( | ||
"password = input(\"Please provide " | ||
"a password to be used for 'postgres' " | ||
"database user: \")" | ||
): f"password = '{password}'", | ||
( | ||
"await create_db(" | ||
"database_name=database_name, " | ||
"connector=connector)" | ||
): "", | ||
}, | ||
until_end=True, | ||
) | ||
|
||
# Connect to the populated table for validation and clean up | ||
async with AsyncConnector() as connector: | ||
pool = await _init_connection_pool( | ||
connector, | ||
database_name, | ||
project_id, | ||
cluster_name, | ||
instance_name, | ||
region, | ||
password, | ||
) | ||
async with pool.connect() as conn: | ||
# Validate that embeddings are non-empty for all rows | ||
result = await conn.execute( | ||
sqlalchemy.text( | ||
f"SELECT COUNT(*) FROM " | ||
f"{table_name} WHERE " | ||
f"analysis_embedding IS NULL" | ||
) | ||
) | ||
row = result.fetchone() | ||
assert row[0] == 0 | ||
result = await conn.execute( | ||
sqlalchemy.text( | ||
f"SELECT COUNT(*) FROM " | ||
f"{table_name} WHERE " | ||
f"overview_embedding IS NULL" | ||
) | ||
) | ||
row = result.fetchone() | ||
assert row[0] == 0 | ||
|
||
# Get the table back to the original state | ||
await conn.execute( | ||
sqlalchemy.text( | ||
f"UPDATE {table_name} set " | ||
f"analysis_embedding = NULL" | ||
) | ||
) | ||
await conn.execute( | ||
sqlalchemy.text( | ||
f"UPDATE {table_name} set " | ||
f"overview_embedding = NULL" | ||
) | ||
) | ||
await conn.commit() | ||
await pool.dispose() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.