Source code for aviary.tools.utils
from collections.abc import Callable
from enum import StrEnum
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, cast
from pydantic import BaseModel, Field
from aviary.message import MalformedMessageError, Message
from .base import (
MessagesAdapter,
Tool,
ToolRequestMessage,
ToolResponseMessage,
ToolsAdapter,
)
if TYPE_CHECKING:
from collections.abc import Awaitable
from litellm import ModelResponse
class EvalAnswerMode(StrEnum):
EXACT = "exact" # strings must match exactly
CONTAINS = "contains" # the correct answer is contained in the supplied answer
LLM = "llm" # Ask an LLM to evaluate
LLM_SCORE = "llm-score" # Ask an LLM to evaluate and return the score (normalized)
LLM_EVAL_CONFIG = {
"prompt": (
"Here is a question, the correct answer to the question, and a proposed answer"
" to the question. Please tell me if the proposed answer is correct, given the"
" correct answer. ONLY SAY 'YES' OR 'NO'. No other output is permitted."
"\n\nQuestion: {question}"
"\n\nCorrect answer: {correct_answer}"
"\n\nProposed answer: {proposed_answer}"
),
"model": "gpt-4o-mini",
"temperature": 0,
}
LLM_SCORE_EVAL_CONFIG = {
"prompt": (
"Here is a question, the correct answer to the question, and a rubric for"
" evaluating the question. Judge the proposed answer based on the given rubric."
" Give a score from 0 to 10. No other output is permitted."
"\n\nQuestion: {question}"
"\n\nRubric: {correct_answer}"
"\n\nProposed answer: {proposed_answer}"
),
"model": "gpt-4o-mini",
"temperature": 0,
"max_score": 10,
}
[docs]
async def eval_answer(
proposed: str,
correct: str,
question: str | None = None,
eval_mode: EvalAnswerMode = EvalAnswerMode.CONTAINS,
llm_eval_config: dict | None = None,
) -> float:
"""Evaluate a proposed answer against a correct answer.
Will return 0 or 1, except for llm-score which should be between 0 and 1
"""
if eval_mode in {EvalAnswerMode.LLM, EvalAnswerMode.LLM_SCORE}:
try:
from litellm import acompletion
except ImportError as e:
raise ImportError(
"eval_answer requires the 'llm' extra for 'litellm'. Please:"
" `pip install aviary[llm]`."
) from e
if question is None:
raise ValueError("Question must be provided for LLM evaluation mode.")
default_config = (
LLM_EVAL_CONFIG
if eval_mode == EvalAnswerMode.LLM
else LLM_SCORE_EVAL_CONFIG
)
config = llm_eval_config or default_config
prompt = cast(str, config.get("prompt", default_config["prompt"])).format(
question=question,
correct_answer=correct,
proposed_answer=proposed,
)
response = await acompletion(
model=config.get("model", default_config["model"]),
temperature=config.get("temperature", default_config["temperature"]),
messages=[{"content": prompt, "role": "user"}],
)
if eval_mode == EvalAnswerMode.LLM:
return await eval_answer(
response.choices[0].message.content.strip().casefold(),
"yes",
eval_mode=EvalAnswerMode.EXACT,
)
try:
return float(response.choices[0].content.strip()) / float(
config.get("max_score", default_config["max_score"]) # type: ignore[arg-type]
)
except ValueError:
return 0
gt = correct.strip().casefold()
pred = proposed.strip().casefold()
if eval_mode == EvalAnswerMode.EXACT:
return float(pred == gt)
if eval_mode == EvalAnswerMode.CONTAINS:
return float(gt in pred)
raise RuntimeError(f"Invalid evaluation mode: {eval_mode}")
[docs]
class ToolSelector:
"""Simple entity to select a tool based on messages."""
def __init__(
self,
model_name: str = "gpt-4o",
acompletion: "Callable[..., Awaitable[ModelResponse]] | None" = None,
accum_messages: bool = False,
):
"""Initialize.
Args:
model_name: Name of the model to select a tool with.
acompletion: Optional async completion function to use, leaving as the
default of None will use LiteLLM's acompletion. Alternately, specify
LiteLLM's Router.acompletion function for centralized rate limiting.
accum_messages: Whether the selector should accumulate messages in a ledger.
"""
if acompletion is None:
try:
from litellm import acompletion
except ImportError as e:
raise ImportError(
f"{type(self).__name__} requires the 'llm' extra for 'litellm'."
" Please: `pip install aviary[llm]`."
) from e
self._model_name = model_name
self._bound_acompletion = partial(cast(Callable, acompletion), model_name)
self._ledger = ToolSelectorLedger() if accum_messages else None
# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# > `required` means the model must call one or more tools.
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required"
async def __call__(
self,
messages: list[Message],
tools: list[Tool],
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
) -> ToolRequestMessage:
"""Run a completion that selects a tool in tools given the messages."""
completion_kwargs: dict[str, Any] = {}
# SEE: https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter
expected_finish_reason: set[str] = {"tool_calls"}
if isinstance(tool_choice, Tool):
completion_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": tool_choice.info.name},
}
expected_finish_reason = {"stop"} # TODO: should this be .add("stop") too?
elif tool_choice is not None:
completion_kwargs["tool_choice"] = tool_choice
if tool_choice == self.TOOL_CHOICE_REQUIRED:
# Even though docs say it should be just 'stop',
# in practice 'tool_calls' shows up too
expected_finish_reason.add("stop")
if self._ledger is not None:
self._ledger.messages.extend(messages)
messages = self._ledger.messages
model_response = await self._bound_acompletion(
messages=MessagesAdapter.dump_python(
messages, exclude_none=True, by_alias=True
),
tools=ToolsAdapter.dump_python(tools, exclude_none=True, by_alias=True),
**completion_kwargs,
)
if (num_choices := len(model_response.choices)) != 1:
raise MalformedMessageError(
f"Expected one choice in LiteLLM model response, got {num_choices}"
f" choices, full response was {model_response}."
)
choice = model_response.choices[0]
if choice.finish_reason not in expected_finish_reason:
raise MalformedMessageError(
f"Expected a finish reason in {expected_finish_reason} in LiteLLM"
f" model response, got finish reason {choice.finish_reason!r}, full"
f" response was {model_response} and tool choice was {tool_choice}."
)
usage = model_response.usage
selection = ToolRequestMessage(
**choice.message.model_dump(),
info={
"usage": (usage.prompt_tokens, usage.completion_tokens),
"model": self._model_name,
},
)
if self._ledger is not None:
self._ledger.messages.append(selection)
return selection
[docs]
class ToolSelectorLedger(BaseModel):
"""Simple ledger to record tools and messages."""
tools: list[Tool] = Field(default_factory=list)
messages: list[ToolRequestMessage | ToolResponseMessage | Message] = Field(
default_factory=list
)