Source code for aviary.env_client
from abc import ABC, abstractmethod
from typing import Any
import httpx
from aviary.env import Environment, TEnvState
from aviary.message import Message
from aviary.tools import MessagesAdapter, Tool, ToolRequestMessage, ToolsAdapter
[docs]
class EnvironmentClient(Environment[TEnvState], ABC):
def __init__(
self,
reset_endpoint_url: str,
step_endpoint_url: str,
request_params: httpx._types.QueryParamTypes | None = None,
request_headers: httpx._types.HeaderTypes | None = None,
request_timeout: float | None = None,
):
self._reset_request_url = reset_endpoint_url
self._step_request_url = step_endpoint_url
self._request_params = request_params
self._request_headers = request_headers
self._request_timeout = request_timeout
[docs]
async def reset(self) -> tuple[list[Message], list[Tool]]:
async with httpx.AsyncClient() as client:
response = await client.post(
self._reset_request_url,
json=self._make_post_json(self.state),
params=self._request_params,
headers=self._request_headers,
timeout=self._request_timeout,
)
response.raise_for_status()
msgs, tools = response.json()
return MessagesAdapter.validate_python(msgs), ToolsAdapter.validate_python(
tools
)
[docs]
async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
async with httpx.AsyncClient() as client:
response = await client.post(
self._step_request_url,
json=self._make_post_json(self.state) | {"action": action.model_dump()},
params=self._request_params,
headers=self._request_headers,
timeout=self._request_timeout,
)
response.raise_for_status()
messages, reward, done, truncated = response.json()
return MessagesAdapter.validate_python(messages), reward, done, truncated
@abstractmethod
def _make_post_json(self, state: TEnvState) -> dict[str, Any]:
"""Extract values from state to sent as JSON for all reset/step POSTs."""