import uuid
from collections.abc import Callable, Mapping, Sequence
from functools import update_wrapper
from inspect import signature
from itertools import starmap
from types import UnionType
from typing import Any, Union, get_args, get_origin
from docstring_parser import compose, parse
from aviary.utils import is_coroutine_callable
[docs]
def make_pretty_id(prefix: str = "") -> str:
"""
Get an ID that is made using an optional prefix followed by part of an uuid4.
For example:
- No prefix: "ff726cd1"
- With prefix of "foo": "foo-ff726cd1"
"""
uuid_frags: list[str] = str(uuid.uuid4()).split("-")
if not prefix:
return uuid_frags[0]
return prefix + "-" + uuid_frags[0]
DEFAULT_ARGREF_NOTE = "(Pass a string key instead of the full object)"
LIST_ARGREF_NOTE = "(Pass comma-separated string keys instead of the full object)"
[docs]
def argref_wrapper(wrapper, wrapped, args_to_skip: set[str]):
"""Inject the ARGREF_NOTE into the Args."""
# normal wraps
wrapped_func = update_wrapper(wrapper, wrapped)
# when we modify wrapped_func's annotations, we don't want to mutate wrapped
wrapped_func.__annotations__ = wrapped_func.__annotations__.copy()
# now adjust what we need
for a in wrapped_func.__annotations__:
if a in args_to_skip:
continue
wrapped_func.__annotations__[a] = str
orig_annots = wrapped.__annotations__
# now add note to docstring for all relevant Args
if wrapped_func.__doc__:
ds = parse(wrapped_func.__doc__)
for param in ds.params:
if param.arg_name in args_to_skip:
continue
note = DEFAULT_ARGREF_NOTE
if (
param.type_name is None
and (type_hint := orig_annots.get(param.arg_name)) is not None
):
param.type_name = _type_to_str(type_hint)
if list in {type_hint, get_origin(type_hint)}:
note = LIST_ARGREF_NOTE
param.description = (param.description or "") + f" {note}"
wrapped_func.__doc__ = compose(ds)
return wrapped_func
[docs]
def argref_by_name( # noqa: C901, PLR0915
fxn_requires_state: bool = False,
prefix: str = "",
return_direct: bool = False,
type_check: bool = False,
args_to_skip: set[str] | None = None,
):
"""Decorator to allow args to be a string key into a refs dict instead of the full object.
This can prevent LLM-powered tool selections from getting confused by full objects,
instead it enables them to work using named references. If a reference is not found, it
will fallback on passing the original argument unless it is the first argument. If the
first argument str is not found in the state object, it will raise an error.
Args:
fxn_requires_state: Whether to pass the state object to the decorated function.
prefix: A prefix to add to the generated reference ID.
return_direct: Whether to return the result directly or update the state object.
type_check: Whether to type-check arguments with respect to the wrapped function's
type annotations.
args_to_skip: If provided, a set of argument names that should not be referenced by name.
Example 1:
>>> @argref_by_name() # doctest: +SKIP
>>> def my_func(foo: float): ... # doctest: +SKIP
Example 2:
>>> def my_func(foo: float, bar: float) -> list[float]:
... return [foo, bar]
>>> wrapped_fxn = argref_by_name()(my_func)
>>> # Equivalent to my_func(state.refs["foo"])
>>> wrapped_fxn("foo", state=state) # doctest: +SKIP
Working with lists:
- If you return a list, the decorator will create a new reference for each item in the list.
- If you pass multiple args that are strings, the decorator will assume those are the keys.
- If you need to pass a string, then use a keyword argument.
Example 1:
>>> @argref_by_name() # doctest: +SKIP
>>> def my_func(foo: float, bar: float) -> list[float]: # doctest: +SKIP
... return [foo, bar] # doctest: +SKIP
Example 2:
>>> def my_func(foo: float, bar: float) -> list[float]:
... return [foo, bar]
>>> wrapped_fxn = argref_by_name()(my_func)
>>> # Returns a multiline string with the new references
>>> # Equivalent to my_func(state.refs["a"], state.refs["b"])
>>> wrapped_fxn("a", "b", state=state) # doctest: +SKIP
"""
args_to_skip = (args_to_skip or set()) | {"state", "return"}
def decorator(func): # noqa: C901, PLR0915
def get_call_args(*args, **kwargs): # noqa: C901
if "state" not in kwargs:
raise ValueError(
"argref_by_name decorated function must have a 'state' argument."
" Function signature:"
f" {func.__name__}({', '.join(func.__annotations__)}) received"
f" args: {args} kwargs: {kwargs}"
)
# pop the state argument
state = kwargs["state"] if fxn_requires_state else kwargs.pop("state")
# now convert the keynames to actual references (if they are a string)
# tuple is (arg, if was dereferenced)
def maybe_deref_arg(arg, must_exist: bool) -> tuple[Any, bool]:
try:
refs = state.refs
except AttributeError as e:
raise AttributeError(
"The state object must have a 'refs' attribute to use"
" argref_by_name decorator."
) from e
if arg in refs:
return [refs[arg]], True
# sometimes it is not correctly converted to a tuple
# so as an attempt to be helpful...
if (
isinstance(arg, str)
and len(split_args := [a.strip() for a in arg.split(",")]) > 1
):
if not (missing := [a for a in split_args if a not in refs]):
return [refs[a] for a in split_args], True
if must_exist:
# Error message for the agent - cast back to comma-separated, since that's the format the agent
# is expected to use.
raise KeyError(
"The following keys are not present in the current"
f' key-value store: "{", ".join(missing)}"'
)
if not must_exist:
return arg, False
# Error message for the agent
raise KeyError(
f"Key is not present in the current key-value store: {arg!r}"
)
# the split thing makes it complicated and we cannot use comprehension
deref_args = []
for i, arg in enumerate(args):
# In order to support *args, allow arguments that are either ref keys or strings
a, dr = maybe_deref_arg(arg, must_exist=False)
if dr:
deref_args.extend(a)
else:
if i == 0 and isinstance(arg, str):
# This is a bit of a heuristic, but if the first arg is a string and not found
# likely the user intended to use a reference
raise KeyError(f"The key {arg} is not found in state.")
deref_args.append(a)
deref_kwargs = {}
for k, v in kwargs.items():
if args_to_skip and k in args_to_skip:
deref_kwargs[k] = v
continue
# In the kwarg case, force arguments to be ref keys (unless in args_to_skip)
a, _ = maybe_deref_arg(v, must_exist=True)
if len(a) > 1:
# We got multiple items, so pass the whole list
deref_kwargs[k] = a
else:
# We only got one item - pass it directly
deref_kwargs[k] = a[0]
return deref_args, deref_kwargs, state
def update_state(state, result):
if return_direct:
return result
# if it returns a list, rather than storing the list as a single reference
# we store each item in the list as a separate reference
if isinstance(result, list):
msg = []
for item in result:
new_name = make_pretty_id(prefix)
state.refs[new_name] = item
msg.append(f"{new_name} ({item.__class__.__name__}): {item!s}")
return "\n".join(msg)
new_name = make_pretty_id(prefix)
state.refs[new_name] = result
return f"{new_name} ({result.__class__.__name__}): {result!s}"
def wrapper(*args, **kwargs):
args, kwargs, state = get_call_args(*args, **kwargs)
if type_check:
_check_arg_types(func, args, kwargs)
result = func(*args, **kwargs)
return update_state(state, result)
async def awrapper(*args, **kwargs):
args, kwargs, state = get_call_args(*args, **kwargs)
if type_check:
_check_arg_types(func, args, kwargs)
result = await func(*args, **kwargs)
return update_state(state, result)
if is_coroutine_callable(func):
awrapper = argref_wrapper(awrapper, func, args_to_skip)
awrapper.requires_state = True # type: ignore[attr-defined]
return awrapper
wrapper = argref_wrapper(wrapper, func, args_to_skip)
wrapper.requires_state = True # type: ignore[attr-defined]
return wrapper
return decorator
def _check_arg_types(func: Callable, args, kwargs) -> None:
annotations = {
k: v for k, v in func.__annotations__.items() if k not in {"return", "state"}
}
sig = signature(func)
param_names = list(sig.parameters.keys())
# Elements are tuple of (param, expected, provided)
wrong_types: list[tuple[str, str, str]] = []
# Map positional arguments to their parameter names
for idx, arg in enumerate(args):
if idx >= len(param_names):
break # Extra arguments are handled by *args if any
param = param_names[idx]
expected_type = annotations.get(param)
if expected_type and not _isinstance_with_generics(arg, expected_type):
wrong_types.append((
param,
_type_to_str(expected_type),
_type_to_str(type(arg)),
))
# Check keyword arguments
for param, arg in kwargs.items():
expected_type = annotations.get(param)
if expected_type and not _isinstance_with_generics(arg, expected_type):
wrong_types.append((
param,
# sometimes need str for generics like Union
_type_to_str(expected_type),
_type_to_str(type(arg)),
))
if wrong_types:
raise TypeError(
"The following arguments have incorrect types:\n"
+ "\n".join(
f"- {param}: expected {expected}, got {provided}"
for param, expected, provided in wrong_types
)
)
def _type_to_str(t) -> str:
"""
Convert a Python type annotation into its string representation.
Examples:
type_to_str(int) -> "int"
type_to_str(Union[int, float]) -> "int | float"
type_to_str(list[str]) -> "list[str]"
"""
origin = get_origin(t)
args = get_args(t)
if origin is Union:
# Handle Union types, including the new | syntax in Python 3.10+
return " | ".join(_type_to_str(arg) for arg in args)
if origin is not None:
# Handle generic types like list[str], dict[str, int], etc.
origin_name = origin.__name__
args_str = ", ".join(_type_to_str(arg) for arg in args)
return f"{origin_name}[{args_str}]"
if hasattr(t, "__name__"):
# Handle basic types
return t.__name__
# Fallback for types without a __name__ attribute
return str(t)
def _isinstance_with_generics(obj, expected_type) -> bool: # noqa: C901, PLR0911
"""Like isinstance, but with support for generics."""
origin = get_origin(expected_type)
if origin is None:
# Handle special cases like typing.Any
if expected_type is Any:
return True
return isinstance(obj, expected_type)
if origin in {UnionType, Union}:
return any(
_isinstance_with_generics(obj, arg) for arg in get_args(expected_type)
)
if origin in {list, Sequence}:
if not isinstance(obj, Sequence):
return False
elem_type = get_args(expected_type)[0]
return all(_isinstance_with_generics(elem, elem_type) for elem in obj)
if origin in {dict, Mapping}:
if not isinstance(obj, Mapping):
return False
key_type, val_type = get_args(expected_type)
return all(
_isinstance_with_generics(k, key_type)
and _isinstance_with_generics(v, val_type)
for k, v in obj.items()
)
if origin is tuple:
if not isinstance(obj, tuple):
return False
elem_types = get_args(expected_type)
if len(elem_types) == 2 and elem_types[1] is Ellipsis: # noqa: PLR2004
# Tuple of variable length
return all(_isinstance_with_generics(elem, elem_types[0]) for elem in obj)
if len(elem_types) != len(obj):
return False
return all(
starmap(_isinstance_with_generics, zip(obj, elem_types, strict=True))
)
# Fallback to checking the origin type
return isinstance(obj, origin)