diff --git a/portia/__init__.py b/portia/__init__.py index 50dbb49f..dffeff96 100644 --- a/portia/__init__.py +++ b/portia/__init__.py @@ -95,6 +95,7 @@ # Plan and execution related classes from portia.plan import Plan, PlanBuilder, PlanContext, PlanInput, PlanUUID, Step, Variable from portia.plan_run import PlanRun, PlanRunState +from portia.run_context import PlanRunV2, RunContext # Core classes from portia.portia import ExecutionHooks, Portia @@ -168,9 +169,11 @@ "PlanRun", "PlanRunNotFoundError", "PlanRunState", + "PlanRunV2", "PlanUUID", "PlanV2", "PlanningAgentType", + "RunContext", "Portia", "PortiaBaseError", "PortiaToolRegistry", diff --git a/portia/run_context.py b/portia/run_context.py index f8881ea1..92833db6 100644 --- a/portia/run_context.py +++ b/portia/run_context.py @@ -7,9 +7,11 @@ from portia.builder.plan_v2 import PlanV2 from portia.config import Config from portia.end_user import EndUser +from portia.execution_agents.output import LocalDataValue, Output from portia.execution_hooks import ExecutionHooks from portia.plan import Plan -from portia.plan_run import PlanRun +from portia.plan_run import PlanRunState +from portia.prefixed_uuid import PlanRunUUID from portia.storage import Storage from portia.telemetry.telemetry_service import BaseProductTelemetry from portia.tool import ToolRunContext @@ -25,30 +27,119 @@ class StepOutputValue(BaseModel): step_num: int = Field(description="The step number of the referenced value.") -class RunContext(BaseModel): - """Data that is returned from a step.""" +class PlanRunV2(BaseModel): + """A plan run represents a running instance of a PlanV2. - model_config = ConfigDict(arbitrary_types_allowed=True) + This consolidates all execution-specific state that was previously split between + PlanRun and RunContext. + + Attributes: + id (PlanRunUUID): A unique ID for this plan_run. + state (PlanRunState): The current state of the PlanRun. + current_step_index (int): The current step that is being executed. + plan (PlanV2): The plan being executed. + end_user (EndUser): The end user executing the plan. + step_output_values (list[StepOutputValue]): Outputs set by the step. + final_output (Output | None): The final consolidated output of the PlanRun. + plan_run_inputs (dict[str, LocalDataValue]): Dict mapping plan input names to their values. + config (Config): The Portia config. + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - plan: PlanV2 = Field(description="The Portia plan being executed.") - legacy_plan: Plan = Field(description="The legacy plan representation.") - plan_run: PlanRun = Field(description="The current plan run instance.") + id: PlanRunUUID = Field( + default_factory=PlanRunUUID, + description="A unique ID for this plan_run.", + ) + state: PlanRunState = Field( + default=PlanRunState.NOT_STARTED, + description="The current state of the PlanRun.", + ) + current_step_index: int = Field( + default=0, + description="The current step that is being executed", + ) + plan: PlanV2 = Field(description="The plan being executed.") end_user: EndUser = Field(description="The end user executing the plan.") step_output_values: list[StepOutputValue] = Field( default_factory=list, description="Outputs set by the step." ) + final_output: Output | None = Field( + default=None, + description="The final consolidated output of the PlanRun if available.", + ) + plan_run_inputs: dict[str, LocalDataValue] = Field( + default_factory=dict, + description="Dict mapping plan input names to their values.", + ) config: Config = Field(description="The Portia config.") + + +class RunContext(BaseModel): + """Context wrapper for a PlanV2 run. + + This class holds the PlanRunV2 instance along with environmental context + like storage and tool registries. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + plan_run: PlanRunV2 = Field(description="The current plan run instance.") storage: Storage = Field(description="The Portia storage.") tool_registry: ToolRegistry = Field(description="The Portia tool registry.") execution_hooks: ExecutionHooks = Field(description="The Portia execution hooks.") telemetry: BaseProductTelemetry = Field(description="The Portia telemetry service.") + # Deprecated fields kept for backwards compatibility + @property + def plan(self) -> PlanV2: + """Get the plan from plan_run (backwards compatibility).""" + return self.plan_run.plan + + @property + def end_user(self) -> EndUser: + """Get the end_user from plan_run (backwards compatibility).""" + return self.plan_run.end_user + + @property + def step_output_values(self) -> list[StepOutputValue]: + """Get the step_output_values from plan_run (backwards compatibility).""" + return self.plan_run.step_output_values + + @property + def config(self) -> Config: + """Get the config from plan_run (backwards compatibility).""" + return self.plan_run.config + + @property + def legacy_plan(self) -> Plan: + """Get the legacy plan representation.""" + # We'll need to get this from storage or convert it + # For now, create a placeholder - this will be properly implemented + from portia.plan import PlanContext + return self.plan_run.plan.to_legacy_plan( + PlanContext(query="", tool_ids=[]) + ) + def get_tool_run_ctx(self) -> ToolRunContext: """Get the tool run context.""" + # Import here to avoid circular dependency + from portia.plan_run import PlanRun + + # Create a legacy PlanRun for backwards compatibility with ToolRunContext + legacy_plan_run = PlanRun( + id=self.plan_run.id, + plan_id=self.plan_run.plan.id, + current_step_index=self.plan_run.current_step_index, + state=self.plan_run.state, + end_user_id=self.plan_run.end_user.external_id, + plan_run_inputs=self.plan_run.plan_run_inputs, + ) + return ToolRunContext( - end_user=self.end_user, - plan_run=self.plan_run, + end_user=self.plan_run.end_user, + plan_run=legacy_plan_run, plan=self.legacy_plan, - config=self.config, - clarifications=self.plan_run.get_clarifications_for_step(), + config=self.plan_run.config, + clarifications=[], # Will be properly implemented ) diff --git a/portia/storage.py b/portia/storage.py index cfdd5565..ae396b2d 100644 --- a/portia/storage.py +++ b/portia/storage.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: from portia.config import Config + from portia.run_context import PlanRunV2 T = TypeVar("T", bound=BaseModel) @@ -321,6 +322,55 @@ async def aget_plan_runs( """ return await asyncio.to_thread(self.get_plan_runs, run_state, page) + # PlanRunV2 storage methods + def save_plan_run_v2(self, plan_run: PlanRunV2) -> None: + """Save a PlanRunV2. + + Args: + plan_run (PlanRunV2): The PlanRunV2 object to save. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError("save_plan_run_v2 is not implemented") + + def get_plan_run_v2(self, plan_run_id: PlanRunUUID) -> PlanRunV2: + """Retrieve PlanRunV2 by its ID. + + Args: + plan_run_id (PlanRunUUID): The UUID of the run to retrieve. + + Returns: + PlanRunV2: The PlanRunV2 object associated with the provided plan_run_id. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError("get_plan_run_v2 is not implemented") + + async def asave_plan_run_v2(self, plan_run: PlanRunV2) -> None: + """Save a PlanRunV2 asynchronously using threaded execution. + + Args: + plan_run (PlanRunV2): The PlanRunV2 object to save. + + """ + await asyncio.to_thread(self.save_plan_run_v2, plan_run) + + async def aget_plan_run_v2(self, plan_run_id: PlanRunUUID) -> PlanRunV2: + """Retrieve PlanRunV2 by its ID asynchronously using threaded execution. + + Args: + plan_run_id (PlanRunUUID): The UUID of the run to retrieve. + + Returns: + PlanRunV2: The PlanRunV2 object associated with the provided plan_run_id. + + """ + return await asyncio.to_thread(self.get_plan_run_v2, plan_run_id) + class AdditionalStorage(ABC): """Abstract base class for additional storage. @@ -524,6 +574,7 @@ class InMemoryStorage(Storage): plans: dict[PlanUUID, Plan] runs: dict[PlanRunUUID, PlanRun] + runs_v2: dict[PlanRunUUID, PlanRunV2] outputs: defaultdict[PlanRunUUID, dict[str, LocalDataValue]] end_users: dict[str, EndUser] @@ -531,6 +582,7 @@ def __init__(self) -> None: """Initialize Storage.""" self.plans = {} self.runs = {} + self.runs_v2 = {} self.outputs = defaultdict(dict) self.end_users = {} @@ -717,6 +769,34 @@ def get_end_user(self, external_id: str) -> EndUser | None: return self.end_users[external_id] return None + def save_plan_run_v2(self, plan_run: PlanRunV2) -> None: + """Add PlanRunV2 to dict. + + Args: + plan_run (PlanRunV2): The PlanRunV2 object to save. + + """ + from portia.run_context import PlanRunV2 + + self.runs_v2[plan_run.id] = plan_run + + def get_plan_run_v2(self, plan_run_id: PlanRunUUID) -> PlanRunV2: + """Get PlanRunV2 from dict. + + Args: + plan_run_id (PlanRunUUID): The UUID of the PlanRunV2 to retrieve. + + Returns: + PlanRunV2: The PlanRunV2 object associated with the provided plan_run_id. + + Raises: + PlanRunNotFoundError: If the PlanRunV2 is not found. + + """ + if plan_run_id in self.runs_v2: + return self.runs_v2[plan_run_id] + raise PlanRunNotFoundError(plan_run_id) + class DiskFileStorage(Storage): """Disk-based implementation of the Storage interface. @@ -973,6 +1053,37 @@ def get_end_user(self, external_id: str) -> EndUser | None: except (ValidationError, FileNotFoundError): return None + def save_plan_run_v2(self, plan_run: PlanRunV2) -> None: + """Save PlanRunV2 object to the storage. + + Args: + plan_run (PlanRunV2): The PlanRunV2 object to save. + + """ + from portia.run_context import PlanRunV2 + + self._write(f"v2_{plan_run.id}.json", plan_run) + + def get_plan_run_v2(self, plan_run_id: PlanRunUUID) -> PlanRunV2: + """Retrieve PlanRunV2 object by its ID. + + Args: + plan_run_id (PlanRunUUID): The ID of the PlanRunV2 to retrieve. + + Returns: + PlanRunV2: The retrieved PlanRunV2 object. + + Raises: + PlanRunNotFoundError: If the PlanRunV2 is not found or validation fails. + + """ + from portia.run_context import PlanRunV2 + + try: + return self._read(f"v2_{plan_run_id}.json", PlanRunV2) + except (ValidationError, FileNotFoundError) as e: + raise PlanRunNotFoundError(plan_run_id) from e + class PortiaCloudStorage(Storage): """Save plans, runs and tool calls to portia cloud.""" @@ -1945,3 +2056,11 @@ async def aget_end_user(self, external_id: str) -> EndUser | None: phone_number=response_json["phone_number"], additional_data=response_json["additional_data"], ) + + def save_plan_run_v2(self, plan_run: 'PlanRunV2') -> None: + """Save PlanRunV2 to Portia Cloud (stub).""" + raise NotImplementedError("PlanRunV2 cloud storage not yet implemented") + + def get_plan_run_v2(self, plan_run_id: PlanRunUUID) -> 'PlanRunV2': + """Retrieve PlanRunV2 from Portia Cloud (stub).""" + raise NotImplementedError("PlanRunV2 cloud storage not yet implemented")