Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ errors/
repl_state/
venv/
*.egg-info
*.iml
CLAUDE.md
.idea/
10 changes: 10 additions & 0 deletions src/strands_tools/generate_image_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,16 @@ def generate_image_stability(tool: ToolUse, **kwargs: Any) -> ToolResult:
f"Generated image using {model_id}. Finish reason: {finish_reason}"
f"{' ' + save_info if save_info else ''}"
),
"json": {
"image_prompt": prompt,
"output_filename": filename,
},
},
{
"image": {
"format": output_format,
"source": {"bytes": image_bytes},
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this duplicating the data in image_object?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh good catch. Forgot to remove it, result of a bad merge from the code you had and what I already had in the repo from when you and were working together on this.
All I wanted is to add . I will update accordingly.

 "json": {
                        "image_prompt": prompt,
                        "output_filename": filename,
                    }

},
{"image": image_object},
],
Expand Down
75 changes: 75 additions & 0 deletions tests/test_generate_image_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,78 @@ def test_tool_spec_exists():
properties = generate_image_stability.TOOL_SPEC["inputSchema"]["properties"]
assert "model_id" not in properties
assert "prompt" in properties


def test_generate_image_stability_no_output_dir(mock_env_api_key, mock_requests):
"""Test that output_filename is None when STABILITY_OUTPUT_DIR is not set."""
mock_post, mock_response = mock_requests

# Ensure STABILITY_OUTPUT_DIR is not set
if "STABILITY_OUTPUT_DIR" in os.environ:
del os.environ["STABILITY_OUTPUT_DIR"]

tool_use = {
"toolUseId": "test-tool-use-id",
"input": {
"prompt": "A test image",
},
}

result = generate_image_stability.generate_image_stability(tool=tool_use)

# Check that the result is successful
assert result["status"] == "success"
assert result["toolUseId"] == "test-tool-use-id"

# Check that output_filename is None in the JSON response
json_content = result["content"][0]["json"]
assert json_content["output_filename"] is None

# Check that the text response doesn't include file save info
text_content = result["content"][0]["text"]
assert "Image saved to" not in text_content


def test_generate_image_stability_with_output_dir(mock_env_api_key, mock_requests, tmp_path):
"""Test that output_filename is set correctly when STABILITY_OUTPUT_DIR is set."""
mock_post, mock_response = mock_requests

# Set STABILITY_OUTPUT_DIR to a temporary directory
output_dir = str(tmp_path)
os.environ["STABILITY_OUTPUT_DIR"] = output_dir

tool_use = {
"toolUseId": "test-tool-use-id",
"input": {
"prompt": "A test image for file output",
},
}

result = generate_image_stability.generate_image_stability(tool=tool_use)

# Check that the result is successful
assert result["status"] == "success"
assert result["toolUseId"] == "test-tool-use-id"

# Check that output_filename is set in the JSON response
json_content = result["content"][0]["json"]
assert json_content["output_filename"] is not None
assert json_content["output_filename"].startswith(output_dir)
assert json_content["output_filename"].endswith(".png") # default format

# Check that the text response includes file save info
text_content = result["content"][0]["text"]
assert "Image saved to" in text_content
assert output_dir in text_content

# Verify the file actually exists
filename = json_content["output_filename"]
assert os.path.exists(filename)

# Verify the file content matches the mock response
with open(filename, "rb") as f:
file_content = f.read()
assert file_content == b"mock_image_data"

# Clean up environment variable
del os.environ["STABILITY_OUTPUT_DIR"]