Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
17cbeb2
ci: Run e2e tests in Cloud Build presubmit
bsidhom Feb 18, 2025
2d03c6c
wip: Install cloud-sdk for ambient ADC
bsidhom Feb 18, 2025
4b0b102
wip: Plumb environment variables and allow independent test failures
bsidhom Feb 18, 2025
75d4462
wip: Verify exit status of each step at completion
bsidhom Feb 18, 2025
08b8986
wip: Only inspect steps where .allowFailed is true
bsidhom Feb 18, 2025
183f33f
wip: Remove BUILD_ID tag to enable caching of layers across runs
bsidhom Feb 18, 2025
7978363
wip: Debug build results by printing to screen
bsidhom Feb 18, 2025
f96e88a
wip: Bump up timeout
bsidhom Feb 18, 2025
4200856
wip: test persistent workspace directory
bsidhom Feb 18, 2025
54b627b
wip: use sentinel files to signal success
bsidhom Feb 18, 2025
07f503f
wip: Add executable bit on bash scripts
bsidhom Feb 18, 2025
5f370d6
wip: Wire service account into integration tests
bsidhom Feb 19, 2025
45fcfe7
wip: Use default compute service account for integration tests
bsidhom Feb 19, 2025
2a87758
Revert "test: Remove invalid session config field"
bsidhom Mar 14, 2025
7920762
wip: fix service account enum name
bsidhom Mar 25, 2025
18d6ae1
wip: update config fields
bsidhom Apr 1, 2025
7f59a20
wip: Split out auth_type params and add better error messages
bsidhom Apr 1, 2025
2c43f8f
wip: Wire service account and suppress end-user credentials on gcb
bsidhom Apr 1, 2025
77186c2
wip: fix unit test textproto reference
bsidhom Apr 1, 2025
001b3f4
wip: fix typo and remove auth_type data classes
bsidhom Apr 1, 2025
e8ba269
wip: clean up config rewriting
bsidhom Apr 1, 2025
b13f907
wip: print default config
bsidhom Apr 1, 2025
39f7750
wip: fix service account environment variable
bsidhom Apr 1, 2025
e9cc055
wip: log full session request
bsidhom Apr 2, 2025
4e3b875
wip: increase log level and add some raw printing
bsidhom Apr 2, 2025
394df3c
wip: set REGIONAL_USER_OWNED_BUCKET for cloud build
bsidhom Apr 3, 2025
3354794
wip: set service account in in-line config test
bsidhom Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ If you are running the client outside of Google Cloud, you must set following en
* GOOGLE_CLOUD_PROJECT - The Google Cloud project you use to run Spark workloads
* GOOGLE_CLOUD_REGION - The Compute Engine [region](https://cloud.google.com/compute/docs/regions-zones#available) where you run the Spark workload.
* GOOGLE_APPLICATION_CREDENTIALS - Your [Application Credentials](https://cloud.google.com/docs/authentication/provide-credentials-adc)
* DATAPROC_SPARK_CONNECT_SESSION_DEFAULT_CONFIG (Optional) - The config location, such as `tests/integration/resources/session.textproto`
* DATAPROC_SPARK_CONNECT_SESSION_DEFAULT_CONFIG (Optional) - The config location, such as `tests/integration/resources/session_user.textproto`

## Usage

Expand Down
11 changes: 11 additions & 0 deletions cloudbuild/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
FROM python:3.10-bookworm

RUN \
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main" | \
tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \
apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && \
apt-get update -y && \
apt-get install google-cloud-cli jq -y

WORKDIR /opt/tests
# NOTE: We copy the requirements files explicitly _before_ installing
# dependencies to allow layer caching. Afterward, we copy the rest of the build
Expand All @@ -11,4 +19,7 @@ RUN python3 -m pip install -U pip
RUN python3 -m pip install --no-cache-dir -r requirements-dev.txt -r requirements-test.txt
COPY . .
RUN python setup.py sdist bdist_wheel egg_info
# TODO: Figure out how to get the .egg-info/ directory included in the output
# image. Something is suppressing it at the moment, so we have to `pip install`
# this package to allow tests to run.
RUN pip install .
32 changes: 27 additions & 5 deletions cloudbuild/cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,33 @@ steps:
# distribution artifacts.
- name: 'gcr.io/cloud-builders/docker'
id: 'build-container-image'
args: ['build', '--tag=gcr.io/${PROJECT_ID}/google-spark-connect/google-spark-connect-presubmit:${BUILD_ID}', -f, 'cloudbuild/Dockerfile', '.']
args: ['build', '--tag=gcr.io/${PROJECT_ID}/google-spark-connect/google-spark-connect-presubmit', -f, 'cloudbuild/Dockerfile', '.']
# Run all unit tests
- name: 'gcr.io/${PROJECT_ID}/google-spark-connect/google-spark-connect-presubmit:${BUILD_ID}'
- name: 'gcr.io/${PROJECT_ID}/google-spark-connect/google-spark-connect-presubmit'
id: 'run-unit-tests'
waitFor: ['build-container-image']
entrypoint: 'pytest'
args: ['-n', 'auto', 'tests/unit']
timeout: 600s
allowFailure: true
entrypoint: '/opt/tests/cloudbuild/run-unit-tests.sh'
# Run all integration tests. These exercise the actual API endpoints and
# require credentials.
- name: 'gcr.io/${PROJECT_ID}/google-spark-connect/google-spark-connect-presubmit'
id: 'run-integration-tests'
waitFor: ['build-container-image']
allowFailure: true
entrypoint: '/opt/tests/cloudbuild/run-integration-tests.sh'
env:
- 'GOOGLE_CLOUD_SUBNET=${_GOOGLE_CLOUD_SUBNET}'
- 'GOOGLE_CLOUD_PROJECT=${_GOOGLE_CLOUD_PROJECT}'
- 'GOOGLE_CLOUD_REGION=${_GOOGLE_CLOUD_REGION}'
- 'GOOGLE_CLOUD_SERVICE_ACCOUNT=${_GOOGLE_CLOUD_SERVICE_ACCOUNT}'
- 'TEST_SUPPRESS_END_USER_CREDENTIALS=1'
- name: 'gcr.io/${PROJECT_ID}/google-spark-connect/google-spark-connect-presubmit'
id: 'verify-status'
waitFor: ['run-unit-tests', 'run-integration-tests']
entrypoint: '/opt/tests/cloudbuild/verify-status.sh'
env:
- 'BUILD_ID=${BUILD_ID}'
- 'GOOGLE_CLOUD_PROJECT=${PROJECT_ID}'
timeout: 1800s
options:
defaultLogsBucketBehavior: REGIONAL_USER_OWNED_BUCKET
11 changes: 11 additions & 0 deletions cloudbuild/run-integration-tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

function main() {
run_tests && touch /workspace/integration-tests.SUCCESS
}

function run_tests() {
pytest -n 10 --log-cli-level=DEBUG tests/integration
}

main "$@"
11 changes: 11 additions & 0 deletions cloudbuild/run-unit-tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

function main() {
run_tests && touch /workspace/unit-tests.SUCCESS
}

function run_tests() {
pytest -n 10 tests/unit
}

main "$@"
59 changes: 59 additions & 0 deletions cloudbuild/verify-status.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env bash

# TODO: Consider rewriting all of these test drivers into a single python script
# to share names and logic.

set -euo pipefail

readonly TESTS=(
'unit-tests'
'integration-tests'
)

function main() {
#echo "/workspace contents:"
#ls -R /workspace
local failed_tests=()
local sentinel=""
for test in "${TESTS[@]}" ; do
# sentinel="/tmp/${test}.SUCCESS"
sentinel="/workspace/${test}.SUCCESS"
if [ ! -f "$sentinel" ] ; then
failed_tests+=("$test")
fi
done
if [ "${#failed_tests[@]}" -gt 0 ] ; then
echo "failed tests: ${failed_tests[@]}"
return 1
fi
# echo ""
# echo "BUILD RESULTS:"
# describe_build
# describe_build | verify_statuses
# jq_program
}

function describe_build() {
gcloud builds describe \
--project "$GOOGLE_CLOUD_PROJECT" \
--format=json \
"$BUILD_ID"
}

function verify_statuses() {
jq -f <(jq_program)
}

function jq_program() {
cat <<'EOF'
reduce
(.steps[] | select(.allowFailure)) as $step
(
[];
if $step.status != "SUCCESS" then . + [$step.id] else . end
) |
if length > 0 then error("the following steps failed: " + (join(", "))) else empty end
EOF
}

main "$@"
3 changes: 3 additions & 0 deletions google/cloud/spark_connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __create_spark_connect_session_from_s8s(
return session

def __create(self) -> "SparkSession":
print("CREATING SPARK SESSION", flush=True)
with self._lock:

if self._options.get("spark.remote", False):
Expand Down Expand Up @@ -208,6 +209,8 @@ def __create(self) -> "SparkSession":
f"projects/{self._project_id}/locations/{self._region}"
)

print("FULL SESSION REQUEST:", session_request, flush=True)

logger.debug("Creating serverless session")
GoogleSparkSession._active_s8s_session_id = session_id
s8s_creation_start_time = time.time()
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/resources/session_service_account.textproto
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
environment_config: {
execution_config: {
subnetwork_uri: "subnet-placeholder"
service_account: "service-account-placeholder"
authentication_config: {
user_workload_authentication_type: SERVICE_ACCOUNT
}
}
}
runtime_config: {
version: "2.2"
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
environment_config: {
execution_config: {
subnetwork_uri: "subnet-placeholder"
authentication_config: {
user_workload_authentication_type: END_USER_CREDENTIALS
}
}
}
runtime_config: {
Expand Down
116 changes: 97 additions & 19 deletions tests/integration/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import enum
import os
import tempfile
import uuid
Expand Down Expand Up @@ -40,29 +41,79 @@
_SERVICE_ACCOUNT_KEY_FILE_ = "service_account_key.json"


class AuthType(enum.Enum):
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
END_USER_CREDENTIALS = "END_USER_CREDENTIALS"


def get_auth_types():
service_account = os.environ.get("GOOGLE_CLOUD_SERVICE_ACCOUNT")
suppress_end_user_creds = os.environ.get(
"TEST_SUPPRESS_END_USER_CREDENTIALS"
)
if service_account:
yield AuthType.SERVICE_ACCOUNT
if not suppress_end_user_creds or suppress_end_user_creds == "0":
yield AuthType.END_USER_CREDENTIALS


def get_config_path(auth_type):
resources_dir = os.path.join(os.path.dirname(__file__), "resources")
match auth_type:
case AuthType.SERVICE_ACCOUNT:
return os.path.join(
resources_dir, "session_service_account.textproto"
)
case AuthType.END_USER_CREDENTIALS:
return os.path.join(resources_dir, "session_user.textproto")
case _:
raise Exception(f"unknown auth_type: {auth_type}")


@pytest.fixture(params=["2.2", "3.0"])
def image_version(request):
return request.param


@pytest.fixture
def test_project():
return os.environ.get("GOOGLE_CLOUD_PROJECT")
project = os.environ.get("GOOGLE_CLOUD_PROJECT")
if project is None:
raise Exception(
"must set a project through GOOGLE_CLOUD_PROJECT environment variable"
)
return project


@pytest.fixture
@pytest.fixture(params=get_auth_types())
def auth_type(request):
return getattr(request, "param", "SYSTEM_SERVICE_ACCOUNT")
return request.param


@pytest.fixture
def test_region():
return os.environ.get("GOOGLE_CLOUD_REGION")
region = os.environ.get("GOOGLE_CLOUD_REGION")
if region is None:
raise Exception(
"must set region through GOOGLE_CLOUD_REGION environment variable"
)
return region


@pytest.fixture
def test_subnet():
return os.environ.get("GOOGLE_CLOUD_SUBNET")
subnet = os.environ.get("GOOGLE_CLOUD_SUBNET")
if subnet is None:
raise Exception(
"must set subnet through GOOGLE_CLOUD_SUBNET environment variable"
)
return subnet


@pytest.fixture
def test_service_account():
# The service account may be empty, in which case we skip SA-based testing.
return os.environ.get("GOOGLE_CLOUD_SERVICE_ACCOUNT")


@pytest.fixture
Expand All @@ -72,22 +123,29 @@ def test_subnetwork_uri(test_project, test_region, test_subnet):

@pytest.fixture
def default_config(
auth_type, image_version, test_project, test_region, test_subnetwork_uri
auth_type,
test_service_account,
image_version,
test_project,
test_region,
test_subnetwork_uri,
):
resources_dir = os.path.join(os.path.dirname(__file__), "resources")
template_file = os.path.join(resources_dir, "session.textproto")
template_file = get_config_path(auth_type)
with open(template_file) as f:
template = f.read()
contents = (
template.replace("2.2", image_version)
.replace("subnet-placeholder", test_subnetwork_uri)
.replace("SYSTEM_SERVICE_ACCOUNT", auth_type)
contents = template.replace("2.2", image_version).replace(
"subnet-placeholder", test_subnetwork_uri
)
with tempfile.NamedTemporaryFile(delete=False) as t:
t.write(contents.encode("utf-8"))
t.close()
yield t.name
os.remove(t.name)
if auth_type == AuthType.SERVICE_ACCOUNT:
contents = contents.replace(
"service-account-placeholder", test_service_account
)
print("CONFIG CONTENTS:", contents)
with tempfile.NamedTemporaryFile(delete=False) as t:
t.write(contents.encode("utf-8"))
t.close()
yield t.name
os.remove(t.name)


@pytest.fixture
Expand Down Expand Up @@ -128,6 +186,7 @@ def session_template_controller_client(test_client_options):

@pytest.fixture
def connect_session(test_project, test_region, os_environment):
print("CREATING SESSION (TEST)", flush=True)
return GoogleSparkSession.builder.getOrCreate()


Expand All @@ -136,12 +195,16 @@ def session_name(test_project, test_region, connect_session):
return f"projects/{test_project}/locations/{test_region}/sessions/{GoogleSparkSession._active_s8s_session_id}"


@pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
# @pytest.mark.parametrize("auth_type", ["END_USER_CREDENTIALS"], indirect=True)
def test_create_spark_session_with_default_notebook_behavior(
auth_type, connect_session, session_name, session_controller_client
connect_session,
session_name,
session_controller_client,
):
# print("auth type parameter:", auth_type)
get_session_request = GetSessionRequest()
get_session_request.name = session_name
print("GET SESSION REQUEST (TEST):", get_session_request, flush=True)
session = session_controller_client.get_session(get_session_request)
assert session.state == Session.State.ACTIVE

Expand Down Expand Up @@ -307,13 +370,28 @@ def session_template_name(


def test_create_spark_session_with_session_template_and_user_provided_dataproc_config(
auth_type,
test_service_account,
image_version,
test_project,
test_region,
session_template_name,
session_controller_client,
):
dataproc_config = Session()
match auth_type:
case AuthType.END_USER_CREDENTIALS:
# This is the default
pass
case AuthType.SERVICE_ACCOUNT:
dataproc_config.environment_config.execution_config.service_account = (
test_service_account
)
dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type = (
"SERVICE_ACCOUNT"
)
case _:
raise Exception(f"unknown auth_type: {auth_type}")
dataproc_config.environment_config.execution_config.ttl = {"seconds": 64800}
dataproc_config.session_template = session_template_name
connect_session = (
Expand Down