Source code for aviary.tools.base

import inspect
import json
import logging
import uuid
from collections.abc import Awaitable, Callable, Iterable
from functools import partial
from itertools import starmap
from typing import Annotated, Any, Literal, NoReturn, Self, TypeAlias

from docstring_parser import DocstringParam, DocstringStyle, parse
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    FieldSerializationInfo,
    PlainSerializer,
    TypeAdapter,
    create_model,
    field_serializer,
    model_validator,
)
from pydantic.fields import FieldInfo

from aviary.message import Message
from aviary.utils import partial_format

try:
    from dicttoxml import dicttoxml
except ImportError:
    dicttoxml = None

logger = logging.getLogger(__name__)

# Mapping from python types to JSON schema types
# SEE: https://json-schema.org/understanding-json-schema/reference/numeric
type_map: dict[type | None, str] = {
    str: "string",
    int: "integer",
    float: "number",
    bool: "boolean",
    list: "list",
    dict: "object",
    None: "null",
}
reverse_type_map = {v: k for k, v in type_map.items()}

# A string to denote an invalid tool. It can be used to indicate
# an attempt to use a non-existent tool, missing/invalid parameters,
# mangled output from the LLM, etc.
INVALID_TOOL_NAME = "INVALID"


[docs] class ToolCallFunction(BaseModel): arguments: dict[str, Any] name: str
[docs] @model_validator(mode="before") @classmethod def deserialize_args(cls, data: Any) -> Any: if isinstance(data, dict) and isinstance(data["arguments"], str | None): if not data["arguments"]: data["arguments"] = {} else: try: data["arguments"] = json.loads(data["arguments"]) except json.JSONDecodeError: # If the arguments are not parseable, mark this ToolCall(Function) as invalid # so we can enable "learn"ing what a valid tool call looks like logger.warning( f"Failed to JSON load tool {data.get('name')}'s arguments" f" {data['arguments']}, declaring as {INVALID_TOOL_NAME}." ) data["name"] = INVALID_TOOL_NAME data["arguments"] = {} return data
[docs] @field_serializer("arguments") def serialize_arguments(self, arguments: dict[str, Any]) -> str: return json.dumps(arguments)
def __str__(self) -> str: arg_str = ", ".join([f"{k}='{v}'" for k, v in self.arguments.items()]) return f"{self.name}({arg_str})"
[docs] class ToolCall(BaseModel): id: str type: Literal["function"] = "function" function: ToolCallFunction
[docs] @staticmethod def generate_id() -> str: """Generate a tool call ID of length 9 with values in [a-zA-Z0-9].""" return str(uuid.uuid4()).replace("-", "")[:9]
[docs] @classmethod def from_tool(cls, tool: "Tool", *args, id: str | None = None, **kwargs) -> Self: # noqa: A002 """Create a ToolCall from a Tool and arguments. The *args is packaged into the ToolCallFunction's arguments dict with best effort. **kwargs is what is passed to toolcall because we have to use named parameters. """ # convert args to kwargs by matching them with the tool's parameters for i, name in enumerate(tool.info.parameters.properties.keys()): if i < len(args): kwargs[name] = args[i] return cls( id=id or cls.generate_id(), function=ToolCallFunction(name=tool.info.name, arguments=kwargs), )
[docs] @classmethod def from_name(cls, function_name: str, **kwargs) -> Self: return cls( id=cls.generate_id(), function=ToolCallFunction(name=function_name, arguments=kwargs), )
def __str__(self) -> str: arg_str = ", ".join([f"{k}='{v}'" for k, v in self.function.arguments.items()]) return f"{self.function.name}({arg_str})"
[docs] class ToolRequestMessage(Message): role: Literal["assistant"] = Field( default="assistant", description="Matching LiteLLM structure." ) content: str | None = None function_call: None = None tool_calls: list[ToolCall] = Field( default_factory=list, description="List of ToolCalls to make concurrently and independently.", ) def __str__(self) -> str: if not self.tool_calls: return super().__str__() base_msg = f"Tool request message {self.content or ''!r}" if len(self.tool_calls) == 1: return ( f"{base_msg} for tool calls: " f"{self.tool_calls[0]} [id={self.tool_calls[0].id}]" ) return f"{base_msg} for tool calls: " + "; ".join([ f"{tc!s} [id={tc.id}]" for tc in self.tool_calls ])
[docs] class ToolResponseMessage(Message): content: str = Field( description=( "Response message content, required to be a string by OpenAI/Anthropic." ), ) role: Literal["tool"] = Field( default="tool", description="Matching LiteLLM structure." ) name: str = Field(description="Name of the tool that was called.") tool_call_id: str = Field( description=( "Propagated from ToolCall.id, enabling matching response with" " ToolRequestMessage." ) )
[docs] @classmethod def from_call(cls, call: ToolCall, content: str) -> Self: return cls(content=content, name=call.function.name, tool_call_id=call.id)
[docs] @classmethod def from_request( cls, request: ToolRequestMessage, contents: Iterable[str] ) -> list[Self]: return list( starmap(cls.from_call, zip(request.tool_calls, contents, strict=True)) )
def __str__(self) -> str: return ( f"Tool response message {self.content!r} for tool call ID" f" {self.tool_call_id} of tool {self.name!r}" )
def dict_serialize_exclude_none( value: dict[str, dict[str, Any]], info: FieldSerializationInfo ) -> dict[str, dict[str, Any]]: """Work around Pydantic not applying exclude_none to dict serializations.""" if info.exclude_none: return { p_name: {k: v for k, v in config.items() if v is not None} for p_name, config in value.items() } return value
[docs] class Parameters(BaseModel): """Matches LiteLLM's desired "tools" schema.""" model_config = ConfigDict(extra="allow") type: Literal["object"] = "object" properties: Annotated[ dict[str, dict[str, Any]], PlainSerializer(dict_serialize_exclude_none) ] required: list[str]
[docs] class FunctionInfo(BaseModel): """ Function-level (not arg-level) information. Matches LiteLLM's desired "tools" schema, and resembles inspect.Signature. """ name: str description: str # SEE: https://github.com/openai/openai-openapi/blob/0f5de60a3d2b263dc2ac362371673f7a21811874/openapi.yaml#L7567-L7570 parameters: Parameters
[docs] def describe_str(self) -> str: for value in self.parameters.properties.values(): if value.get("allOf") or not value.get("type"): raise NotImplementedError( f"Complex types are not yet supported. Failed on: {self!r}" ) # Start with the function prototype prototype = f"{self.name}(" prototype += ", ".join([ f"{arg['type']} {name}" for name, arg in self.parameters.properties.items() ]) prototype += ")" # Function description indented_description_lines = "\n".join([ f" {line}" if line else "" for line in self.description.split("\n") ]) description = f"DESCRIPTION:\n{indented_description_lines}\n" # Parameters description params_description = "PARAMETERS:\n" for name, arg in self.parameters.properties.items(): param_desc = ( f" {name} ({arg['type']}):" f" {arg.get('description') or 'No description provided.'}\n" ) params_description += param_desc # Constructing the full man page return ( f"NAME: {self.name}\n\n" f"SYNOPSIS:\n {prototype}\n\n" f"{description}\n{params_description}" )
[docs] def describe_xml(self) -> str: try: return dicttoxml( self.model_dump(exclude_none=True, by_alias=True), custom_root="function_info", attr_type=False, xml_declaration=False, ).decode() except TypeError: raise ImportError( "XML description requires the 'xml' extra for 'dicttoxml'. Please:" " `pip install aviary[xml]`." ) from None
[docs] def describe_json(self) -> str: return self.model_dump_json(exclude_none=True, by_alias=True)
def __str__(self): return self.describe_str()
def _raises(exc: Exception) -> NoReturn: """Work around lambda not supporting raise statement.""" raise exc
[docs] class Tool(BaseModel): model_config = ConfigDict(populate_by_name=True) type: Literal["function"] = "function" info: FunctionInfo = Field( alias="function", description=( "The serialization alias of 'function' is to match LiteLLM structure on" " serialization, and the validation alias enables deserialization." ), ) def __init__( self, tool_fn: Callable[..., Any] | Callable[..., Awaitable[Any]] = ( lambda *_, **__: _raises( NotImplementedError("Please provide a tool function to call.") ) ), **kwargs, ): super().__init__(**kwargs) # NOTE: this Callable is excluded from serialization self._tool_fn = tool_fn self._force_pickle_fn = False def __getstate__(self) -> dict[Any, Any]: # Prevent _tool_fn from being pickled, SEE: https://stackoverflow.com/a/2345953 state = super().__getstate__() # allow forcing pickle, e.g., for cloud pickle sending if self._force_pickle_fn: return state state["__dict__"] = state["__dict__"].copy() state["__dict__"].pop("_tool_fn", None) return state @staticmethod def _get_param_desc(param: DocstringParam, include_type: bool) -> str: if not include_type or not param.type_name: return param.description or "" return f"({param.type_name}): {param.description or ''}"
[docs] @classmethod def from_function( cls, function: Callable[..., Any] | Callable[..., Awaitable[Any]], docstring_style: DocstringStyle = DocstringStyle.AUTO, allow_empty_param_descriptions: bool = False, types_in_param_descriptions: bool = False, **formats, ) -> "Tool": """Hydrate this class via inspection from a free function with a docstring.""" fxn_name = function.__name__ # now we parse descriptions from the docstring docstring = parse(function.__doc__, style=docstring_style) # type: ignore[arg-type] # SEE: https://github.com/rr-/docstring_parser/issues/88 if not docstring.description: raise ValueError(f"Missing docstring for function {fxn_name}.") # now we parse descriptions from the docstring try: # Don't include anything below \f, matching FastAPI's solution for this # SEE: https://fastapi.tiangolo.com/advanced/path-operation-advanced-configuration/#advanced-description-from-docstring description_stop_index: int | None = docstring.description.index("\\f") except ValueError: description_stop_index = None field_definitions: dict[str, tuple[type, FieldInfo]] = {} required: dict[str, bool] = {} annotations = function.__annotations__ for pname, parameter in inspect.signature(function).parameters.items(): if pname == "state": # NOTE: ToolRequestMessage passes state for us, not the LLM continue d = next( ( cls._get_param_desc( p, include_type=types_in_param_descriptions ).replace("\n", " ") for p in docstring.params if p.arg_name == pname ), "", ) if not d and not allow_empty_param_descriptions: raise ValueError(f"Missing description for parameter {pname}.") required[pname] = parameter.default == inspect.Parameter.empty field_config: dict[str, Any] = {} if description := partial_format(d, **formats): field_config["description"] = description if not required[pname]: field_config["default"] = parameter.default # Annotation resolution order: # 1. function.__annotations__: type-hints in function signature or injected # by argref_by_name. If a function has an opinion on a type hint, take it # at face-value. # 2. parameter.annotation - this will descend into wrapped functions. For # argref_by_name, this is undesirabe, since the wrapper overwrites type hints. # Hence, this is second in resolution order. field_definitions[pname] = ( annotations.get(pname) or parameter.annotation or type(None), Field(**field_config), # type: ignore[pydantic-field] ) json_schema = create_model( # type: ignore[call-overload] "FieldDefinitions", **field_definitions ).model_json_schema() json_schema.pop("title") # Remove the throwaway model name if "required" not in json_schema: # The API schema doesn't require this, and gpt-3.5-turbo doesn't # need this, but claude-3-haiku-20240307 does json_schema["required"] = [] return cls( tool_fn=function, info=FunctionInfo( name=fxn_name, description=partial_format( docstring.description[:description_stop_index].strip(), **formats ), parameters=json_schema, ), )
[docs] def wraps_doc_only(wrapped): """A decorator to copy only the docstring from the wrapped function. You cannot use functools wraps directly because it will set the __wrapped__ attribute, which causes inspect.signature to inspect the wrapped function instead of the wrapper. Usage: def my_documented_function(foo): '''This is a function that does something with foo.''' pass @wraps_doc_only(my_documented_function) def my_other_function(foo, state): pass In this example, the second function can have different arguments, types, etc. and only the docstring will be copied over. """ def _wraps_doc_only(wrapper, wrapped): wrapper.__doc__ = wrapped.__doc__ return wrapper return partial(_wraps_doc_only, wrapped=wrapped)
# Conveniences for deserialization Messages: TypeAlias = list[ToolRequestMessage | ToolResponseMessage | Message] MessagesAdapter = TypeAdapter(Messages) Tools: TypeAlias = list[Tool] ToolsAdapter = TypeAdapter(Tools)