diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index c8fdeaa..7cd94f9 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -15,7 +15,7 @@ from agents.run import RunHooks from agents import Agent, Runner, AgentHooks, RunHooks, result, function_tool, Tool, RunContextWrapper, TContext, OpenAIChatCompletionsModel, set_default_openai_client, set_default_openai_api, set_tracing_disabled -from .capi import COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token, AI_API_ENDPOINT_ENUM +from .capi import get_AI_endpoint, get_AI_token, get_custom_header, AI_API_ENDPOINT_ENUM # grab our secrets from .env, this must be in .gitignore load_dotenv(find_dotenv(usecwd=True)) @@ -156,7 +156,7 @@ def __init__(self, agent_hooks: TaskAgentHooks | None = None): client = AsyncOpenAI(base_url=api_endpoint, api_key=get_AI_token(), - default_headers={'Copilot-Integration-Id': COPILOT_INTEGRATION_ID}) + default_headers=get_custom_header()) set_default_openai_client(client) # CAPI does not yet support the Responses API: https://github.com/github/copilot-api/issues/11185 # as such we are implementing on chat completions for now diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index a4a308f..c38e2e6 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -29,8 +29,6 @@ def to_url(self): case _: raise ValueError(f"Unsupported endpoint: {self}") -COPILOT_INTEGRATION_ID = 'vscode-chat' - # you can also set https://api.githubcopilot.com if you prefer # but beware that your taskflows need to reference the correct model id # since different APIs use their own id schema, use -l with your desired @@ -52,6 +50,31 @@ def get_AI_token(): return token raise RuntimeError("AI_API_TOKEN environment variable is not set.") +def get_custom_header() -> dict[str, str]: + """ + Get custom header from environment variable AI_API_CUSTOM_HEADER. + Expected format: name:value + Returns a dictionary that can be merged into request headers. + """ + custom_header = os.getenv('AI_API_CUSTOM_HEADER') + if not custom_header: + return {} + + # Split on first colon to handle values that might contain colons + parts = custom_header.split(':', 1) + if len(parts) != 2: + logging.warning(f"Invalid AI_API_CUSTOM_HEADER format. Expected 'name:value', got: {custom_header}") + return {} + + name, value = parts + name = name.strip() + value = value.strip() + if not name or not value: + logging.warning(f"Invalid AI_API_CUSTOM_HEADER: header name and value must be non-empty after stripping. Got: '{custom_header}'") + return {} + return {name: value} + + # assume we are >= python 3.9 for our type hints def list_capi_models(token: str) -> dict[str, dict]: """Retrieve a dictionary of available CAPI models""" @@ -69,12 +92,11 @@ def list_capi_models(token: str) -> dict[str, dict]: case _: raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}\n" f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}") - r = httpx.get(httpx.URL(api_endpoint).join(models_catalog), - headers={ - 'Accept': 'application/json', - 'Authorization': f'Bearer {token}', - 'Copilot-Integration-Id': COPILOT_INTEGRATION_ID - }) + headers = { + 'Accept': 'application/json', + 'Authorization': f'Bearer {token}', + } | get_custom_header() + r = httpx.get(httpx.URL(api_endpoint).join(models_catalog), headers=headers) r.raise_for_status() # CAPI vs Models API match netloc: