from __future__ import annotations
import asyncio
import importlib
import inspect
import json
import logging
import random
from abc import ABC, abstractmethod
from collections.abc import Iterator
from copy import deepcopy
from typing import Annotated, Generic, Self, TypeAlias, TypeVar, cast
from pydantic import (
BaseModel,
ConfigDict,
Field,
JsonValue,
ValidationInfo,
WrapSerializer,
field_validator,
)
from aviary.message import Message
from aviary.tools import Tool, ToolCall, ToolRequestMessage, ToolResponseMessage
from aviary.utils import is_coroutine_callable
logger = logging.getLogger(__name__)
# TODO: make TypeVar after https://github.com/pydantic/pydantic/milestone/13
# NOTE: can't use pydantic.JsonValue here because it will deep copy all the way
# down JSON, and we want to support shallow copying capability
Serializable: TypeAlias = dict | list | int | float | str | bool | BaseModel
[docs]
class Frame(BaseModel):
"""A frame is a snapshot at a given timestep. The name comes from video frame."""
deepcopy: bool = Field(
default=True,
description=(
"Whether to deepcopy the state and info fields. "
"Disable if you're sure they're immutable or desire mutability."
),
)
@staticmethod
def _custom_serializer(value: Serializable, handler, info): # noqa: ARG004
if isinstance(value, BaseModel):
return value.model_dump()
return handler(value)
state: Annotated[Serializable | None, WrapSerializer(_custom_serializer)] = Field(
default=None,
description=(
"Either entire (or a subset of) the current state. Leave as default of None"
" if state is irrelevant."
),
)
info: Annotated[Serializable | None, WrapSerializer(_custom_serializer)] = Field(
default=None, description="Optional metadata that doesn't vary with state."
)
[docs]
@field_validator("state", "info")
@classmethod
def make_deepcopy(cls, v: Serializable, info: ValidationInfo) -> Serializable:
if info.data["deepcopy"]:
return deepcopy(v)
return v
# NOTE: setting to None means there is no state
TEnvState = TypeVar("TEnvState")
[docs]
class Environment(ABC, Generic[TEnvState]):
"""
An environment is a stateful place where agents use tools and make observations.
Tools are housed in the environment because they can interact with the environment.
Environments (and their contained tools) are not trainable.
"""
tools: list[Tool]
state: TEnvState
[docs]
@abstractmethod
async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
"""Take a step in the environment.
Args:
action: Action to take.
Returns:
Four-tuple of new observations, instantaneous reward for this action, a flag
symbolizing if the episode is done, and a flag symbolizing if the
episode was truncated (e.g. via early stopping).
"""
[docs]
@abstractmethod
async def reset(self) -> tuple[list[Message], list[Tool]]:
"""
Reset the environment and collect initial observation(s).
Possible observations could be instructions on how tools are related,
or the goal of the environment.
Returns:
Two-tuple of initial observations and tools.
"""
[docs]
def export_frame(self) -> Frame:
"""
Export a snapshot of the environment as a Frame for visualization or debugging.
If you are not sure what to put in the Frame, just give it the entire state.
See the Frame class itself for more information.
"""
return Frame()
[docs]
async def close(self) -> None:
"""
Shutdown the environment.
If this is unimplemented, __del__ will manage cleanup.
"""
[docs]
@classmethod
def from_task(cls, task: str) -> Self:
"""Create an environment from a task description.
A task is meant to be closer to a user prompt - like what you would expect
in calling an LLM. This is how the environment should be used after training
and in deployment. We don't take config here, because the default environment config
should be general for arbitrary tasks.
For example, with GSM8k/calculator: "What is 18 * (number of legs on a cat) / moons of mars?"
"""
raise NotImplementedError(f"{cls.__name__} does not implement from_task")
[docs]
@classmethod
def from_name(cls, name: str, task: str | None = None, **env_kwargs) -> Self:
"""Create an environment from the name of the class. Call `Environment.available()` to see list."""
new_cls = _get_cls_from_name(ENV_REGISTRY, name)
if task is not None:
if env_kwargs:
raise ValueError("Cannot pass both a task and environment kwargs.")
return new_cls.from_task(task)
return new_cls(**env_kwargs)
[docs]
@classmethod
def available(cls) -> set[str]:
"""See list of available environment classes for `from_name`.
This is not exhaustive, because some may be importable and so you should just
try to call `from_name`. This is more for logging/debugging purposes.
"""
return set(ENV_REGISTRY.keys())
# Maps baseline environment names to their module and class names
ENV_REGISTRY: dict[str, tuple[str, str]] = {
"dummy": ("aviary.env", "DummyEnv"),
"calculator": ("aviary.envs.gsm8k.env", "CalculatorEnv"),
"hotpotqa": ("aviary.envs.hotpotqa.env", "HotPotQAEnv"),
}
TEnvironment = TypeVar("TEnvironment", bound=Environment)
[docs]
class TaskDataset(ABC, Generic[TEnvironment]):
"""A base class for a dataset of tasks as environments.
Examples of task datasets: GSM8k, HotPotQA, etc.
These are related environments instances with different problem
specifications and reward conditions.
"""
[docs]
@classmethod
def from_name(cls, name: str, **env_kwargs) -> TaskDataset:
return _get_cls_from_name(TASK_DATASET_REGISTRY, name)(**env_kwargs)
def __len__(self) -> int:
raise TypeError(f'"Object of type {self.__class__.__name__}" has no len()')
[docs]
def get_new_env_by_idx(self, idx: int) -> TEnvironment:
"""Get an env from a finite dataset."""
raise NotImplementedError(
f'"{self.__class__.__name__}" does not implement get_new_env_by_idx'
)
[docs]
def get_new_env(self) -> TEnvironment:
"""Get an env from a non-indexable dataset."""
raise NotImplementedError(
f'"{self.__class__.__name__}" does not implement get_new_env'
)
[docs]
def iter_batches(
self, batch_size: int, shuffle: bool = False
) -> Iterator[list[TEnvironment]]:
"""Construct batches from this dataset.
Args:
batch_size: Size of each batch.
Note that if this dataset's size is finite and isn't evenly divisible by
this value, the last yielded batch will be smaller than batch_size.
shuffle: Opt-in flag to shuffle without replacement.
Yields:
An iterator over batches of environments.
"""
try:
n = len(self)
except TypeError:
# not a finite-length dataset, so construct an infinite iter
while True:
yield [self.get_new_env() for _ in range(batch_size)]
else:
# finite-length dataset
idcs = list(range(n))
if shuffle:
random.shuffle(idcs)
while idcs:
batch_idcs = idcs[:batch_size]
idcs = idcs[batch_size:]
yield [self.get_new_env_by_idx(idx) for idx in batch_idcs]
# Maps baseline task dataset names to their module and class names
TASK_DATASET_REGISTRY: dict[str, tuple[str, str]] = {
"dummy": ("aviary.env", "DummyTaskDataset"),
"gsm8k": ("aviary.envs.gsm8k.env", "GSM8kDataset"),
"hotpotqa": ("aviary.envs.hotpotqa.env", "HotPotQADataset"),
}
[docs]
class TaskConfig(BaseModel):
"""Convenience for making a config file entry for a TaskDataset."""
model_config = ConfigDict(extra="forbid")
name: str
task_kwargs: dict[str, BaseModel | JsonValue] = Field(
default_factory=dict, description="Arguments to pass to TaskDataset.from_name()"
)
train_kwargs: dict[str, BaseModel | JsonValue] = Field(
default_factory=dict, description="Additional arguments for the training split."
)
eval_kwargs: dict[str, BaseModel | JsonValue] = Field(
default_factory=dict,
description="Additional arguments for the evaluation split.",
)
test_kwargs: dict[str, BaseModel | JsonValue] = Field(
default_factory=dict, description="Additional arguments for the test split."
)
[docs]
def make_dataset(self, split: str) -> TaskDataset:
if split == "train":
split_kw = self.task_kwargs | self.train_kwargs
elif split == "eval":
split_kw = self.task_kwargs | self.eval_kwargs
elif split == "test":
split_kw = self.task_kwargs | self.test_kwargs
else:
raise NotImplementedError(f"Didn't handle split {split!r}.")
return TaskDataset.from_name(self.name, **split_kw)
[docs]
class DummyEnvState(BaseModel):
messages: list[Message]
reward: float = 0
done: bool = False
[docs]
class DummyEnv(Environment[DummyEnvState]):
"""Simple Environment with basic functionality and no network usage."""
State = DummyEnvState
def __init__(self, task: str | None = None, end_immediately: bool = True):
self.end_immediately = end_immediately
self.task = task
[docs]
@classmethod
def from_task(cls, task: str) -> DummyEnv:
return cls(task=task)
[docs]
async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
msgs: list[Message] = await self.exec_tool_calls(action, state=self.state)
self.state.messages.extend(msgs)
return msgs, self.state.reward, self.state.done, False
[docs]
async def reset(self) -> tuple[list[Message], list[Tool]]:
def print_story(story: str, state: DummyEnvState) -> None: # noqa: ARG001
"""Print a story.
Args:
story: Story to print.
state: Environment state.
"""
state.reward = 1
state.done = self.end_immediately
def cast_float(x: str) -> float:
"""Cast the input argument x to a float."""
return float(x)
def cast_int(x: float) -> int:
"""Cast the input argument x to an integer."""
return int(x)
self.tools = [
Tool.from_function(print_story),
Tool.from_function(cast_float, allow_empty_param_descriptions=True),
Tool.from_function(cast_int, allow_empty_param_descriptions=True),
]
self.state = type(self).State(
messages=[
Message(
content="Write a 5 word story via print_story"
+ (f" about {self.task}" if self.task else "")
)
],
)
return self.state.messages, self.tools
[docs]
def export_frame(self) -> Frame:
return Frame(
state={"messages": [m.content for m in self.state.messages]},
info={
"tool_names": [t.info.name for t in self.tools],
"done": self.state.done,
"reward": self.state.reward,
},
)
[docs]
class DummyTaskDataset(TaskDataset[DummyEnv]):
"""A dummy task of infinite DummyEnvs."""
[docs]
def get_new_env(self) -> DummyEnv:
return DummyEnv()
def __bool__(self) -> bool:
return True
def _get_cls_from_name(registry: dict[str, tuple[str, str]], name: str):
try:
module_name, cls_name = registry[name]
except KeyError:
raise ValueError(f"Unknown environment name: {name}") from None
try:
module = importlib.import_module(module_name)
except ImportError:
# TODO: before release: add install instructions per env?
raise ImportError(
f"Could not import env from {module_name}; you need to install it."
) from None
return getattr(module, cls_name)