Source code for tools.server

import os
import secrets
import sys
import tempfile
from collections.abc import Callable
from pathlib import Path
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field, create_model

from aviary.tools.base import Tool, ToolCall, ToolRequestMessage, reverse_type_map


[docs] async def make_tool_server( # noqa: C901, PLR0915 environment_factory: Callable, name: str = "Aviary Tool Server", env_path: Path | None = None, ): """Create a FastAPI server for the provided environment. This function exposes one endpoint per tool and endpoints to create/view/delete environments. In contrast to other environment servers that expose an action endpoint, this one exposes all tools individually. This is only for debugging tools and not intended as a strategy for working with environments. Most environments have side-effects from using tools that occur in the step function. This bypasses that and allows you to call tools directly. Args: environment_factory: A callable that returns an environment instance. name: The name of the server. Defaults to Aviary Tool Server. env_path: The path to the directory to store environments """ try: import cloudpickle as pickle from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer except ModuleNotFoundError as exc: raise ImportError( "Please install aviary with the 'server' extra like so:" " `pip install aviary[server]`." ) from exc if not env_path: env_path = Path(tempfile.gettempdir()) auth_scheme = HTTPBearer() async def validate_token( credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), # noqa: B008 ) -> str: # NOTE: don't use os.environ.get() to avoid possible empty string matches, and # to have clearer server failures if the AUTH_TOKEN env var isn't present if not secrets.compare_digest( credentials.credentials, os.environ["AUTH_TOKEN"] ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect bearer token", headers={"WWW-Authenticate": "Bearer"}, ) return credentials.credentials # these seem useful in other contexts, but from what I read # it is discouraged to save/load so leaving it defined here def save_environment(environment, tools, environment_id): # make sure we force all tools to pickle for tool in tools: tool._force_pickle_fn = True with open(env_path / f"{environment_id}.pkl", "wb") as f: pickle.dump((environment, tools), f) def load_environment(environment_id): if not (env_path / f"{environment_id}.pkl").exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Environment {environment_id} not found", ) with open(env_path / f"{environment_id}.pkl", "rb") as f: return pickle.load(f) def make_environment_id(): return f"env{str(uuid4())[:8].replace('-', '')}" def create_request_model_from_tool(tool: Tool) -> BaseModel: fields = {} for pname, info in tool.info.parameters.properties.items(): if pname == "type": continue # we just assume it exists ptype = reverse_type_map[info["type"]] if "type" in info else Any # decipher optional description, optional default, and type if pname in tool.info.parameters.required: if "description" in info: fields[pname] = (ptype, Field(description=info["description"])) else: fields[pname] = (ptype, ...) elif "description" in info: fields[pname] = ( ptype | None, Field(description=info["description"], default=None), ) else: fields[pname] = (ptype | None, None) return create_model(f"{tool.info.name.capitalize()}Params", **fields) # type: ignore[call-overload] web_app = FastAPI( title=name, description="API Server for Aviary Environment Tools", dependencies=[Depends(validate_token)], ) # make a starting environment to save tools env = environment_factory() _, tools = await env.reset() # Dynamically create routes for each tool for tool in (t for t in tools if hasattr(t, "_tool_fn")): tool_name = tool.info.name tool_description = tool.info.description RequestModel = create_request_model_from_tool(tool) # ensure the this will be in fast api scope # because fastapi will barf on a request model that isn't in scope # close your eyes PR reviewers # also fuck your IDE tools RequestModel.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") def create_tool_handler(tool_name, RequestModel, tool_description): async def _tool_handler( data: RequestModel, # type: ignore[valid-type] environment_id: str = "", ): if environment_id: env, env_tools = load_environment(environment_id) else: env = environment_factory() _, env_tools = await env.reset() environment_id = make_environment_id() # ok now find the tool_fn to call it with # that came from the env I just loaded msg = ToolRequestMessage( tool_calls=[ToolCall.from_name(tool_name, **data.model_dump())] # type: ignore[attr-defined] ) try: result_msgs, done, *_ = await env.step(msg) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) from e if done: _, env_tools = await env.reset() save_environment(env, env_tools, environment_id) return { "result": "\n\n".join([ str(msg.content) for msg in result_msgs if msg.content ]), "environment_id": environment_id, } _tool_handler.__doc__ = tool_description return _tool_handler tool_handler = create_tool_handler( tool.info.name, RequestModel, tool_description ) # Add a POST route so we can invoke the tool function web_app.post( f"/{tool_name}", summary=tool_name, name=tool_name, description=tool_description, )(tool_handler) # Add environment endpoints @web_app.get( "/env/create", summary="Create Environment", description="Create a new environment", ) async def create_environment_endpoint(): env = environment_factory() _, tools = await env.reset() environment_id = make_environment_id() save_environment(env, tools, environment_id) return environment_id @web_app.get( "/env/delete/{environment_id}", summary="Delete Environment", description="Delete an environment", ) async def delete_environment_endpoint(environment_id: str): if (env_path / f"{environment_id}.pkl").exists(): (env_path / f"{environment_id}.pkl").unlink() return environment_id @web_app.get( "/env/view/{environment_id}", summary="View Environment", description="View an environment", ) async def view_environment_endpoint(environment_id: str): if not (env_path / f"{environment_id}.pkl").exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Environment {environment_id} not found", ) with (env_path / f"{environment_id}.pkl").open("rb") as f: env, _ = pickle.load(f) return env.state return web_app