Skip to content

Agent Lightning Core

Client Side

agentlightning.litagent

LitAgent

Bases: Generic[T]

Base class for the training and validation logic of an agent.

Developers should subclass this class and implement the rollout methods to define the agent's behavior for a single task. The agent's logic is completely decoupled from the server communication and training infrastructure.

Source code in agentlightning/litagent/litagent.py
class LitAgent(Generic[T]):
    """Base class for the training and validation logic of an agent.

    Developers should subclass this class and implement the rollout methods
    to define the agent's behavior for a single task. The agent's logic
    is completely decoupled from the server communication and training
    infrastructure.
    """

    def __init__(self, *, trained_agents: Optional[str] = None) -> None:  # FIXME: str | None won't work for cli
        """
        Initialize the LitAgent.

        Args:
            trained_agents: Optional string representing the trained agents.
                            This can be used to track which agents have been trained by this instance.
                            Deprecated. Configure `agent_match` in adapter instead.
        """
        if trained_agents is not None:
            warnings.warn(
                "`trained_agents` is deprecated. Configure `agent_match` in adapter instead.",
                DeprecationWarning,
                stacklevel=2,
            )
        self.trained_agents = trained_agents

        self._trainer_ref: weakref.ReferenceType[Trainer] | None = None
        self._runner_ref: weakref.ReferenceType[BaseRunner[T]] | None = None

    def is_async(self) -> bool:
        """
        Check if the agent implements asynchronous rollout methods.
        Override this property for customized async detection logic.

        Returns:
            True if the agent has custom async rollout methods, False otherwise.
        """
        return (
            (
                hasattr(self, "training_rollout_async")
                and self.__class__.training_rollout_async is not LitAgent.training_rollout_async  # type: ignore
            )
            or (
                hasattr(self, "validation_rollout_async")
                and self.__class__.validation_rollout_async is not LitAgent.validation_rollout_async  # type: ignore
            )
            or (hasattr(self, "rollout_async") and self.__class__.rollout_async is not LitAgent.rollout_async)  # type: ignore
        )

    def set_trainer(self, trainer: Trainer) -> None:
        """
        Set the trainer for this agent.

        Args:
            trainer: The Trainer instance that will handle training and validation.
        """
        self._trainer_ref = weakref.ref(trainer)

    def get_trainer(self) -> Trainer:
        """
        Get the trainer for this agent.

        Returns:
            The Trainer instance associated with this agent.
        """
        if self._trainer_ref is None:
            raise ValueError("Trainer has not been set for this agent.")
        trainer = self._trainer_ref()
        if trainer is None:
            raise ValueError("Trainer reference is no longer valid (object has been garbage collected).")
        return trainer

    @property
    def trainer(self) -> Trainer:
        """Convenient shortcut of self.get_trainer()."""
        return self.get_trainer()

    def get_tracer(self) -> BaseTracer:
        """
        Get the tracer for this agent.

        Returns:
            The BaseTracer instance associated with this agent.
        """
        return self.trainer.tracer

    @property
    def tracer(self) -> BaseTracer:
        """Convenient shortcut of self.get_tracer()."""
        return self.get_tracer()

    def set_runner(self, runner: BaseRunner[T]) -> None:
        """
        Set the runner for this agent.

        Args:
            runner: The runner instance that will handle the execution of rollouts.
        """
        self._runner_ref = weakref.ref(runner)

    def get_runner(self) -> BaseRunner[T]:
        """
        Get the runner for this agent.

        Returns:
            The runner instance associated with this agent.
        """
        if self._runner_ref is None:
            raise ValueError("Runner has not been set for this agent.")
        runner = self._runner_ref()
        if runner is None:
            raise ValueError("Runner reference is no longer valid (object has been garbage collected).")
        return runner

    @property
    def runner(self) -> BaseRunner[T]:
        """Convenient shortcut of self.get_runner()."""
        return self.get_runner()

    def on_rollout_start(self, task: Task, runner: BaseRunner[T], tracer: BaseTracer) -> None:
        """Hook called immediately before a rollout begins.

        Deprecated in favor of `on_rollout_start` in the `Hook` interface.

        Args:
            task: The :class:`Task` object that will be processed.
            runner: The :class:`BaseRunner` managing the rollout.
            tracer: The tracer instance associated with the runner.

        Subclasses can override this method to implement custom logic such as
        logging, metric collection, or resource setup. By default, this is a
        no-op.
        """

    def on_rollout_end(self, task: Task, rollout: RolloutV2, runner: BaseRunner[T], tracer: BaseTracer) -> None:
        """Hook called after a rollout completes.

        Deprecated in favor of `on_rollout_end` in the `Hook` interface.

        Args:
            task: The :class:`Task` object that was processed.
            rollout: The resulting :class:`Rollout` object.
            runner: The :class:`BaseRunner` managing the rollout.
            tracer: The tracer instance associated with the runner.

        Subclasses can override this method for cleanup or additional
        logging. By default, this is a no-op.
        """

    def rollout(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
        """Main entry point for executing a rollout.

        This method determines whether to call the synchronous or
        asynchronous rollout method based on the agent's implementation.

        If you don't wish to implement both training rollout and validation
        rollout separately, you can just implement `rollout` which will work for both.

        Args:
            task: The task object received from the server, containing the
                  input data and metadata.
            resources: A dictionary of named resources (e.g., LLMs, prompt
                       templates) for the agent to use.
            rollout: The full rollout object, please avoid from directly modifying it.
                     Most agents should only use `task` and `resources`. Use `rollout`
                     only if you need to access metadata like `rollout_id`.

        Returns:
            The result of the rollout, which can be one of:
            - None. The tracing should be handled by the agent runner.
            - A float representing the final reward.
            - A list of `Triplet` objects for detailed, step-by-step feedback.
            - A list of `ReadableSpan` objects for OpenTelemetry tracing.
            - A list of dictionaries for any trace spans.
            - A complete `Rollout` object for full control over reporting.
        """
        raise NotImplementedError("Agents must implement the `rollout` method.")

    async def rollout_async(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
        """Asynchronous version of the main rollout method.

        This method determines whether to call the synchronous or
        asynchronous rollout method based on the agent's implementation.

        Args:
            task: The task object received from the server, containing the
                  input data and metadata.
            resources: A dictionary of named resources (e.g., LLMs, prompt
                       templates) for the agent to use.
            rollout: The full rollout object, please avoid from directly modifying it.
                     Most agents should only use `task` and `resources`. Use `rollout`
                     only if you need to access metadata like `rollout_id`.

        Returns:
            The result of the rollout, which can be one of:
            - None. The tracing should be handled by the agent runner.
            - A float representing the final reward.
            - A list of `Triplet` objects for detailed, step-by-step feedback.
            - A list of `ReadableSpan` objects for OpenTelemetry tracing.
            - A list of dictionaries for any trace spans.
            - A complete `Rollout` object for full control over reporting.
        """
        raise NotImplementedError("Agents must implement the `rollout_async` method for async operations.")

    def training_rollout(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
        """Defines the agent's behavior for a single training task.

        This method should contain the logic for how the agent processes an
        input, uses the provided resources (like LLMs or prompts), and
        produces a result.

        Args:
            task: The task object received from the server, containing the
                  input data and metadata.
            resources: A dictionary of named resources (e.g., LLMs, prompt
                       templates) for the agent to use.
            rollout: The full rollout object, please avoid from directly modifying it.
        """
        return self.rollout(task, resources, rollout)

    def validation_rollout(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
        """Defines the agent's behavior for a single validation task.

        By default, this method redirects to `training_rollout`. Override it
        if the agent should behave differently during validation.

        Args:
            task: The task object received from the server, containing the
                  input data and metadata.
            resources: A dictionary of named resources for the agent to use.
            rollout: The full rollout object, avoid from modifying it.

        Returns:
            The result of the validation rollout. See `rollout` for
            possible return types.
        """
        return self.rollout(task, resources, rollout)

    async def training_rollout_async(
        self, task: T, resources: NamedResources, rollout: RolloutV2
    ) -> RolloutRawResultV2:
        """Asynchronous version of `training_rollout`.

        This method should be implemented by agents that perform asynchronous
        operations (e.g., non-blocking I/O, concurrent API calls).

        Args:
            task: The task object received from the server.
            resources: A dictionary of named resources for the agent to use.
            rollout: The full rollout object, avoid from modifying it.

        Returns:
            The result of the asynchronous training rollout. See `rollout` for
            possible return types.
        """
        return await self.rollout_async(task, resources, rollout)

    async def validation_rollout_async(
        self, task: T, resources: NamedResources, rollout: RolloutV2
    ) -> RolloutRawResultV2:
        """Asynchronous version of `validation_rollout`.

        By default, this method redirects to `training_rollout_async`.
        Override it for different asynchronous validation behavior.

        Args:
            task: The task object received from the server.
            resources: A dictionary of named resources for the agent to use.
            rollout: The full rollout object, avoid from modifying it.

        Returns:
            The result of the asynchronous validation rollout. See `rollout` for
            possible return types.
        """
        return await self.rollout_async(task, resources, rollout)

runner property

Convenient shortcut of self.get_runner().

tracer property

Convenient shortcut of self.get_tracer().

trainer property

Convenient shortcut of self.get_trainer().

__init__(*, trained_agents=None)

Initialize the LitAgent.

Parameters:

Name Type Description Default
trained_agents Optional[str]

Optional string representing the trained agents. This can be used to track which agents have been trained by this instance. Deprecated. Configure agent_match in adapter instead.

None
Source code in agentlightning/litagent/litagent.py
def __init__(self, *, trained_agents: Optional[str] = None) -> None:  # FIXME: str | None won't work for cli
    """
    Initialize the LitAgent.

    Args:
        trained_agents: Optional string representing the trained agents.
                        This can be used to track which agents have been trained by this instance.
                        Deprecated. Configure `agent_match` in adapter instead.
    """
    if trained_agents is not None:
        warnings.warn(
            "`trained_agents` is deprecated. Configure `agent_match` in adapter instead.",
            DeprecationWarning,
            stacklevel=2,
        )
    self.trained_agents = trained_agents

    self._trainer_ref: weakref.ReferenceType[Trainer] | None = None
    self._runner_ref: weakref.ReferenceType[BaseRunner[T]] | None = None

get_runner()

Get the runner for this agent.

Returns:

Type Description
BaseRunner[T]

The runner instance associated with this agent.

Source code in agentlightning/litagent/litagent.py
def get_runner(self) -> BaseRunner[T]:
    """
    Get the runner for this agent.

    Returns:
        The runner instance associated with this agent.
    """
    if self._runner_ref is None:
        raise ValueError("Runner has not been set for this agent.")
    runner = self._runner_ref()
    if runner is None:
        raise ValueError("Runner reference is no longer valid (object has been garbage collected).")
    return runner

get_tracer()

Get the tracer for this agent.

Returns:

Type Description
BaseTracer

The BaseTracer instance associated with this agent.

Source code in agentlightning/litagent/litagent.py
def get_tracer(self) -> BaseTracer:
    """
    Get the tracer for this agent.

    Returns:
        The BaseTracer instance associated with this agent.
    """
    return self.trainer.tracer

get_trainer()

Get the trainer for this agent.

Returns:

Type Description
Trainer

The Trainer instance associated with this agent.

Source code in agentlightning/litagent/litagent.py
def get_trainer(self) -> Trainer:
    """
    Get the trainer for this agent.

    Returns:
        The Trainer instance associated with this agent.
    """
    if self._trainer_ref is None:
        raise ValueError("Trainer has not been set for this agent.")
    trainer = self._trainer_ref()
    if trainer is None:
        raise ValueError("Trainer reference is no longer valid (object has been garbage collected).")
    return trainer

is_async()

Check if the agent implements asynchronous rollout methods. Override this property for customized async detection logic.

Returns:

Type Description
bool

True if the agent has custom async rollout methods, False otherwise.

Source code in agentlightning/litagent/litagent.py
def is_async(self) -> bool:
    """
    Check if the agent implements asynchronous rollout methods.
    Override this property for customized async detection logic.

    Returns:
        True if the agent has custom async rollout methods, False otherwise.
    """
    return (
        (
            hasattr(self, "training_rollout_async")
            and self.__class__.training_rollout_async is not LitAgent.training_rollout_async  # type: ignore
        )
        or (
            hasattr(self, "validation_rollout_async")
            and self.__class__.validation_rollout_async is not LitAgent.validation_rollout_async  # type: ignore
        )
        or (hasattr(self, "rollout_async") and self.__class__.rollout_async is not LitAgent.rollout_async)  # type: ignore
    )

on_rollout_end(task, rollout, runner, tracer)

Hook called after a rollout completes.

Deprecated in favor of on_rollout_end in the Hook interface.

Parameters:

Name Type Description Default
task Task

The :class:Task object that was processed.

required
rollout RolloutV2

The resulting :class:Rollout object.

required
runner BaseRunner[T]

The :class:BaseRunner managing the rollout.

required
tracer BaseTracer

The tracer instance associated with the runner.

required

Subclasses can override this method for cleanup or additional logging. By default, this is a no-op.

Source code in agentlightning/litagent/litagent.py
def on_rollout_end(self, task: Task, rollout: RolloutV2, runner: BaseRunner[T], tracer: BaseTracer) -> None:
    """Hook called after a rollout completes.

    Deprecated in favor of `on_rollout_end` in the `Hook` interface.

    Args:
        task: The :class:`Task` object that was processed.
        rollout: The resulting :class:`Rollout` object.
        runner: The :class:`BaseRunner` managing the rollout.
        tracer: The tracer instance associated with the runner.

    Subclasses can override this method for cleanup or additional
    logging. By default, this is a no-op.
    """

on_rollout_start(task, runner, tracer)

Hook called immediately before a rollout begins.

Deprecated in favor of on_rollout_start in the Hook interface.

Parameters:

Name Type Description Default
task Task

The :class:Task object that will be processed.

required
runner BaseRunner[T]

The :class:BaseRunner managing the rollout.

required
tracer BaseTracer

The tracer instance associated with the runner.

required

Subclasses can override this method to implement custom logic such as logging, metric collection, or resource setup. By default, this is a no-op.

Source code in agentlightning/litagent/litagent.py
def on_rollout_start(self, task: Task, runner: BaseRunner[T], tracer: BaseTracer) -> None:
    """Hook called immediately before a rollout begins.

    Deprecated in favor of `on_rollout_start` in the `Hook` interface.

    Args:
        task: The :class:`Task` object that will be processed.
        runner: The :class:`BaseRunner` managing the rollout.
        tracer: The tracer instance associated with the runner.

    Subclasses can override this method to implement custom logic such as
    logging, metric collection, or resource setup. By default, this is a
    no-op.
    """

rollout(task, resources, rollout)

Main entry point for executing a rollout.

This method determines whether to call the synchronous or asynchronous rollout method based on the agent's implementation.

If you don't wish to implement both training rollout and validation rollout separately, you can just implement rollout which will work for both.

Parameters:

Name Type Description Default
task T

The task object received from the server, containing the input data and metadata.

required
resources NamedResources

A dictionary of named resources (e.g., LLMs, prompt templates) for the agent to use.

required
rollout RolloutV2

The full rollout object, please avoid from directly modifying it. Most agents should only use task and resources. Use rollout only if you need to access metadata like rollout_id.

required

Returns:

Type Description
RolloutRawResultV2

The result of the rollout, which can be one of:

RolloutRawResultV2
  • None. The tracing should be handled by the agent runner.
RolloutRawResultV2
  • A float representing the final reward.
RolloutRawResultV2
  • A list of Triplet objects for detailed, step-by-step feedback.
RolloutRawResultV2
  • A list of ReadableSpan objects for OpenTelemetry tracing.
RolloutRawResultV2
  • A list of dictionaries for any trace spans.
RolloutRawResultV2
  • A complete Rollout object for full control over reporting.
Source code in agentlightning/litagent/litagent.py
def rollout(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
    """Main entry point for executing a rollout.

    This method determines whether to call the synchronous or
    asynchronous rollout method based on the agent's implementation.

    If you don't wish to implement both training rollout and validation
    rollout separately, you can just implement `rollout` which will work for both.

    Args:
        task: The task object received from the server, containing the
              input data and metadata.
        resources: A dictionary of named resources (e.g., LLMs, prompt
                   templates) for the agent to use.
        rollout: The full rollout object, please avoid from directly modifying it.
                 Most agents should only use `task` and `resources`. Use `rollout`
                 only if you need to access metadata like `rollout_id`.

    Returns:
        The result of the rollout, which can be one of:
        - None. The tracing should be handled by the agent runner.
        - A float representing the final reward.
        - A list of `Triplet` objects for detailed, step-by-step feedback.
        - A list of `ReadableSpan` objects for OpenTelemetry tracing.
        - A list of dictionaries for any trace spans.
        - A complete `Rollout` object for full control over reporting.
    """
    raise NotImplementedError("Agents must implement the `rollout` method.")

rollout_async(task, resources, rollout) async

Asynchronous version of the main rollout method.

This method determines whether to call the synchronous or asynchronous rollout method based on the agent's implementation.

Parameters:

Name Type Description Default
task T

The task object received from the server, containing the input data and metadata.

required
resources NamedResources

A dictionary of named resources (e.g., LLMs, prompt templates) for the agent to use.

required
rollout RolloutV2

The full rollout object, please avoid from directly modifying it. Most agents should only use task and resources. Use rollout only if you need to access metadata like rollout_id.

required

Returns:

Type Description
RolloutRawResultV2

The result of the rollout, which can be one of:

RolloutRawResultV2
  • None. The tracing should be handled by the agent runner.
RolloutRawResultV2
  • A float representing the final reward.
RolloutRawResultV2
  • A list of Triplet objects for detailed, step-by-step feedback.
RolloutRawResultV2
  • A list of ReadableSpan objects for OpenTelemetry tracing.
RolloutRawResultV2
  • A list of dictionaries for any trace spans.
RolloutRawResultV2
  • A complete Rollout object for full control over reporting.
Source code in agentlightning/litagent/litagent.py
async def rollout_async(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
    """Asynchronous version of the main rollout method.

    This method determines whether to call the synchronous or
    asynchronous rollout method based on the agent's implementation.

    Args:
        task: The task object received from the server, containing the
              input data and metadata.
        resources: A dictionary of named resources (e.g., LLMs, prompt
                   templates) for the agent to use.
        rollout: The full rollout object, please avoid from directly modifying it.
                 Most agents should only use `task` and `resources`. Use `rollout`
                 only if you need to access metadata like `rollout_id`.

    Returns:
        The result of the rollout, which can be one of:
        - None. The tracing should be handled by the agent runner.
        - A float representing the final reward.
        - A list of `Triplet` objects for detailed, step-by-step feedback.
        - A list of `ReadableSpan` objects for OpenTelemetry tracing.
        - A list of dictionaries for any trace spans.
        - A complete `Rollout` object for full control over reporting.
    """
    raise NotImplementedError("Agents must implement the `rollout_async` method for async operations.")

set_runner(runner)

Set the runner for this agent.

Parameters:

Name Type Description Default
runner BaseRunner[T]

The runner instance that will handle the execution of rollouts.

required
Source code in agentlightning/litagent/litagent.py
def set_runner(self, runner: BaseRunner[T]) -> None:
    """
    Set the runner for this agent.

    Args:
        runner: The runner instance that will handle the execution of rollouts.
    """
    self._runner_ref = weakref.ref(runner)

set_trainer(trainer)

Set the trainer for this agent.

Parameters:

Name Type Description Default
trainer Trainer

The Trainer instance that will handle training and validation.

required
Source code in agentlightning/litagent/litagent.py
def set_trainer(self, trainer: Trainer) -> None:
    """
    Set the trainer for this agent.

    Args:
        trainer: The Trainer instance that will handle training and validation.
    """
    self._trainer_ref = weakref.ref(trainer)

training_rollout(task, resources, rollout)

Defines the agent's behavior for a single training task.

This method should contain the logic for how the agent processes an input, uses the provided resources (like LLMs or prompts), and produces a result.

Parameters:

Name Type Description Default
task T

The task object received from the server, containing the input data and metadata.

required
resources NamedResources

A dictionary of named resources (e.g., LLMs, prompt templates) for the agent to use.

required
rollout RolloutV2

The full rollout object, please avoid from directly modifying it.

required
Source code in agentlightning/litagent/litagent.py
def training_rollout(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
    """Defines the agent's behavior for a single training task.

    This method should contain the logic for how the agent processes an
    input, uses the provided resources (like LLMs or prompts), and
    produces a result.

    Args:
        task: The task object received from the server, containing the
              input data and metadata.
        resources: A dictionary of named resources (e.g., LLMs, prompt
                   templates) for the agent to use.
        rollout: The full rollout object, please avoid from directly modifying it.
    """
    return self.rollout(task, resources, rollout)

training_rollout_async(task, resources, rollout) async

Asynchronous version of training_rollout.

This method should be implemented by agents that perform asynchronous operations (e.g., non-blocking I/O, concurrent API calls).

Parameters:

Name Type Description Default
task T

The task object received from the server.

required
resources NamedResources

A dictionary of named resources for the agent to use.

required
rollout RolloutV2

The full rollout object, avoid from modifying it.

required

Returns:

Type Description
RolloutRawResultV2

The result of the asynchronous training rollout. See rollout for

RolloutRawResultV2

possible return types.

Source code in agentlightning/litagent/litagent.py
async def training_rollout_async(
    self, task: T, resources: NamedResources, rollout: RolloutV2
) -> RolloutRawResultV2:
    """Asynchronous version of `training_rollout`.

    This method should be implemented by agents that perform asynchronous
    operations (e.g., non-blocking I/O, concurrent API calls).

    Args:
        task: The task object received from the server.
        resources: A dictionary of named resources for the agent to use.
        rollout: The full rollout object, avoid from modifying it.

    Returns:
        The result of the asynchronous training rollout. See `rollout` for
        possible return types.
    """
    return await self.rollout_async(task, resources, rollout)

validation_rollout(task, resources, rollout)

Defines the agent's behavior for a single validation task.

By default, this method redirects to training_rollout. Override it if the agent should behave differently during validation.

Parameters:

Name Type Description Default
task T

The task object received from the server, containing the input data and metadata.

required
resources NamedResources

A dictionary of named resources for the agent to use.

required
rollout RolloutV2

The full rollout object, avoid from modifying it.

required

Returns:

Type Description
RolloutRawResultV2

The result of the validation rollout. See rollout for

RolloutRawResultV2

possible return types.

Source code in agentlightning/litagent/litagent.py
def validation_rollout(self, task: T, resources: NamedResources, rollout: RolloutV2) -> RolloutRawResultV2:
    """Defines the agent's behavior for a single validation task.

    By default, this method redirects to `training_rollout`. Override it
    if the agent should behave differently during validation.

    Args:
        task: The task object received from the server, containing the
              input data and metadata.
        resources: A dictionary of named resources for the agent to use.
        rollout: The full rollout object, avoid from modifying it.

    Returns:
        The result of the validation rollout. See `rollout` for
        possible return types.
    """
    return self.rollout(task, resources, rollout)

validation_rollout_async(task, resources, rollout) async

Asynchronous version of validation_rollout.

By default, this method redirects to training_rollout_async. Override it for different asynchronous validation behavior.

Parameters:

Name Type Description Default
task T

The task object received from the server.

required
resources NamedResources

A dictionary of named resources for the agent to use.

required
rollout RolloutV2

The full rollout object, avoid from modifying it.

required

Returns:

Type Description
RolloutRawResultV2

The result of the asynchronous validation rollout. See rollout for

RolloutRawResultV2

possible return types.

Source code in agentlightning/litagent/litagent.py
async def validation_rollout_async(
    self, task: T, resources: NamedResources, rollout: RolloutV2
) -> RolloutRawResultV2:
    """Asynchronous version of `validation_rollout`.

    By default, this method redirects to `training_rollout_async`.
    Override it for different asynchronous validation behavior.

    Args:
        task: The task object received from the server.
        resources: A dictionary of named resources for the agent to use.
        rollout: The full rollout object, avoid from modifying it.

    Returns:
        The result of the asynchronous validation rollout. See `rollout` for
        possible return types.
    """
    return await self.rollout_async(task, resources, rollout)

is_v0_1_rollout_api(func)

Check if the rollout API is v0.1. Inspect the function signature to see if it has a rollout_id parameter.

Parameters:

Name Type Description Default
func Callable[..., Any]

The function to check.

required
Source code in agentlightning/litagent/litagent.py
def is_v0_1_rollout_api(func: Callable[..., Any]) -> bool:
    """Check if the rollout API is v0.1.
    Inspect the function signature to see if it has a rollout_id parameter.

    Args:
        func: The function to check.
    """
    return "rollout_id" in inspect.signature(func).parameters

llm_rollout(func=None, *, strip_proxy=True)

llm_rollout(
    func: LlmRolloutFunc[T],
) -> FunctionalLitAgent[T]
llm_rollout(
    *, strip_proxy: bool = True
) -> Callable[[LlmRolloutFunc[T]], FunctionalLitAgent[T]]

Create a FunctionalLitAgent from a function that takes (task, llm[, rollout]).

This decorator allows you to define an agent using a simple function instead of creating a full LitAgent subclass. The returned FunctionalLitAgent instance is callable, preserving the original function's behavior.

Parameters:

Name Type Description Default
func LlmRolloutFunc[T] | None

A function that defines the agent's behavior. Can be: - sync: (task, llm) -> result - sync with rollout: (task, llm, rollout) -> result - async: async (task, llm) -> result - async with rollout: async (task, llm, rollout) -> result

None
strip_proxy bool

Whether to strip the ProxyLLM resource into a LLM resource. Defaults to True.

True

Returns:

Type Description
FunctionalLitAgent[T] | Callable[[LlmRolloutFunc[T]], FunctionalLitAgent[T]]

A callable FunctionalLitAgent instance that preserves the original function's

FunctionalLitAgent[T] | Callable[[LlmRolloutFunc[T]], FunctionalLitAgent[T]]

type hints and behavior while providing all agent functionality.

Example

@llm_rollout def my_agent(task, llm): # Agent logic here return response

@llm_rollout(strip_proxy=False) def my_agent_no_strip(task, llm): # Agent logic here return response

Function is still callable with original behavior

result = my_agent(task, llm)

Agent methods are also available

result = my_agent.rollout(task, resources, rollout)

Source code in agentlightning/litagent/decorator.py
def llm_rollout(
    func: LlmRolloutFunc[T] | None = None, *, strip_proxy: bool = True
) -> FunctionalLitAgent[T] | Callable[[LlmRolloutFunc[T]], FunctionalLitAgent[T]]:
    """Create a FunctionalLitAgent from a function that takes (task, llm[, rollout]).

    This decorator allows you to define an agent using a simple function
    instead of creating a full LitAgent subclass. The returned FunctionalLitAgent
    instance is callable, preserving the original function's behavior.

    Args:
        func: A function that defines the agent's behavior. Can be:
              - sync: (task, llm) -> result
              - sync with rollout: (task, llm, rollout) -> result
              - async: async (task, llm) -> result
              - async with rollout: async (task, llm, rollout) -> result
        strip_proxy: Whether to strip the ProxyLLM resource into a LLM resource.
                     Defaults to True.

    Returns:
        A callable FunctionalLitAgent instance that preserves the original function's
        type hints and behavior while providing all agent functionality.

    Example:
        @llm_rollout
        def my_agent(task, llm):
            # Agent logic here
            return response

        @llm_rollout(strip_proxy=False)
        def my_agent_no_strip(task, llm):
            # Agent logic here
            return response

        # Function is still callable with original behavior
        result = my_agent(task, llm)

        # Agent methods are also available
        result = my_agent.rollout(task, resources, rollout)
    """

    def decorator(f: LlmRolloutFunc[T]) -> FunctionalLitAgent[T]:
        _validate_llm_rollout_func(f)
        return FunctionalLitAgent(f, strip_proxy=strip_proxy)

    if func is None:
        # Called with arguments: @llm_rollout(strip_proxy=False)
        return decorator
    else:
        # Called without arguments: @llm_rollout
        return decorator(func)

prompt_rollout(func=None)

prompt_rollout(
    func: PromptRolloutFunc[T],
) -> FunctionalLitAgent[T]
prompt_rollout() -> (
    Callable[[PromptRolloutFunc[T]], FunctionalLitAgent[T]]
)

Create a FunctionalLitAgent from a function that takes (task, prompt_template[, rollout]).

This decorator is designed for agents that work with tunable prompt templates. It enables a workflow where algorithms manage and optimize the prompt template, while agents consume the template to perform rollouts. This is particularly useful for prompt optimization scenarios.

Parameters:

Name Type Description Default
func PromptRolloutFunc[T] | None

A function that defines the agent's behavior. Can be: - sync: (task, prompt_template) -> result - sync with rollout: (task, prompt_template, rollout) -> result - async: async (task, prompt_template) -> result - async with rollout: async (task, prompt_template, rollout) -> result

None

Returns:

Type Description
FunctionalLitAgent[T] | Callable[[PromptRolloutFunc[T]], FunctionalLitAgent[T]]

A callable FunctionalLitAgent instance that preserves the original function's

FunctionalLitAgent[T] | Callable[[PromptRolloutFunc[T]], FunctionalLitAgent[T]]

type hints and behavior while providing all agent functionality.

Example

@prompt_rollout def my_agent(task, prompt_template): # Use the prompt template to generate a response messages = prompt_template.format(task=task.input) # ... perform rollout with the formatted prompt return response

Function is still callable with original behavior

result = my_agent(task, prompt_template)

Agent methods are also available

result = my_agent.rollout(task, resources, rollout)

Source code in agentlightning/litagent/decorator.py
def prompt_rollout(
    func: PromptRolloutFunc[T] | None = None,
) -> FunctionalLitAgent[T] | Callable[[PromptRolloutFunc[T]], FunctionalLitAgent[T]]:
    """Create a FunctionalLitAgent from a function that takes (task, prompt_template[, rollout]).

    This decorator is designed for agents that work with tunable prompt templates. It enables
    a workflow where algorithms manage and optimize the prompt template, while agents consume
    the template to perform rollouts. This is particularly useful for prompt optimization scenarios.

    Args:
        func: A function that defines the agent's behavior. Can be:
              - sync: (task, prompt_template) -> result
              - sync with rollout: (task, prompt_template, rollout) -> result
              - async: async (task, prompt_template) -> result
              - async with rollout: async (task, prompt_template, rollout) -> result

    Returns:
        A callable FunctionalLitAgent instance that preserves the original function's
        type hints and behavior while providing all agent functionality.

    Example:
        @prompt_rollout
        def my_agent(task, prompt_template):
            # Use the prompt template to generate a response
            messages = prompt_template.format(task=task.input)
            # ... perform rollout with the formatted prompt
            return response

        # Function is still callable with original behavior
        result = my_agent(task, prompt_template)

        # Agent methods are also available
        result = my_agent.rollout(task, resources, rollout)
    """

    def decorator(f: PromptRolloutFunc[T]) -> FunctionalLitAgent[T]:
        _validate_prompt_rollout_func(f)
        return FunctionalLitAgent(f)

    if func is None:
        return decorator
    else:
        return decorator(func)

rollout(func)

Create a LitAgent from a function, automatically detecting the appropriate type.

This function inspects the provided callable and creates the appropriate agent type based on its signature. It supports both LLM-based and prompt-template-based agents. The returned agent instance is callable, preserving the original function's behavior and type hints.

Parameters:

Name Type Description Default
func Union[LlmRolloutFunc[T], PromptRolloutFunc[T], Callable[..., Any]]

A function that defines the agent's behavior. Supported signatures: - (task, llm[, rollout]) for LLM-based agents - (task, prompt_template[, rollout]) for prompt-template-based agents

required

Returns:

Type Description
FunctionalLitAgent[T]

A callable FunctionalLitAgent instance that preserves the original function's

FunctionalLitAgent[T]

type hints and behavior while providing all agent functionality.

Example

LLM-based agent

@rollout def my_llm_agent(task, llm): client = OpenAI(base_url=llm.endpoint) response = client.chat.completions.create( model=llm.model, messages=[{"role": "user", "content": task.input}], ) return response

Prompt-template-based agent

@rollout def my_prompt_agent(task, prompt_template): messages = prompt_template.format(task=task.input) # ... perform rollout with the formatted prompt return response

Function is still callable with original behavior

result = my_llm_agent(task, llm)

Agent methods are also available

result = my_llm_agent.rollout(task, resources, rollout)

Raises:

Type Description
NotImplementedError

If the function signature doesn't match any known patterns.

Source code in agentlightning/litagent/decorator.py
def rollout(func: Union[LlmRolloutFunc[T], PromptRolloutFunc[T], Callable[..., Any]]) -> FunctionalLitAgent[T]:
    """Create a LitAgent from a function, automatically detecting the appropriate type.

    This function inspects the provided callable and creates the appropriate
    agent type based on its signature. It supports both LLM-based and prompt-template-based
    agents. The returned agent instance is callable, preserving the original function's
    behavior and type hints.

    Args:
        func: A function that defines the agent's behavior. Supported signatures:
              - (task, llm[, rollout]) for LLM-based agents
              - (task, prompt_template[, rollout]) for prompt-template-based agents

    Returns:
        A callable FunctionalLitAgent instance that preserves the original function's
        type hints and behavior while providing all agent functionality.

    Example:
        # LLM-based agent
        @rollout
        def my_llm_agent(task, llm):
            client = OpenAI(base_url=llm.endpoint)
            response = client.chat.completions.create(
                model=llm.model,
                messages=[{"role": "user", "content": task.input}],
            )
            return response

        # Prompt-template-based agent
        @rollout
        def my_prompt_agent(task, prompt_template):
            messages = prompt_template.format(task=task.input)
            # ... perform rollout with the formatted prompt
            return response

        # Function is still callable with original behavior
        result = my_llm_agent(task, llm)

        # Agent methods are also available
        result = my_llm_agent.rollout(task, resources, rollout)

    Raises:
        NotImplementedError: If the function signature doesn't match any known patterns.
    """
    # Check if it matches the LLM rollout API pattern
    sig = inspect.signature(func)

    try:
        if _validate_llm_rollout_func(func):
            return llm_rollout(func)
    except ValueError:
        pass

    try:
        if _validate_prompt_rollout_func(func):
            return prompt_rollout(func)
    except ValueError:
        pass

    raise NotImplementedError(
        f"Function signature {sig} does not match any known agent patterns. "
        "Expected signatures: (task, llm[, rollout]) or (task, prompt_template[, rollout]). "
        "Functions can be sync or async."
    )

agentlightning.client

Legacy client for interacting with a legacy Agent Lightning server.

AgentLightningClient

Client for interacting with a version-aware Agent Lightning Server.

This client handles polling for tasks, fetching specific versions of resources (like model configurations), and posting completed rollouts back to the server. It provides both synchronous and asynchronous methods for these operations and includes a cache for resources.

Source code in agentlightning/client.py
class AgentLightningClient:
    """
    Client for interacting with a version-aware Agent Lightning Server.

    This client handles polling for tasks, fetching specific versions of resources
    (like model configurations), and posting completed rollouts back to the server.
    It provides both synchronous and asynchronous methods for these operations and
    includes a cache for resources.
    """

    _next_task_uri = "/task"
    _resources_uri = "/resources"
    _latest_resources_uri = "/resources/latest"
    _report_rollout_uri = "/rollout"

    def __init__(self, endpoint: str, poll_interval: float = 5.0, timeout: float = 10.0):
        """Initializes the AgentLightningClient.

        Args:
            endpoint: The root URL of the Agent Lightning server.
            poll_interval: The interval in seconds to wait between polling for new tasks.
            timeout: The timeout in seconds for HTTP requests.
        """
        self.endpoint = endpoint
        self.task_count = 0
        self.poll_interval = poll_interval
        self.timeout = timeout
        self._resource_cache: Dict[str, ResourcesUpdate] = {}  # TODO: mechanism to evict cache
        self._default_headers = {"X-AgentLightning-Client": "true"}

    async def _request_json_async(self, url: str) -> Optional[Dict[str, Any]]:
        """Makes an async GET request to the specified URL and returns the JSON response.

        Args:
            url: The URL to request.

        Returns:
            The JSON response as a dictionary or None if the request fails.
        """
        timeout = aiohttp.ClientTimeout(total=self.timeout)
        async with aiohttp.ClientSession(timeout=timeout) as session:
            try:
                async with session.get(url, headers=self._default_headers) as resp:
                    resp.raise_for_status()
                    return await resp.json()
            except Exception as e:
                logger.debug(f"Async GET request failed for {url}: {e}")
                return None

    async def _post_json_async(self, url: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """Makes an async POST request with a JSON payload.

        Args:
            url: The URL to post to.
            payload: The dictionary data to send as JSON.

        Returns:
            The JSON response as a dictionary or None if the request fails.
        """
        timeout = aiohttp.ClientTimeout(total=self.timeout)
        async with aiohttp.ClientSession(timeout=timeout) as session:
            try:
                async with session.post(url, json=payload, headers=self._default_headers) as resp:
                    resp.raise_for_status()
                    return await resp.json()
            except Exception as e:
                logger.debug(f"Async POST request failed for {url}: {e}")
                return None

    async def poll_next_task_async(self) -> Optional[Task]:
        """Polls the server asynchronously for the next task until one is available.

        Returns:
            A Task object containing the task details.
        """
        url = urllib.parse.urljoin(self.endpoint, self._next_task_uri)
        while True:
            response = await self._request_json_async(url)
            if response:
                task_if_any = TaskIfAny.model_validate(response)
                if task_if_any.is_available and task_if_any.task:
                    self.task_count += 1
                    logger.info(f"[Task {self.task_count} Received] ID: {task_if_any.task.rollout_id}")
                    return task_if_any.task
            logger.debug(f"No task available yet. Retrying in {self.poll_interval} seconds...")
            await asyncio.sleep(self.poll_interval)

    async def get_resources_by_id_async(self, resource_id: str) -> Optional[ResourcesUpdate]:
        """Fetches a specific version of resources by its ID, using a cache.

        Args:
            resource_id: The ID of the resources to fetch, usually from a Task's metadata.

        Returns:
            A ResourcesUpdate object containing the versioned resources, or None if not found.
        """
        if resource_id in self._resource_cache:
            logger.debug(f"Found resources '{resource_id}' in cache.")
            return self._resource_cache[resource_id]

        url = urllib.parse.urljoin(self.endpoint, f"{self._resources_uri}/{resource_id}")
        response = await self._request_json_async(url)
        if response:
            resources_update = ResourcesUpdate.model_validate(response)
            self._resource_cache[resource_id] = resources_update
            logger.info(f"Fetched and cached resources for ID: {resource_id}")
            return resources_update
        return None

    async def get_latest_resources_async(self) -> Optional[ResourcesUpdate]:
        """Fetches the latest available resources from the server.

        Returns:
            A ResourcesUpdate object containing the latest resources.
        """
        url = urllib.parse.urljoin(self.endpoint, self._latest_resources_uri)
        response = await self._request_json_async(url)
        if response:
            resources_update = ResourcesUpdate.model_validate(response)
            # Cache this result as well
            self._resource_cache[resources_update.resources_id] = resources_update
            return resources_update
        return None

    async def post_rollout_async(self, rollout: Rollout) -> Optional[Dict[str, Any]]:
        """Posts a completed rollout to the server asynchronously.

        Args:
            rollout: A Rollout object containing the results of a task.

        Returns:
            The server's JSON response as a dictionary.
        """
        url = urllib.parse.urljoin(self.endpoint, self._report_rollout_uri)
        payload = rollout.model_dump(mode="json")
        return await self._post_json_async(url, payload)

    def _request_json(self, url: str) -> Optional[Dict[str, Any]]:
        """Makes a sync GET request to the specified URL and returns the JSON response.

        Args:
            url: The URL to request.

        Returns:
            The JSON response as a dictionary or None if the request fails.
        """
        try:
            response = requests.get(url, timeout=self.timeout, headers=self._default_headers)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.debug(f"Sync GET request failed for {url}: {e}")
            return None

    def _post_json(self, url: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """Makes a sync POST request with a JSON payload.

        Args:
            url: The URL to post to.
            payload: The dictionary data to send as JSON.

        Returns:
            The JSON response as a dictionary or None if the request fails.
        """
        try:
            response = requests.post(url, json=payload, timeout=self.timeout, headers=self._default_headers)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.debug(f"Sync POST request failed for {url}: {e}")
            return None

    def poll_next_task(self) -> Optional[Task]:
        """Polls the server synchronously for the next task until one is available.

        Returns:
            A Task object containing the task details, including the required `resources_id`.
        """
        url = urllib.parse.urljoin(self.endpoint, self._next_task_uri)
        while True:
            response = self._request_json(url)
            if response:
                task_if_any = TaskIfAny.model_validate(response)
                if task_if_any.is_available and task_if_any.task:
                    self.task_count += 1
                    logger.info(f"[Task {self.task_count} Received] ID: {task_if_any.task.rollout_id}")
                    return task_if_any.task
            logger.debug(f"No task available yet. Retrying in {self.poll_interval} seconds...")
            time.sleep(self.poll_interval)

    def get_resources_by_id(self, resource_id: str) -> Optional[ResourcesUpdate]:
        """Fetches a specific version of resources by its ID synchronously, using a cache.

        Args:
            resource_id: The ID of the resources to fetch, usually from a Task's metadata.

        Returns:
            A ResourcesUpdate object containing the versioned resources, or None if not found.
        """
        if resource_id in self._resource_cache:
            logger.debug(f"Found resources '{resource_id}' in cache.")
            return self._resource_cache[resource_id]

        url = urllib.parse.urljoin(self.endpoint, f"{self._resources_uri}/{resource_id}")
        response = self._request_json(url)
        if response:
            resources_update = ResourcesUpdate.model_validate(response)
            self._resource_cache[resource_id] = resources_update
            logger.info(f"Fetched and cached resources for ID: {resource_id}")
            return resources_update
        return None

    def get_latest_resources(self) -> Optional[ResourcesUpdate]:
        """Fetches the latest available resources from the server synchronously.

        Returns:
            A ResourcesUpdate object containing the latest resources.
        """
        url = urllib.parse.urljoin(self.endpoint, self._latest_resources_uri)
        response = self._request_json(url)
        if response:
            resources_update = ResourcesUpdate.model_validate(response)
            self._resource_cache[resources_update.resources_id] = resources_update
            return resources_update
        return None

    def post_rollout(self, rollout: Rollout) -> Optional[Dict[str, Any]]:
        """Posts a completed rollout to the server synchronously.

        Args:
            rollout: A Rollout object containing the results of a task.

        Returns:
            The server's JSON response as a dictionary.
        """
        url = urllib.parse.urljoin(self.endpoint, self._report_rollout_uri)
        payload = rollout.model_dump(mode="json")
        return self._post_json(url, payload)

__init__(endpoint, poll_interval=5.0, timeout=10.0)

Initializes the AgentLightningClient.

Parameters:

Name Type Description Default
endpoint str

The root URL of the Agent Lightning server.

required
poll_interval float

The interval in seconds to wait between polling for new tasks.

5.0
timeout float

The timeout in seconds for HTTP requests.

10.0
Source code in agentlightning/client.py
def __init__(self, endpoint: str, poll_interval: float = 5.0, timeout: float = 10.0):
    """Initializes the AgentLightningClient.

    Args:
        endpoint: The root URL of the Agent Lightning server.
        poll_interval: The interval in seconds to wait between polling for new tasks.
        timeout: The timeout in seconds for HTTP requests.
    """
    self.endpoint = endpoint
    self.task_count = 0
    self.poll_interval = poll_interval
    self.timeout = timeout
    self._resource_cache: Dict[str, ResourcesUpdate] = {}  # TODO: mechanism to evict cache
    self._default_headers = {"X-AgentLightning-Client": "true"}

get_latest_resources()

Fetches the latest available resources from the server synchronously.

Returns:

Type Description
Optional[ResourcesUpdate]

A ResourcesUpdate object containing the latest resources.

Source code in agentlightning/client.py
def get_latest_resources(self) -> Optional[ResourcesUpdate]:
    """Fetches the latest available resources from the server synchronously.

    Returns:
        A ResourcesUpdate object containing the latest resources.
    """
    url = urllib.parse.urljoin(self.endpoint, self._latest_resources_uri)
    response = self._request_json(url)
    if response:
        resources_update = ResourcesUpdate.model_validate(response)
        self._resource_cache[resources_update.resources_id] = resources_update
        return resources_update
    return None

get_latest_resources_async() async

Fetches the latest available resources from the server.

Returns:

Type Description
Optional[ResourcesUpdate]

A ResourcesUpdate object containing the latest resources.

Source code in agentlightning/client.py
async def get_latest_resources_async(self) -> Optional[ResourcesUpdate]:
    """Fetches the latest available resources from the server.

    Returns:
        A ResourcesUpdate object containing the latest resources.
    """
    url = urllib.parse.urljoin(self.endpoint, self._latest_resources_uri)
    response = await self._request_json_async(url)
    if response:
        resources_update = ResourcesUpdate.model_validate(response)
        # Cache this result as well
        self._resource_cache[resources_update.resources_id] = resources_update
        return resources_update
    return None

get_resources_by_id(resource_id)

Fetches a specific version of resources by its ID synchronously, using a cache.

Parameters:

Name Type Description Default
resource_id str

The ID of the resources to fetch, usually from a Task's metadata.

required

Returns:

Type Description
Optional[ResourcesUpdate]

A ResourcesUpdate object containing the versioned resources, or None if not found.

Source code in agentlightning/client.py
def get_resources_by_id(self, resource_id: str) -> Optional[ResourcesUpdate]:
    """Fetches a specific version of resources by its ID synchronously, using a cache.

    Args:
        resource_id: The ID of the resources to fetch, usually from a Task's metadata.

    Returns:
        A ResourcesUpdate object containing the versioned resources, or None if not found.
    """
    if resource_id in self._resource_cache:
        logger.debug(f"Found resources '{resource_id}' in cache.")
        return self._resource_cache[resource_id]

    url = urllib.parse.urljoin(self.endpoint, f"{self._resources_uri}/{resource_id}")
    response = self._request_json(url)
    if response:
        resources_update = ResourcesUpdate.model_validate(response)
        self._resource_cache[resource_id] = resources_update
        logger.info(f"Fetched and cached resources for ID: {resource_id}")
        return resources_update
    return None

get_resources_by_id_async(resource_id) async

Fetches a specific version of resources by its ID, using a cache.

Parameters:

Name Type Description Default
resource_id str

The ID of the resources to fetch, usually from a Task's metadata.

required

Returns:

Type Description
Optional[ResourcesUpdate]

A ResourcesUpdate object containing the versioned resources, or None if not found.

Source code in agentlightning/client.py
async def get_resources_by_id_async(self, resource_id: str) -> Optional[ResourcesUpdate]:
    """Fetches a specific version of resources by its ID, using a cache.

    Args:
        resource_id: The ID of the resources to fetch, usually from a Task's metadata.

    Returns:
        A ResourcesUpdate object containing the versioned resources, or None if not found.
    """
    if resource_id in self._resource_cache:
        logger.debug(f"Found resources '{resource_id}' in cache.")
        return self._resource_cache[resource_id]

    url = urllib.parse.urljoin(self.endpoint, f"{self._resources_uri}/{resource_id}")
    response = await self._request_json_async(url)
    if response:
        resources_update = ResourcesUpdate.model_validate(response)
        self._resource_cache[resource_id] = resources_update
        logger.info(f"Fetched and cached resources for ID: {resource_id}")
        return resources_update
    return None

poll_next_task()

Polls the server synchronously for the next task until one is available.

Returns:

Type Description
Optional[Task]

A Task object containing the task details, including the required resources_id.

Source code in agentlightning/client.py
def poll_next_task(self) -> Optional[Task]:
    """Polls the server synchronously for the next task until one is available.

    Returns:
        A Task object containing the task details, including the required `resources_id`.
    """
    url = urllib.parse.urljoin(self.endpoint, self._next_task_uri)
    while True:
        response = self._request_json(url)
        if response:
            task_if_any = TaskIfAny.model_validate(response)
            if task_if_any.is_available and task_if_any.task:
                self.task_count += 1
                logger.info(f"[Task {self.task_count} Received] ID: {task_if_any.task.rollout_id}")
                return task_if_any.task
        logger.debug(f"No task available yet. Retrying in {self.poll_interval} seconds...")
        time.sleep(self.poll_interval)

poll_next_task_async() async

Polls the server asynchronously for the next task until one is available.

Returns:

Type Description
Optional[Task]

A Task object containing the task details.

Source code in agentlightning/client.py
async def poll_next_task_async(self) -> Optional[Task]:
    """Polls the server asynchronously for the next task until one is available.

    Returns:
        A Task object containing the task details.
    """
    url = urllib.parse.urljoin(self.endpoint, self._next_task_uri)
    while True:
        response = await self._request_json_async(url)
        if response:
            task_if_any = TaskIfAny.model_validate(response)
            if task_if_any.is_available and task_if_any.task:
                self.task_count += 1
                logger.info(f"[Task {self.task_count} Received] ID: {task_if_any.task.rollout_id}")
                return task_if_any.task
        logger.debug(f"No task available yet. Retrying in {self.poll_interval} seconds...")
        await asyncio.sleep(self.poll_interval)

post_rollout(rollout)

Posts a completed rollout to the server synchronously.

Parameters:

Name Type Description Default
rollout Rollout

A Rollout object containing the results of a task.

required

Returns:

Type Description
Optional[Dict[str, Any]]

The server's JSON response as a dictionary.

Source code in agentlightning/client.py
def post_rollout(self, rollout: Rollout) -> Optional[Dict[str, Any]]:
    """Posts a completed rollout to the server synchronously.

    Args:
        rollout: A Rollout object containing the results of a task.

    Returns:
        The server's JSON response as a dictionary.
    """
    url = urllib.parse.urljoin(self.endpoint, self._report_rollout_uri)
    payload = rollout.model_dump(mode="json")
    return self._post_json(url, payload)

post_rollout_async(rollout) async

Posts a completed rollout to the server asynchronously.

Parameters:

Name Type Description Default
rollout Rollout

A Rollout object containing the results of a task.

required

Returns:

Type Description
Optional[Dict[str, Any]]

The server's JSON response as a dictionary.

Source code in agentlightning/client.py
async def post_rollout_async(self, rollout: Rollout) -> Optional[Dict[str, Any]]:
    """Posts a completed rollout to the server asynchronously.

    Args:
        rollout: A Rollout object containing the results of a task.

    Returns:
        The server's JSON response as a dictionary.
    """
    url = urllib.parse.urljoin(self.endpoint, self._report_rollout_uri)
    payload = rollout.model_dump(mode="json")
    return await self._post_json_async(url, payload)

DevTaskLoader

Bases: AgentLightningClient

A local task manager for development that provides sample tasks and resources.

This client mocks the server APIs by maintaining a local queue of tasks and resources within the same process. It's designed for development, testing, and scenarios where a full Agent Lightning server is not needed.

The DevTaskLoader overrides the polling and resource fetching methods to return data from local collections instead of making HTTP requests to a remote server.

Source code in agentlightning/client.py
class DevTaskLoader(AgentLightningClient):
    """A local task manager for development that provides sample tasks and resources.

    This client mocks the server APIs by maintaining a local queue of tasks and resources
    within the same process. It's designed for development, testing, and scenarios where
    a full Agent Lightning server is not needed.

    The DevTaskLoader overrides the polling and resource fetching methods to return data
    from local collections instead of making HTTP requests to a remote server.
    """

    def __init__(
        self,
        tasks: Union[List[TaskInput], List[Task]],
        resources: Union[NamedResources, ResourcesUpdate],
        **kwargs: Any,
    ):
        """Initializes the DevTaskLoader with pre-defined tasks and resources.

        Args:
            tasks: Either a List of TaskInput objects or a List of Task objects.
            resources: Either NamedResources or ResourcesUpdate object.
            **kwargs: Additional arguments passed to the parent AgentLightningClient.
        """
        super().__init__(endpoint="local://", **kwargs)
        self._tasks = tasks.copy()
        if len(self._tasks) == 0:
            raise ValueError("DevTaskLoader requires at least one task to be provided.")

        # Check if tasks are mixture of TaskInput and Task
        if any(isinstance(task, Task) for task in self._tasks):
            if not all(isinstance(task, Task) for task in self._tasks):
                raise ValueError("All tasks must be either Task or TaskInput objects.")

        self._task_index = 0

        if isinstance(resources, ResourcesUpdate):
            self._resources_update = resources
        else:
            self._resources_update = ResourcesUpdate(resources_id="local", resources=resources)

        # Store rollouts posted back to the loader for easy debugging of local runs
        self._rollouts: List[Rollout] = []

    @property
    def rollouts(self) -> List[Rollout]:
        """Return rollouts that have been posted back to the loader."""
        return self._rollouts

    def poll_next_task(self) -> Optional[Task]:
        """Returns the next task from the local queue.

        If tasks are TaskInput objects, assembles them into Task objects.
        If tasks are already Task objects, returns them directly.

        Returns:
            The next Task object from the local task list.
        """
        if self._task_index >= len(self._tasks):
            self._task_index = 0

        task_or_input = self._tasks[self._task_index]

        if isinstance(task_or_input, Task):
            task = task_or_input
        else:
            rollout_id = f"local_task_{self._task_index + 1:03d}"
            task = Task(
                rollout_id=rollout_id,
                input=task_or_input,
                resources_id=self._resources_update.resources_id,
                create_time=time.time(),
            )

        self._task_index += 1
        self.task_count += 1
        logger.info(f"[Task {self.task_count} Received] Task ID: {task.rollout_id}")
        return task

    def get_resources_by_id(self, resource_id: str) -> Optional[ResourcesUpdate]:
        logger.debug(f"DevTaskLoader checking resources for ID: {resource_id}")
        if resource_id != self._resources_update.resources_id:
            raise ValueError(
                f"Resource ID '{resource_id}' not found. Only '{self._resources_update.resources_id}' is available."
            )
        return self._resources_update

    def get_latest_resources(self) -> Optional[ResourcesUpdate]:
        logger.debug("DevTaskLoader returning latest resources.")
        return self._resources_update

    def post_rollout(self, rollout: Rollout) -> Optional[Dict[str, Any]]:
        logger.debug(f"DevTaskLoader received rollout for task: {rollout.rollout_id}")
        self._rollouts.append(rollout)
        return {"status": "received", "rollout_id": rollout.rollout_id}

    async def poll_next_task_async(self) -> Optional[Task]:
        return self.poll_next_task()

    async def get_resources_by_id_async(self, resource_id: str) -> Optional[ResourcesUpdate]:
        return self.get_resources_by_id(resource_id)

    async def get_latest_resources_async(self) -> Optional[ResourcesUpdate]:
        return self.get_latest_resources()

    async def post_rollout_async(self, rollout: Rollout) -> Optional[Dict[str, Any]]:
        return self.post_rollout(rollout)

    def __repr__(self):
        return f"DevTaskLoader(num_tasks={len(self._tasks)}, resources={self._resources_update.resources})"

rollouts property

Return rollouts that have been posted back to the loader.

__init__(tasks, resources, **kwargs)

Initializes the DevTaskLoader with pre-defined tasks and resources.

Parameters:

Name Type Description Default
tasks Union[List[TaskInput], List[Task]]

Either a List of TaskInput objects or a List of Task objects.

required
resources Union[NamedResources, ResourcesUpdate]

Either NamedResources or ResourcesUpdate object.

required
**kwargs Any

Additional arguments passed to the parent AgentLightningClient.

{}
Source code in agentlightning/client.py
def __init__(
    self,
    tasks: Union[List[TaskInput], List[Task]],
    resources: Union[NamedResources, ResourcesUpdate],
    **kwargs: Any,
):
    """Initializes the DevTaskLoader with pre-defined tasks and resources.

    Args:
        tasks: Either a List of TaskInput objects or a List of Task objects.
        resources: Either NamedResources or ResourcesUpdate object.
        **kwargs: Additional arguments passed to the parent AgentLightningClient.
    """
    super().__init__(endpoint="local://", **kwargs)
    self._tasks = tasks.copy()
    if len(self._tasks) == 0:
        raise ValueError("DevTaskLoader requires at least one task to be provided.")

    # Check if tasks are mixture of TaskInput and Task
    if any(isinstance(task, Task) for task in self._tasks):
        if not all(isinstance(task, Task) for task in self._tasks):
            raise ValueError("All tasks must be either Task or TaskInput objects.")

    self._task_index = 0

    if isinstance(resources, ResourcesUpdate):
        self._resources_update = resources
    else:
        self._resources_update = ResourcesUpdate(resources_id="local", resources=resources)

    # Store rollouts posted back to the loader for easy debugging of local runs
    self._rollouts: List[Rollout] = []

poll_next_task()

Returns the next task from the local queue.

If tasks are TaskInput objects, assembles them into Task objects. If tasks are already Task objects, returns them directly.

Returns:

Type Description
Optional[Task]

The next Task object from the local task list.

Source code in agentlightning/client.py
def poll_next_task(self) -> Optional[Task]:
    """Returns the next task from the local queue.

    If tasks are TaskInput objects, assembles them into Task objects.
    If tasks are already Task objects, returns them directly.

    Returns:
        The next Task object from the local task list.
    """
    if self._task_index >= len(self._tasks):
        self._task_index = 0

    task_or_input = self._tasks[self._task_index]

    if isinstance(task_or_input, Task):
        task = task_or_input
    else:
        rollout_id = f"local_task_{self._task_index + 1:03d}"
        task = Task(
            rollout_id=rollout_id,
            input=task_or_input,
            resources_id=self._resources_update.resources_id,
            create_time=time.time(),
        )

    self._task_index += 1
    self.task_count += 1
    logger.info(f"[Task {self.task_count} Received] Task ID: {task.rollout_id}")
    return task

agentlightning.runner

AgentRunner

Bases: BaseRunner[Any]

Manages the agent's execution loop and integrates with AgentOps.

This class orchestrates the interaction between the agent (LitAgent) and the server (AgentLightningClient). It handles polling for tasks, executing the agent's logic, and reporting results back to the server. If enabled, it will also automatically trace each rollout using AgentOps.

Attributes:

Name Type Description
agent

The LitAgent instance containing the agent's logic.

client

The AgentLightningClient for server communication.

tracer

The tracer instance for this runner/worker.

worker_id

An optional identifier for the worker process.

max_tasks

The maximum number of tasks to process before stopping.

Source code in agentlightning/runner/legacy.py
class AgentRunner(BaseRunner[Any]):
    """Manages the agent's execution loop and integrates with AgentOps.

    This class orchestrates the interaction between the agent (`LitAgent`) and
    the server (`AgentLightningClient`). It handles polling for tasks, executing
    the agent's logic, and reporting results back to the server. If enabled,
    it will also automatically trace each rollout using AgentOps.

    Attributes:
        agent: The `LitAgent` instance containing the agent's logic.
        client: The `AgentLightningClient` for server communication.
        tracer: The tracer instance for this runner/worker.
        worker_id: An optional identifier for the worker process.
        max_tasks: The maximum number of tasks to process before stopping.
    """

    def __init__(
        self,
        agent: LitAgent[Any],
        client: AgentLightningClient,
        tracer: BaseTracer,
        triplet_exporter: TraceTripletAdapter,
        worker_id: Optional[int] = None,
        max_tasks: Optional[int] = None,
    ):
        super().__init__()
        self.agent = agent
        self.client = client
        self.tracer = tracer
        self.triplet_exporter = triplet_exporter

        # Worker-specific attributes
        self.worker_id = worker_id
        self.max_tasks = max_tasks

    # These methods are overridden by BaseRunner, getting them back to old behavior.
    def init(self, *args: Any, **kwargs: Any) -> None:
        pass

    def init_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
        self.worker_id = worker_id

    def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
        pass

    def teardown(self, *args: Any, **kwargs: Any) -> None:
        pass

    def _log_prefix(self, rollout_id: Optional[str] = None) -> str:
        """Generates a standardized log prefix for the current worker."""
        if self.worker_id is not None:
            if rollout_id:
                return f"[Worker {self.worker_id} | Rollout {rollout_id}]"
            else:
                return f"[Worker {self.worker_id}]"
        if rollout_id:
            return f"[Rollout {rollout_id}]"
        return "[Default Worker]"

    def _to_rollout_object(
        self,
        result: RolloutRawResult,
        rollout_id: str,
    ) -> Rollout:
        """Standardizes the agent's return value into a Rollout object.

        Args:
            result: The output from the agent's rollout method.
            rollout_id: The unique identifier for the current task.

        Returns:
            A standardized `Rollout` object for reporting to the server.
        """
        trace: Any = None
        final_reward: Optional[float] = None
        triplets: Optional[List[Triplet]] = None
        trace_spans: Optional[List[ReadableSpan]] = None

        # Handle different types of results from the agent
        # Case 1: result is a float (final reward)
        if isinstance(result, float):
            final_reward = result
        # Case 2: result is a list of Triplets
        if isinstance(result, list) and all(isinstance(t, Triplet) for t in result):
            triplets = result  # type: ignore
        # Case 3: result is a list of ReadableSpan (OpenTelemetry spans)
        if isinstance(result, list) and all(isinstance(t, ReadableSpan) for t in result):
            trace_spans = result  # type: ignore
            trace = [json.loads(readable_span.to_json()) for readable_span in trace_spans]  # type: ignore
        # Case 4: result is a list of dict (trace JSON)
        if isinstance(result, list) and all(isinstance(t, dict) for t in result):
            trace = result
        # Case 5: result is a Rollout object
        if isinstance(result, Rollout):
            final_reward = result.final_reward
            triplets = result.triplets
            trace = result.trace

        # If the agent has tracing enabled, use the tracer's last trace if not already set
        if self.tracer and (trace is None or trace_spans is None):
            spans = self.tracer.get_last_trace()
            if spans:
                trace = [json.loads(readable_span.to_json()) for readable_span in spans]
                trace_spans = spans

        # Always extract triplets from the trace using TraceTripletAdapter
        if trace_spans:
            triplets = self.triplet_exporter(trace_spans)

        # If the agent has triplets, use the last one for final reward if not set
        if triplets and triplets[-1].reward is not None and final_reward is None:
            final_reward = triplets[-1].reward

        # Create the Rollout object with standardized fields
        result_dict: Dict[str, Any] = {
            "rollout_id": rollout_id,
        }
        if final_reward is not None:
            result_dict["final_reward"] = final_reward
        if triplets is not None:
            result_dict["triplets"] = triplets
        if trace is not None:
            result_dict["trace"] = trace

        if isinstance(result, Rollout):
            return result.model_copy(update=result_dict)
        return Rollout(**result_dict)

    def run(self) -> bool:  # type: ignore
        """Poll the task and rollout once synchronously."""
        self.agent.set_runner(self)  # Ensure the agent has a reference to this runner

        task = self.client.poll_next_task()
        if task is None:
            logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
            return False
        rollout_id = task.rollout_id

        resources_id = task.resources_id
        resources_update = None
        if resources_id:
            resources_update = self.client.get_resources_by_id(resources_id)
        else:
            logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
            resources_update = self.client.get_latest_resources()
        if not resources_update:
            logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
            return False

        rollout_obj = Rollout(rollout_id=task.rollout_id, task=task)  # Default empty rollout

        try:
            try:
                self.agent.on_rollout_start(task, self, self.tracer)
            except Exception:
                logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")

            with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
                start_time = time.time()
                rollout_method = self.agent.training_rollout if task.mode == "train" else self.agent.validation_rollout
                # Pass the task input, not the whole task object
                if is_v0_1_rollout_api(rollout_method):
                    result = cast(
                        RolloutRawResult,
                        rollout_method(
                            task.input, rollout_id=rollout_obj.rollout_id, resources=resources_update.resources  # type: ignore
                        ),
                    )  # type: ignore
                else:
                    result = rollout_method(task.input, resources=resources_update.resources, rollout=rollout_obj)
                rollout_obj = self._to_rollout_object(result, task.rollout_id)
                end_time = time.time()
                logger.info(
                    f"{self._log_prefix(rollout_id)} Completed in "
                    f"{end_time - start_time:.2f}s. Triplet length: "
                    f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. "
                    f"Reward: {rollout_obj.final_reward}"
                )

        except Exception:
            logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
        finally:
            try:
                self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
            except Exception:
                logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
            self.client.post_rollout(rollout_obj)

        return True

    def iter(self) -> int:  # type: ignore
        """Executes the synchronous polling and rollout loop."""
        num_tasks_processed = 0
        logger.info(f"{self._log_prefix()} Started sync rollouts (max: {self.max_tasks or 'unlimited'}).")

        while self.max_tasks is None or num_tasks_processed < self.max_tasks:
            if self.run():
                num_tasks_processed += 1

            if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
                logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")

        logger.info(f"{self._log_prefix()} Finished sync rollouts. Processed {num_tasks_processed} tasks.")
        return num_tasks_processed

    async def run_async(self) -> bool:
        """Poll the task and rollout once."""
        self.agent.set_runner(self)  # Ensure the agent has a reference to this runner

        task = await self.client.poll_next_task_async()
        if task is None:
            logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
            return False
        rollout_id = task.rollout_id

        resources_id = task.resources_id
        resources_update = None
        if resources_id:
            resources_update = await self.client.get_resources_by_id_async(resources_id)
        else:
            logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
            resources_update = await self.client.get_latest_resources_async()
        if not resources_update:
            logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
            return False

        rollout_obj = Rollout(rollout_id=task.rollout_id, task=task)  # Default empty rollout

        try:
            try:
                self.agent.on_rollout_start(task, self, self.tracer)
            except Exception:
                logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")

            with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
                start_time = time.time()
                rollout_method = (
                    self.agent.training_rollout_async if task.mode == "train" else self.agent.validation_rollout_async
                )
                # Pass the task input, not the whole task object
                if is_v0_1_rollout_api(rollout_method):
                    result = cast(
                        RolloutRawResult,
                        await rollout_method(
                            task.input, rollout_id=rollout_obj.rollout_id, resources=resources_update.resources  # type: ignore
                        ),
                    )  # type: ignore
                else:
                    result = await rollout_method(task.input, resources=resources_update.resources, rollout=rollout_obj)
                rollout_obj = self._to_rollout_object(result, task.rollout_id)
                end_time = time.time()
                logger.info(
                    f"{self._log_prefix(rollout_id)} Completed in "
                    f"{end_time - start_time:.2f}s. Triplet length: "
                    f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. "
                    f"Reward: {rollout_obj.final_reward}"
                )
        except Exception:
            logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
        finally:
            try:
                self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
            except Exception:
                logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
            await self.client.post_rollout_async(rollout_obj)

        return True

    async def iter_async(self) -> int:
        """Executes the asynchronous polling and rollout loop."""
        num_tasks_processed = 0
        logger.info(f"{self._log_prefix()} Started async rollouts (max: {self.max_tasks or 'unlimited'}).")

        while self.max_tasks is None or num_tasks_processed < self.max_tasks:
            if await self.run_async():
                num_tasks_processed += 1

            if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
                logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")
        logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")
        return num_tasks_processed

iter()

Executes the synchronous polling and rollout loop.

Source code in agentlightning/runner/legacy.py
def iter(self) -> int:  # type: ignore
    """Executes the synchronous polling and rollout loop."""
    num_tasks_processed = 0
    logger.info(f"{self._log_prefix()} Started sync rollouts (max: {self.max_tasks or 'unlimited'}).")

    while self.max_tasks is None or num_tasks_processed < self.max_tasks:
        if self.run():
            num_tasks_processed += 1

        if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
            logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")

    logger.info(f"{self._log_prefix()} Finished sync rollouts. Processed {num_tasks_processed} tasks.")
    return num_tasks_processed

iter_async() async

Executes the asynchronous polling and rollout loop.

Source code in agentlightning/runner/legacy.py
async def iter_async(self) -> int:
    """Executes the asynchronous polling and rollout loop."""
    num_tasks_processed = 0
    logger.info(f"{self._log_prefix()} Started async rollouts (max: {self.max_tasks or 'unlimited'}).")

    while self.max_tasks is None or num_tasks_processed < self.max_tasks:
        if await self.run_async():
            num_tasks_processed += 1

        if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
            logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")
    logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")
    return num_tasks_processed

run()

Poll the task and rollout once synchronously.

Source code in agentlightning/runner/legacy.py
def run(self) -> bool:  # type: ignore
    """Poll the task and rollout once synchronously."""
    self.agent.set_runner(self)  # Ensure the agent has a reference to this runner

    task = self.client.poll_next_task()
    if task is None:
        logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
        return False
    rollout_id = task.rollout_id

    resources_id = task.resources_id
    resources_update = None
    if resources_id:
        resources_update = self.client.get_resources_by_id(resources_id)
    else:
        logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
        resources_update = self.client.get_latest_resources()
    if not resources_update:
        logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
        return False

    rollout_obj = Rollout(rollout_id=task.rollout_id, task=task)  # Default empty rollout

    try:
        try:
            self.agent.on_rollout_start(task, self, self.tracer)
        except Exception:
            logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")

        with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
            start_time = time.time()
            rollout_method = self.agent.training_rollout if task.mode == "train" else self.agent.validation_rollout
            # Pass the task input, not the whole task object
            if is_v0_1_rollout_api(rollout_method):
                result = cast(
                    RolloutRawResult,
                    rollout_method(
                        task.input, rollout_id=rollout_obj.rollout_id, resources=resources_update.resources  # type: ignore
                    ),
                )  # type: ignore
            else:
                result = rollout_method(task.input, resources=resources_update.resources, rollout=rollout_obj)
            rollout_obj = self._to_rollout_object(result, task.rollout_id)
            end_time = time.time()
            logger.info(
                f"{self._log_prefix(rollout_id)} Completed in "
                f"{end_time - start_time:.2f}s. Triplet length: "
                f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. "
                f"Reward: {rollout_obj.final_reward}"
            )

    except Exception:
        logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
    finally:
        try:
            self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
        except Exception:
            logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
        self.client.post_rollout(rollout_obj)

    return True

run_async() async

Poll the task and rollout once.

Source code in agentlightning/runner/legacy.py
async def run_async(self) -> bool:
    """Poll the task and rollout once."""
    self.agent.set_runner(self)  # Ensure the agent has a reference to this runner

    task = await self.client.poll_next_task_async()
    if task is None:
        logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
        return False
    rollout_id = task.rollout_id

    resources_id = task.resources_id
    resources_update = None
    if resources_id:
        resources_update = await self.client.get_resources_by_id_async(resources_id)
    else:
        logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
        resources_update = await self.client.get_latest_resources_async()
    if not resources_update:
        logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
        return False

    rollout_obj = Rollout(rollout_id=task.rollout_id, task=task)  # Default empty rollout

    try:
        try:
            self.agent.on_rollout_start(task, self, self.tracer)
        except Exception:
            logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")

        with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
            start_time = time.time()
            rollout_method = (
                self.agent.training_rollout_async if task.mode == "train" else self.agent.validation_rollout_async
            )
            # Pass the task input, not the whole task object
            if is_v0_1_rollout_api(rollout_method):
                result = cast(
                    RolloutRawResult,
                    await rollout_method(
                        task.input, rollout_id=rollout_obj.rollout_id, resources=resources_update.resources  # type: ignore
                    ),
                )  # type: ignore
            else:
                result = await rollout_method(task.input, resources=resources_update.resources, rollout=rollout_obj)
            rollout_obj = self._to_rollout_object(result, task.rollout_id)
            end_time = time.time()
            logger.info(
                f"{self._log_prefix(rollout_id)} Completed in "
                f"{end_time - start_time:.2f}s. Triplet length: "
                f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. "
                f"Reward: {rollout_obj.final_reward}"
            )
    except Exception:
        logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
    finally:
        try:
            self.agent.on_rollout_end(task, rollout_obj, self, self.tracer)
        except Exception:
            logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
        await self.client.post_rollout_async(rollout_obj)

    return True

AgentRunnerV2

Bases: BaseRunner[T_task]

Runner implementation for executing agent tasks with distributed support.

This runner manages the complete lifecycle of agent rollout execution, including task polling, resource management, tracing, and hooks. It supports both continuous iteration over tasks from the store and single-step execution.

Attributes:

Name Type Description
worker_id Optional[int]

The unique identifier for this worker process.

Source code in agentlightning/runner/agent.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
class AgentRunnerV2(BaseRunner[T_task]):
    """Runner implementation for executing agent tasks with distributed support.

    This runner manages the complete lifecycle of agent rollout execution,
    including task polling, resource management, tracing, and hooks. It supports
    both continuous iteration over tasks from the store and single-step execution.

    Attributes:
        worker_id: The unique identifier for this worker process.
    """

    def __init__(self, tracer: BaseTracer, max_rollouts: Optional[int] = None, poll_interval: float = 5.0) -> None:
        """Initialize the agent runner.

        Args:
            tracer: The tracer instance for recording execution traces and spans.
            max_rollouts: Maximum number of tasks to process in iter() mode. If None,
                the runner will continue indefinitely until interrupted.
            poll_interval: Time in seconds to wait between polling attempts when
                no tasks are available in the store.
        """
        super().__init__()
        self._tracer = tracer
        self._max_rollouts = max_rollouts
        self._poll_interval = poll_interval

        # Set later
        self._agent: Optional[LitAgent[T_task]] = None
        self._hooks: Sequence[Hook] = []
        self._store: Optional[LightningStore] = None
        self.worker_id: Optional[int] = None

    def init(self, agent: LitAgent[T_task], *, hooks: Optional[Sequence[Hook]] = None, **kwargs: Any) -> None:
        """Initialize the runner with the agent.

        This sets up the agent-runner relationship, registers hooks, and
        initializes the tracer.

        Args:
            agent: The LitAgent instance to be managed by this runner.
            hooks: Optional sequence of Hook objects to be called at various
                lifecycle stages (on_trace_start, on_trace_end, on_rollout_start,
                on_rollout_end).
            **kwargs: Additional initialization arguments (currently unused).
        """
        self._agent = agent
        self._agent.set_runner(self)
        self._hooks = [*hooks] if hooks is not None else []

        self._tracer.init()

    def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> None:
        """Initialize the runner for each worker with worker_id and store.

        This method is called once per worker in a distributed setup to provide
        the worker with its ID and store connection.

        Args:
            worker_id: Unique identifier for this worker process.
            store: The LightningStore instance for task coordination and data persistence.
            **kwargs: Additional worker-specific initialization arguments (currently unused).
        """
        self._store = store
        self.worker_id = worker_id

        self._tracer.init_worker(worker_id)

    def teardown(self, *args: Any, **kwargs: Any) -> None:
        """Teardown the runner and clean up all resources.

        This method resets all internal state including the agent, store,
        hooks, and worker ID, and calls the tracer's teardown method.

        Args:
            *args: Additional teardown arguments (currently unused).
            **kwargs: Additional teardown keyword arguments (currently unused).
        """
        self._agent = None
        self._store = None
        self.worker_id = None
        self._hooks = []

        self._tracer.teardown()

    def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
        """Teardown the runner for a specific worker.

        This method cleans up worker-specific resources and resets the worker ID.

        Args:
            worker_id: The unique identifier of the worker being torn down.
            *args: Additional teardown arguments (currently unused).
            **kwargs: Additional teardown keyword arguments (currently unused).
        """
        self.worker_id = None

        self._tracer.teardown_worker(worker_id)

    def get_agent(self) -> LitAgent[T_task]:
        """Get the agent instance.

        Returns:
            The LitAgent instance managed by this runner.

        Raises:
            ValueError: If the agent has not been initialized via init().
        """
        if self._agent is None:
            raise ValueError("Agent not initialized. Call init() first.")
        return self._agent

    def get_store(self) -> LightningStore:
        """Get the store instance.

        Returns:
            The LightningStore instance for this worker.

        Raises:
            ValueError: If the store has not been initialized via init_worker().
        """
        if self._store is None:
            raise ValueError("Store not initialized. Call init_worker() first.")
        return self._store

    def get_worker_id(self) -> str:
        """Get the formatted worker ID string.

        Returns:
            A formatted string like "Worker-0" if initialized, or "Worker-Unknown"
            if the worker ID has not been set.
        """
        return f"Worker-{self.worker_id}" if self.worker_id is not None else "Worker-Unknown"

    def _log_prefix(self, rollout_id: Optional[str] = None) -> str:
        """Generate a standardized log prefix for the current worker.

        This creates a consistent prefix format for log messages to identify
        which worker and rollout the message is associated with.

        Args:
            rollout_id: Optional rollout ID to include in the prefix.

        Returns:
            A formatted log prefix string like "[Worker 0 | Rollout xyz]",
            "[Worker 0]", "[Rollout xyz]", or "[Default Worker]".
        """
        if self.worker_id is not None:
            if rollout_id:
                return f"[Worker {self.worker_id} | Rollout {rollout_id}]"
            else:
                return f"[Worker {self.worker_id}]"
        if rollout_id:
            return f"[Rollout {rollout_id}]"
        return "[Default Worker]"

    async def _trigger_hooks(
        self,
        hook_type: Literal["on_trace_start", "on_trace_end", "on_rollout_start", "on_rollout_end"],
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Trigger all registered hooks of a specific type.

        This method calls the specified hook method on all registered hooks,
        catching and logging any exceptions that occur during hook execution
        to prevent them from disrupting the main execution flow.

        Args:
            hook_type: The type of hook to trigger. Valid values are:
                "on_trace_start", "on_trace_end", "on_rollout_start", "on_rollout_end".
            *args: Positional arguments to pass to the hook methods.
            **kwargs: Keyword arguments to pass to the hook methods.
        """
        for hook in self._hooks:
            try:
                await getattr(hook, hook_type)(*args, **kwargs)
            except Exception:
                logger.exception(f"{self._log_prefix()} Exception during {hook_type} hook {hook}.")

    async def _post_process_rollout_result(
        self, rollout: AttemptedRollout, raw_result: RolloutRawResultV2
    ) -> List[ReadableSpan] | List[Span]:
        """Standardizes the agent's return value and report what's needed to report to the store.

        Args:
            rollout: The rollout object for the current task.
            raw_result: The output from the agent's rollout method.

        Returns:
            The spans that are assumed to be added to the store.
            This only serves as an estimation for logging purposes. For precise tracking, use the store directly.
        """
        store = self.get_store()

        trace_spans: list[ReadableSpan] | list[Span] = []

        # Case 0: result is None
        if raw_result is None:
            trace_spans = self._tracer.get_last_trace()

        # Case 1: result is a float (final reward)
        if isinstance(raw_result, float):
            # Preserve the existing spans before another span is emitted
            trace_spans = list(self._tracer.get_last_trace())
            # This will emit another span to the tracer
            reward_span = emit_reward(raw_result)
            await store.add_otel_span(rollout.rollout_id, rollout.attempt.attempt_id, reward_span)
            trace_spans.append(reward_span)

        if isinstance(raw_result, list):
            # For rollout methods that return a list, we assume that the returned spans
            # are the complete span set from the whole rollout
            trace_spans = raw_result

            # Case 2: result is a list of ReadableSpan (OpenTelemetry spans)
            if len(raw_result) > 0 and all(isinstance(t, ReadableSpan) for t in raw_result):

                if not isinstance(
                    self._tracer, AgentOpsTracer
                ):  # TODO: this should be replaced with general OpenTelemetry tracer in next version
                    for span in raw_result:
                        await store.add_otel_span(
                            rollout.rollout_id, rollout.attempt.attempt_id, cast(ReadableSpan, span)
                        )
                else:
                    logger.warning(
                        f"{self._log_prefix(rollout.rollout_id)} Tracer is already an OpenTelemetry tracer. "
                        "The traces should have already been added to the store. "
                        "No need to return anything from rollout."
                    )

            # Case 3: result is a list of Span (agentlightning spans)
            elif len(raw_result) > 0 and all(isinstance(t, Span) for t in raw_result):
                # Add the spans directly to the store
                for span in raw_result:
                    await store.add_span(cast(Span, span))
                trace_spans = raw_result

            # Left over cases for list
            elif len(raw_result) == 0:
                logger.warning(
                    f"{self._log_prefix(rollout.rollout_id)} The rollout returns an empty list. "
                    "Please check your rollout implementation."
                )
                trace_spans = raw_result

            else:
                types = [type(t).__name__ for t in raw_result][:10]
                raise ValueError(
                    f"Invalid raw result type. It's expected to be a list of ReadableSpan or Span, "
                    f"but got: {', '.join(types)}..."
                )

        return trace_spans

    async def _sleep_until_next_poll(self, event: Optional[Event] = None) -> None:
        """Sleep until the next poll interval, with optional event-based interruption.

        If an event is provided, the method will check it periodically (every 0.1s)
        and return early if the event is set.

        Args:
            event: Optional Event object that can be used to interrupt the sleep.
                If set during the sleep period, the method returns immediately.
        """
        if event is None:
            await asyncio.sleep(self._poll_interval)
            return
        current_time = time.time()
        next_time = current_time + self._poll_interval
        while time.time() < next_time:
            await asyncio.sleep(0.1)
            if event.is_set():
                return

    async def _step_impl(self, next_rollout: AttemptedRollout, raise_on_exception: bool = False) -> None:
        """Execute a single rollout implementation.

        This is the core method that handles the execution of a single rollout,
        including resource fetching, hook triggering, agent invocation, tracing,
        and result processing.

        Args:
            next_rollout: The rollout to execute, containing input data, mode,
                and resources information.
            raise_on_exception: If True, exceptions during rollout execution will
                be re-raised. If False, exceptions are logged but not propagated.
        """
        store = self.get_store()
        agent = self.get_agent()

        rollout_id = next_rollout.rollout_id

        resources_id = next_rollout.resources_id
        resources_update = None
        if resources_id:
            resources_update = await store.get_resources_by_id(resources_id)
        else:
            logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
            resources_update = await store.get_latest_resources()
        if not resources_update:
            if raise_on_exception:
                raise RuntimeError(f"{self._log_prefix(rollout_id)} Failed to fetch resources")
            else:
                logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
                return

        trace_spans: List[ReadableSpan] | List[Span] = []
        has_exception: bool = False

        try:
            await self._trigger_hooks(hook_type="on_rollout_start", agent=agent, runner=self, rollout=next_rollout)

            start_time = time.time()
            with self._tracer.trace_context(
                name=rollout_id, store=store, rollout_id=rollout_id, attempt_id=next_rollout.attempt.attempt_id
            ):
                await self._trigger_hooks(
                    hook_type="on_trace_start", agent=agent, runner=self, tracer=self._tracer, rollout=next_rollout
                )

                # NOTE: This is the most costly step in the whole function
                # If the rollout method becomes unresponsive or timeouts, there is nothing we can do within the runner.
                # We might need some mechanisms in execution strategy to restart the runner. But that's a future work.
                if agent.is_async():
                    rollout_method = (
                        agent.training_rollout_async if next_rollout.mode == "train" else agent.validation_rollout_async
                    )
                    result = await rollout_method(
                        next_rollout.input, resources=resources_update.resources, rollout=next_rollout
                    )
                else:
                    rollout_method = (
                        agent.training_rollout if next_rollout.mode == "train" else agent.validation_rollout
                    )
                    result = rollout_method(
                        next_rollout.input, resources=resources_update.resources, rollout=next_rollout
                    )

                await self._trigger_hooks(
                    hook_type="on_trace_end", agent=agent, runner=self, tracer=self._tracer, rollout=next_rollout
                )

            # Possible exceptions in post_process will be caught in the overall exception handler
            trace_spans = await self._post_process_rollout_result(next_rollout, result)
            last_reward = find_final_reward(trace_spans)

            end_time = time.time()
            logger.info(
                f"{self._log_prefix(rollout_id)} Completed in "
                f"{end_time - start_time:.2f}s. Collected {len(trace_spans)} span(s). "
                f"Final reward: {last_reward}"
            )

        except Exception:
            logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
            has_exception = True

            if raise_on_exception:
                raise
        finally:
            try:
                await self._trigger_hooks(
                    hook_type="on_rollout_end", agent=agent, runner=self, rollout=next_rollout, spans=trace_spans
                )
            except Exception:
                logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")

            try:
                if has_exception:
                    # possibly timed out and cancelled?
                    await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="failed")
                else:
                    await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="succeeded")
            except Exception:
                logger.exception(
                    f"{self._log_prefix(rollout_id)} Exception during update_attempt. Giving up the update."
                )

    async def iter(self, *, event: Optional[Event] = None) -> None:
        """Run the runner, continuously iterating over tasks in the store.

        This method polls the store for new rollouts and executes them until:
        - The event is set (if provided)
        - The max_rollouts limit is reached (if configured)
        - No more tasks are available

        All exceptions during rollout execution are caught and logged but not
        propagated, allowing the runner to continue processing subsequent tasks.

        Args:
            event: Optional Event object to signal the runner to stop. The runner
                will check this event periodically and stop gracefully when set.
        """
        num_tasks_processed = 0
        logger.info(f"{self._log_prefix()} Started async rollouts (max: {self._max_rollouts or 'unlimited'}).")
        store = self.get_store()

        while not (event is not None and event.is_set()) and (
            self._max_rollouts is None or num_tasks_processed < self._max_rollouts
        ):
            # Retrieve the next rollout
            next_rollout: Optional[RolloutV2] = None
            while not (event is not None and event.is_set()):
                logger.debug(f"{self._log_prefix()} Try to poll for next rollout.")
                next_rollout = await store.dequeue_rollout()
                if next_rollout is None:
                    logger.debug(f"{self._log_prefix()} No rollout to poll. Waiting for {self._poll_interval} seconds.")
                    await self._sleep_until_next_poll(event)
                else:
                    break

            if next_rollout is None:
                return

            try:
                # Claim the rollout but updating the current worker id
                await store.update_attempt(
                    next_rollout.rollout_id, next_rollout.attempt.attempt_id, worker_id=self.get_worker_id()
                )
            except Exception:
                # This exception could happen if the rollout is dequeued and the other end died for some reason
                logger.exception(f"{self._log_prefix()} Exception during update_attempt, giving up the rollout.")
                continue

            # Execute the step
            await self._step_impl(next_rollout)

            num_tasks_processed += 1
            if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
                logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self._max_rollouts or 'unlimited'}")

        logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")

    async def step(
        self,
        input: T_task,
        *,
        resources: Optional[NamedResources] = None,
        mode: Optional[RolloutMode] = None,
        event: Optional[Event] = None,
    ) -> None:
        """Execute a single task directly, bypassing the task queue.

        This method creates a new rollout for the given input and executes it
        immediately. Unlike iter(), exceptions are propagated to the caller.

        Args:
            input: The task input to be processed by the agent.
            resources: Optional named resources to be used for this specific task.
                If provided, a new resources entry will be created in the store.
                If not provided, the latest resources from the store will be used.
            mode: Optional rollout mode ("train" or "validation"). If not provided,
                the agent's default mode will be used.
            event: Optional Event object to signal interruption (currently unused
                but included for interface consistency).

        Raises:
            Exception: Any exception that occurs during rollout execution will be
                re-raised to the caller.
        """
        store = self.get_store()

        if resources is not None:
            resources_update = await store.add_resources(resources)
            resources_id = resources_update.resources_id
        else:
            resources_id = None

        attempted_rollout = await self.get_store().start_rollout(input=input, mode=mode, resources_id=resources_id)
        await self._step_impl(attempted_rollout, raise_on_exception=True)

__init__(tracer, max_rollouts=None, poll_interval=5.0)

Initialize the agent runner.

Parameters:

Name Type Description Default
tracer BaseTracer

The tracer instance for recording execution traces and spans.

required
max_rollouts Optional[int]

Maximum number of tasks to process in iter() mode. If None, the runner will continue indefinitely until interrupted.

None
poll_interval float

Time in seconds to wait between polling attempts when no tasks are available in the store.

5.0
Source code in agentlightning/runner/agent.py
def __init__(self, tracer: BaseTracer, max_rollouts: Optional[int] = None, poll_interval: float = 5.0) -> None:
    """Initialize the agent runner.

    Args:
        tracer: The tracer instance for recording execution traces and spans.
        max_rollouts: Maximum number of tasks to process in iter() mode. If None,
            the runner will continue indefinitely until interrupted.
        poll_interval: Time in seconds to wait between polling attempts when
            no tasks are available in the store.
    """
    super().__init__()
    self._tracer = tracer
    self._max_rollouts = max_rollouts
    self._poll_interval = poll_interval

    # Set later
    self._agent: Optional[LitAgent[T_task]] = None
    self._hooks: Sequence[Hook] = []
    self._store: Optional[LightningStore] = None
    self.worker_id: Optional[int] = None

get_agent()

Get the agent instance.

Returns:

Type Description
LitAgent[T_task]

The LitAgent instance managed by this runner.

Raises:

Type Description
ValueError

If the agent has not been initialized via init().

Source code in agentlightning/runner/agent.py
def get_agent(self) -> LitAgent[T_task]:
    """Get the agent instance.

    Returns:
        The LitAgent instance managed by this runner.

    Raises:
        ValueError: If the agent has not been initialized via init().
    """
    if self._agent is None:
        raise ValueError("Agent not initialized. Call init() first.")
    return self._agent

get_store()

Get the store instance.

Returns:

Type Description
LightningStore

The LightningStore instance for this worker.

Raises:

Type Description
ValueError

If the store has not been initialized via init_worker().

Source code in agentlightning/runner/agent.py
def get_store(self) -> LightningStore:
    """Get the store instance.

    Returns:
        The LightningStore instance for this worker.

    Raises:
        ValueError: If the store has not been initialized via init_worker().
    """
    if self._store is None:
        raise ValueError("Store not initialized. Call init_worker() first.")
    return self._store

get_worker_id()

Get the formatted worker ID string.

Returns:

Type Description
str

A formatted string like "Worker-0" if initialized, or "Worker-Unknown"

str

if the worker ID has not been set.

Source code in agentlightning/runner/agent.py
def get_worker_id(self) -> str:
    """Get the formatted worker ID string.

    Returns:
        A formatted string like "Worker-0" if initialized, or "Worker-Unknown"
        if the worker ID has not been set.
    """
    return f"Worker-{self.worker_id}" if self.worker_id is not None else "Worker-Unknown"

init(agent, *, hooks=None, **kwargs)

Initialize the runner with the agent.

This sets up the agent-runner relationship, registers hooks, and initializes the tracer.

Parameters:

Name Type Description Default
agent LitAgent[T_task]

The LitAgent instance to be managed by this runner.

required
hooks Optional[Sequence[Hook]]

Optional sequence of Hook objects to be called at various lifecycle stages (on_trace_start, on_trace_end, on_rollout_start, on_rollout_end).

None
**kwargs Any

Additional initialization arguments (currently unused).

{}
Source code in agentlightning/runner/agent.py
def init(self, agent: LitAgent[T_task], *, hooks: Optional[Sequence[Hook]] = None, **kwargs: Any) -> None:
    """Initialize the runner with the agent.

    This sets up the agent-runner relationship, registers hooks, and
    initializes the tracer.

    Args:
        agent: The LitAgent instance to be managed by this runner.
        hooks: Optional sequence of Hook objects to be called at various
            lifecycle stages (on_trace_start, on_trace_end, on_rollout_start,
            on_rollout_end).
        **kwargs: Additional initialization arguments (currently unused).
    """
    self._agent = agent
    self._agent.set_runner(self)
    self._hooks = [*hooks] if hooks is not None else []

    self._tracer.init()

init_worker(worker_id, store, **kwargs)

Initialize the runner for each worker with worker_id and store.

This method is called once per worker in a distributed setup to provide the worker with its ID and store connection.

Parameters:

Name Type Description Default
worker_id int

Unique identifier for this worker process.

required
store LightningStore

The LightningStore instance for task coordination and data persistence.

required
**kwargs Any

Additional worker-specific initialization arguments (currently unused).

{}
Source code in agentlightning/runner/agent.py
def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> None:
    """Initialize the runner for each worker with worker_id and store.

    This method is called once per worker in a distributed setup to provide
    the worker with its ID and store connection.

    Args:
        worker_id: Unique identifier for this worker process.
        store: The LightningStore instance for task coordination and data persistence.
        **kwargs: Additional worker-specific initialization arguments (currently unused).
    """
    self._store = store
    self.worker_id = worker_id

    self._tracer.init_worker(worker_id)

iter(*, event=None) async

Run the runner, continuously iterating over tasks in the store.

This method polls the store for new rollouts and executes them until: - The event is set (if provided) - The max_rollouts limit is reached (if configured) - No more tasks are available

All exceptions during rollout execution are caught and logged but not propagated, allowing the runner to continue processing subsequent tasks.

Parameters:

Name Type Description Default
event Optional[Event]

Optional Event object to signal the runner to stop. The runner will check this event periodically and stop gracefully when set.

None
Source code in agentlightning/runner/agent.py
async def iter(self, *, event: Optional[Event] = None) -> None:
    """Run the runner, continuously iterating over tasks in the store.

    This method polls the store for new rollouts and executes them until:
    - The event is set (if provided)
    - The max_rollouts limit is reached (if configured)
    - No more tasks are available

    All exceptions during rollout execution are caught and logged but not
    propagated, allowing the runner to continue processing subsequent tasks.

    Args:
        event: Optional Event object to signal the runner to stop. The runner
            will check this event periodically and stop gracefully when set.
    """
    num_tasks_processed = 0
    logger.info(f"{self._log_prefix()} Started async rollouts (max: {self._max_rollouts or 'unlimited'}).")
    store = self.get_store()

    while not (event is not None and event.is_set()) and (
        self._max_rollouts is None or num_tasks_processed < self._max_rollouts
    ):
        # Retrieve the next rollout
        next_rollout: Optional[RolloutV2] = None
        while not (event is not None and event.is_set()):
            logger.debug(f"{self._log_prefix()} Try to poll for next rollout.")
            next_rollout = await store.dequeue_rollout()
            if next_rollout is None:
                logger.debug(f"{self._log_prefix()} No rollout to poll. Waiting for {self._poll_interval} seconds.")
                await self._sleep_until_next_poll(event)
            else:
                break

        if next_rollout is None:
            return

        try:
            # Claim the rollout but updating the current worker id
            await store.update_attempt(
                next_rollout.rollout_id, next_rollout.attempt.attempt_id, worker_id=self.get_worker_id()
            )
        except Exception:
            # This exception could happen if the rollout is dequeued and the other end died for some reason
            logger.exception(f"{self._log_prefix()} Exception during update_attempt, giving up the rollout.")
            continue

        # Execute the step
        await self._step_impl(next_rollout)

        num_tasks_processed += 1
        if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
            logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self._max_rollouts or 'unlimited'}")

    logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")

step(input, *, resources=None, mode=None, event=None) async

Execute a single task directly, bypassing the task queue.

This method creates a new rollout for the given input and executes it immediately. Unlike iter(), exceptions are propagated to the caller.

Parameters:

Name Type Description Default
input T_task

The task input to be processed by the agent.

required
resources Optional[NamedResources]

Optional named resources to be used for this specific task. If provided, a new resources entry will be created in the store. If not provided, the latest resources from the store will be used.

None
mode Optional[RolloutMode]

Optional rollout mode ("train" or "validation"). If not provided, the agent's default mode will be used.

None
event Optional[Event]

Optional Event object to signal interruption (currently unused but included for interface consistency).

None

Raises:

Type Description
Exception

Any exception that occurs during rollout execution will be re-raised to the caller.

Source code in agentlightning/runner/agent.py
async def step(
    self,
    input: T_task,
    *,
    resources: Optional[NamedResources] = None,
    mode: Optional[RolloutMode] = None,
    event: Optional[Event] = None,
) -> None:
    """Execute a single task directly, bypassing the task queue.

    This method creates a new rollout for the given input and executes it
    immediately. Unlike iter(), exceptions are propagated to the caller.

    Args:
        input: The task input to be processed by the agent.
        resources: Optional named resources to be used for this specific task.
            If provided, a new resources entry will be created in the store.
            If not provided, the latest resources from the store will be used.
        mode: Optional rollout mode ("train" or "validation"). If not provided,
            the agent's default mode will be used.
        event: Optional Event object to signal interruption (currently unused
            but included for interface consistency).

    Raises:
        Exception: Any exception that occurs during rollout execution will be
            re-raised to the caller.
    """
    store = self.get_store()

    if resources is not None:
        resources_update = await store.add_resources(resources)
        resources_id = resources_update.resources_id
    else:
        resources_id = None

    attempted_rollout = await self.get_store().start_rollout(input=input, mode=mode, resources_id=resources_id)
    await self._step_impl(attempted_rollout, raise_on_exception=True)

teardown(*args, **kwargs)

Teardown the runner and clean up all resources.

This method resets all internal state including the agent, store, hooks, and worker ID, and calls the tracer's teardown method.

Parameters:

Name Type Description Default
*args Any

Additional teardown arguments (currently unused).

()
**kwargs Any

Additional teardown keyword arguments (currently unused).

{}
Source code in agentlightning/runner/agent.py
def teardown(self, *args: Any, **kwargs: Any) -> None:
    """Teardown the runner and clean up all resources.

    This method resets all internal state including the agent, store,
    hooks, and worker ID, and calls the tracer's teardown method.

    Args:
        *args: Additional teardown arguments (currently unused).
        **kwargs: Additional teardown keyword arguments (currently unused).
    """
    self._agent = None
    self._store = None
    self.worker_id = None
    self._hooks = []

    self._tracer.teardown()

teardown_worker(worker_id, *args, **kwargs)

Teardown the runner for a specific worker.

This method cleans up worker-specific resources and resets the worker ID.

Parameters:

Name Type Description Default
worker_id int

The unique identifier of the worker being torn down.

required
*args Any

Additional teardown arguments (currently unused).

()
**kwargs Any

Additional teardown keyword arguments (currently unused).

{}
Source code in agentlightning/runner/agent.py
def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
    """Teardown the runner for a specific worker.

    This method cleans up worker-specific resources and resets the worker ID.

    Args:
        worker_id: The unique identifier of the worker being torn down.
        *args: Additional teardown arguments (currently unused).
        **kwargs: Additional teardown keyword arguments (currently unused).
    """
    self.worker_id = None

    self._tracer.teardown_worker(worker_id)

BaseRunner

Bases: ParallelWorkerBase, Generic[T_task]

Base class for all runners.

This abstract base class defines the interface that all runner implementations must follow. Runners are responsible for executing agent tasks, managing the execution lifecycle, and coordinating with the store.

Source code in agentlightning/runner/base.py
class BaseRunner(ParallelWorkerBase, Generic[T_task]):
    """Base class for all runners.

    This abstract base class defines the interface that all runner implementations
    must follow. Runners are responsible for executing agent tasks, managing the
    execution lifecycle, and coordinating with the store.
    """

    def init(self, agent: LitAgent[T_task], **kwargs: Any) -> None:
        """Initialize the runner with the agent.

        This method is called once during setup to configure the runner with
        the agent it will execute.

        Args:
            agent: The LitAgent instance to be managed by this runner.
            **kwargs: Additional initialization arguments specific to the runner implementation.

        Raises:
            NotImplementedError: Must be implemented by subclasses.
        """
        raise NotImplementedError()

    def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> None:
        """Initialize the runner for each worker with worker_id and store.

        This method is called once per worker process in a distributed setup.
        It provides the worker with its unique ID and the store instance for
        task coordination.

        Args:
            worker_id: Unique identifier for this worker process.
            store: The LightningStore instance for task coordination and data persistence.
            **kwargs: Additional worker-specific initialization arguments.

        Raises:
            NotImplementedError: Must be implemented by subclasses.
        """
        raise NotImplementedError()

    def run(self, *args: Any, **kwargs: Any) -> None:
        """Undefined method - use iter() or step() instead.

        This method is intentionally not implemented as the execution behavior
        should be defined through iter() for continuous execution or step()
        for single-task execution.

        Args:
            *args: Unused positional arguments.
            **kwargs: Unused keyword arguments.

        Raises:
            RuntimeError: Always raised to indicate this method should not be used.
        """
        raise RuntimeError("The behavior of run() of Runner is undefined. Use iter() or step() instead.")

    def teardown(self, *args: Any, **kwargs: Any) -> None:
        """Clean up runner resources and reset state.

        This method is called once during shutdown to clean up any resources
        allocated during initialization and reset the runner state.

        Args:
            *args: Additional teardown arguments.
            **kwargs: Additional teardown keyword arguments.

        Raises:
            NotImplementedError: Must be implemented by subclasses.
        """
        raise NotImplementedError()

    def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
        """Clean up worker-specific resources.

        This method is called once per worker during shutdown to clean up
        any resources specific to that worker.

        Args:
            worker_id: The unique identifier of the worker being torn down.
            *args: Additional teardown arguments.
            **kwargs: Additional teardown keyword arguments.

        Raises:
            NotImplementedError: Must be implemented by subclasses.
        """
        raise NotImplementedError()

    @contextmanager
    def run_context(
        self, *, agent: LitAgent[T_task], store: LightningStore, hooks: Optional[Sequence[Hook]] = None
    ) -> Iterator[BaseRunner[T_task]]:
        """Context manager for quickly init and teardown the runner,
        so that you can debug the runner without a trainer environment.

        Args:
            agent: The LitAgent instance to be managed by this runner.
                   It should be the same agent that is to be run within the context.
            store: The LightningStore instance for task coordination and data persistence.
                   If you don't have one, you can easily create one with `InMemoryLightningStore()`.
            hooks: Optional sequence of Hook instances to be used by the runner.
                   Only some runners support hooks.
        """
        _initialized: bool = False
        _worker_initialized: bool = False
        try:
            self.init(agent=agent, hooks=hooks)
            _initialized = True
            self.init_worker(worker_id=0, store=store)
            _worker_initialized = True
            yield self
        finally:
            try:
                if _worker_initialized:
                    self.teardown_worker(worker_id=0)
            except Exception:
                logger.error("Error during runner worker teardown", exc_info=True)

            try:
                if _initialized:
                    self.teardown()
            except Exception:
                logger.error("Error during runner teardown", exc_info=True)

    async def iter(self, *, event: Optional[Event] = None) -> None:
        """Run the runner, continuously iterating over tasks in the store.

        This method runs in a loop, polling the store for new tasks and executing
        them until interrupted by the event or when no more tasks are available.

        Args:
            event: Optional Event object that can be used to signal the runner
                to stop gracefully. When set, the runner should finish its current
                task and exit the iteration loop.

        Raises:
            NotImplementedError: Must be implemented by subclasses.
        """
        raise NotImplementedError()

    async def step(
        self,
        input: T_task,
        *,
        resources: Optional[NamedResources] = None,
        mode: Optional[RolloutMode] = None,
        event: Optional[Event] = None,
    ) -> None:
        """Execute a single task with the given input.

        This method provides fine-grained control for executing individual tasks
        directly, bypassing the store's task queue.

        Args:
            input: The task input to be processed by the agent.
            resources: Optional named resources to be used for this specific task.
                If not provided, the latest resources from the store will be used.
            mode: Optional rollout mode (e.g., "train", "test"). If not provided,
                the default mode will be used.
            event: Optional Event object to signal interruption. When set, the
                runner may abort the current execution.

        Raises:
            NotImplementedError: Must be implemented by subclasses.
        """
        raise NotImplementedError()

init(agent, **kwargs)

Initialize the runner with the agent.

This method is called once during setup to configure the runner with the agent it will execute.

Parameters:

Name Type Description Default
agent LitAgent[T_task]

The LitAgent instance to be managed by this runner.

required
**kwargs Any

Additional initialization arguments specific to the runner implementation.

{}

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses.

Source code in agentlightning/runner/base.py
def init(self, agent: LitAgent[T_task], **kwargs: Any) -> None:
    """Initialize the runner with the agent.

    This method is called once during setup to configure the runner with
    the agent it will execute.

    Args:
        agent: The LitAgent instance to be managed by this runner.
        **kwargs: Additional initialization arguments specific to the runner implementation.

    Raises:
        NotImplementedError: Must be implemented by subclasses.
    """
    raise NotImplementedError()

init_worker(worker_id, store, **kwargs)

Initialize the runner for each worker with worker_id and store.

This method is called once per worker process in a distributed setup. It provides the worker with its unique ID and the store instance for task coordination.

Parameters:

Name Type Description Default
worker_id int

Unique identifier for this worker process.

required
store LightningStore

The LightningStore instance for task coordination and data persistence.

required
**kwargs Any

Additional worker-specific initialization arguments.

{}

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses.

Source code in agentlightning/runner/base.py
def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> None:
    """Initialize the runner for each worker with worker_id and store.

    This method is called once per worker process in a distributed setup.
    It provides the worker with its unique ID and the store instance for
    task coordination.

    Args:
        worker_id: Unique identifier for this worker process.
        store: The LightningStore instance for task coordination and data persistence.
        **kwargs: Additional worker-specific initialization arguments.

    Raises:
        NotImplementedError: Must be implemented by subclasses.
    """
    raise NotImplementedError()

iter(*, event=None) async

Run the runner, continuously iterating over tasks in the store.

This method runs in a loop, polling the store for new tasks and executing them until interrupted by the event or when no more tasks are available.

Parameters:

Name Type Description Default
event Optional[Event]

Optional Event object that can be used to signal the runner to stop gracefully. When set, the runner should finish its current task and exit the iteration loop.

None

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses.

Source code in agentlightning/runner/base.py
async def iter(self, *, event: Optional[Event] = None) -> None:
    """Run the runner, continuously iterating over tasks in the store.

    This method runs in a loop, polling the store for new tasks and executing
    them until interrupted by the event or when no more tasks are available.

    Args:
        event: Optional Event object that can be used to signal the runner
            to stop gracefully. When set, the runner should finish its current
            task and exit the iteration loop.

    Raises:
        NotImplementedError: Must be implemented by subclasses.
    """
    raise NotImplementedError()

run(*args, **kwargs)

Undefined method - use iter() or step() instead.

This method is intentionally not implemented as the execution behavior should be defined through iter() for continuous execution or step() for single-task execution.

Parameters:

Name Type Description Default
*args Any

Unused positional arguments.

()
**kwargs Any

Unused keyword arguments.

{}

Raises:

Type Description
RuntimeError

Always raised to indicate this method should not be used.

Source code in agentlightning/runner/base.py
def run(self, *args: Any, **kwargs: Any) -> None:
    """Undefined method - use iter() or step() instead.

    This method is intentionally not implemented as the execution behavior
    should be defined through iter() for continuous execution or step()
    for single-task execution.

    Args:
        *args: Unused positional arguments.
        **kwargs: Unused keyword arguments.

    Raises:
        RuntimeError: Always raised to indicate this method should not be used.
    """
    raise RuntimeError("The behavior of run() of Runner is undefined. Use iter() or step() instead.")

run_context(*, agent, store, hooks=None)

Context manager for quickly init and teardown the runner, so that you can debug the runner without a trainer environment.

Parameters:

Name Type Description Default
agent LitAgent[T_task]

The LitAgent instance to be managed by this runner. It should be the same agent that is to be run within the context.

required
store LightningStore

The LightningStore instance for task coordination and data persistence. If you don't have one, you can easily create one with InMemoryLightningStore().

required
hooks Optional[Sequence[Hook]]

Optional sequence of Hook instances to be used by the runner. Only some runners support hooks.

None
Source code in agentlightning/runner/base.py
@contextmanager
def run_context(
    self, *, agent: LitAgent[T_task], store: LightningStore, hooks: Optional[Sequence[Hook]] = None
) -> Iterator[BaseRunner[T_task]]:
    """Context manager for quickly init and teardown the runner,
    so that you can debug the runner without a trainer environment.

    Args:
        agent: The LitAgent instance to be managed by this runner.
               It should be the same agent that is to be run within the context.
        store: The LightningStore instance for task coordination and data persistence.
               If you don't have one, you can easily create one with `InMemoryLightningStore()`.
        hooks: Optional sequence of Hook instances to be used by the runner.
               Only some runners support hooks.
    """
    _initialized: bool = False
    _worker_initialized: bool = False
    try:
        self.init(agent=agent, hooks=hooks)
        _initialized = True
        self.init_worker(worker_id=0, store=store)
        _worker_initialized = True
        yield self
    finally:
        try:
            if _worker_initialized:
                self.teardown_worker(worker_id=0)
        except Exception:
            logger.error("Error during runner worker teardown", exc_info=True)

        try:
            if _initialized:
                self.teardown()
        except Exception:
            logger.error("Error during runner teardown", exc_info=True)

step(input, *, resources=None, mode=None, event=None) async

Execute a single task with the given input.

This method provides fine-grained control for executing individual tasks directly, bypassing the store's task queue.

Parameters:

Name Type Description Default
input T_task

The task input to be processed by the agent.

required
resources Optional[NamedResources]

Optional named resources to be used for this specific task. If not provided, the latest resources from the store will be used.

None
mode Optional[RolloutMode]

Optional rollout mode (e.g., "train", "test"). If not provided, the default mode will be used.

None
event Optional[Event]

Optional Event object to signal interruption. When set, the runner may abort the current execution.

None

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses.

Source code in agentlightning/runner/base.py
async def step(
    self,
    input: T_task,
    *,
    resources: Optional[NamedResources] = None,
    mode: Optional[RolloutMode] = None,
    event: Optional[Event] = None,
) -> None:
    """Execute a single task with the given input.

    This method provides fine-grained control for executing individual tasks
    directly, bypassing the store's task queue.

    Args:
        input: The task input to be processed by the agent.
        resources: Optional named resources to be used for this specific task.
            If not provided, the latest resources from the store will be used.
        mode: Optional rollout mode (e.g., "train", "test"). If not provided,
            the default mode will be used.
        event: Optional Event object to signal interruption. When set, the
            runner may abort the current execution.

    Raises:
        NotImplementedError: Must be implemented by subclasses.
    """
    raise NotImplementedError()

teardown(*args, **kwargs)

Clean up runner resources and reset state.

This method is called once during shutdown to clean up any resources allocated during initialization and reset the runner state.

Parameters:

Name Type Description Default
*args Any

Additional teardown arguments.

()
**kwargs Any

Additional teardown keyword arguments.

{}

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses.

Source code in agentlightning/runner/base.py
def teardown(self, *args: Any, **kwargs: Any) -> None:
    """Clean up runner resources and reset state.

    This method is called once during shutdown to clean up any resources
    allocated during initialization and reset the runner state.

    Args:
        *args: Additional teardown arguments.
        **kwargs: Additional teardown keyword arguments.

    Raises:
        NotImplementedError: Must be implemented by subclasses.
    """
    raise NotImplementedError()

teardown_worker(worker_id, *args, **kwargs)

Clean up worker-specific resources.

This method is called once per worker during shutdown to clean up any resources specific to that worker.

Parameters:

Name Type Description Default
worker_id int

The unique identifier of the worker being torn down.

required
*args Any

Additional teardown arguments.

()
**kwargs Any

Additional teardown keyword arguments.

{}

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses.

Source code in agentlightning/runner/base.py
def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
    """Clean up worker-specific resources.

    This method is called once per worker during shutdown to clean up
    any resources specific to that worker.

    Args:
        worker_id: The unique identifier of the worker being torn down.
        *args: Additional teardown arguments.
        **kwargs: Additional teardown keyword arguments.

    Raises:
        NotImplementedError: Must be implemented by subclasses.
    """
    raise NotImplementedError()

agentlightning.trainer

Trainer

Bases: ParallelWorkerBase

Orchestrates the distributed execution of agent rollouts.

The Trainer is responsible for launching one or more worker processes that run the agent's execution loop. It manages multiprocessing, handles graceful shutdown, and serves as the main entry point for running a client-side agent fleet.

Attributes:

Name Type Description
algorithm

An instance of BaseAlgorithm to use for training.

store

An instance of LightningStore to use for storing tasks and traces.

runner

An instance of BaseRunner to use for running the agent.

initial_resources

An instance of Resources to use for bootstrapping the fit/dev process. The resources will be handed over to the algorithm. Note that not all algorithms support seeding resources.

n_runners

Number of agent runners to run in parallel.

max_rollouts

Maximum number of rollouts to process per runner. If None, workers run until no more rollouts are available.

strategy

An instance of ExecutionStrategy to use for spawning the algorithm and runners.

tracer

A tracer instance, or a string pointing to the class full name or a dictionary with a 'type' key that specifies the class full name and other initialization parameters. If None, a default AgentOpsTracer will be created with the current settings.

hooks

A sequence of Hook instances to be called at various lifecycle stages (e.g., on_trace_start, on_trace_end, on_rollout_start, on_rollout_end).

adapter

An instance of TraceTripletAdapter to export data consumble by algorithms from traces.

llm_proxy

An instance of LLMProxy to use for intercepting the LLM calls. If not provided, algorithm will create one on its own.

n_workers

Number of agent workers to run in parallel. Deprecated in favor of n_runners.

max_tasks

Maximum number of tasks to process per runner. Deprecated in favor of max_rollouts.

daemon

Whether worker processes should be daemons. Daemon processes are terminated automatically when the main process exits. Deprecated. Only have effect with fit_v0.

triplet_exporter

An instance of TraceTripletAdapter to export triplets from traces, or a dictionary with the initialization parameters for the exporter. Deprecated. Use adapter instead.

dev None

If True, rollouts are run against the dev endpoint provided in fit. Deprecated in favor of dev() method.

Source code in agentlightning/trainer/trainer.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
class Trainer(ParallelWorkerBase):
    """Orchestrates the distributed execution of agent rollouts.

    The Trainer is responsible for launching one or more worker processes
    that run the agent's execution loop. It manages multiprocessing,
    handles graceful shutdown, and serves as the main entry point for
    running a client-side agent fleet.

    Attributes:
        algorithm: An instance of `BaseAlgorithm` to use for training.
        store: An instance of `LightningStore` to use for storing tasks and traces.
        runner: An instance of `BaseRunner` to use for running the agent.
        initial_resources: An instance of `Resources` to use for bootstrapping the fit/dev process.
            The resources will be handed over to the algorithm.
            Note that not all algorithms support seeding resources.
        n_runners: Number of agent runners to run in parallel.
        max_rollouts: Maximum number of rollouts to process per runner. If None,
                      workers run until no more rollouts are available.
        strategy: An instance of `ExecutionStrategy` to use for spawning the algorithm and runners.
        tracer: A tracer instance, or a string pointing to the class full name or a dictionary with a 'type' key
                that specifies the class full name and other initialization parameters.
                If None, a default `AgentOpsTracer` will be created with the current settings.
        hooks: A sequence of `Hook` instances to be called at various lifecycle stages (e.g., on_trace_start,
               on_trace_end, on_rollout_start, on_rollout_end).
        adapter: An instance of `TraceTripletAdapter` to export data consumble by algorithms from traces.
        llm_proxy: An instance of `LLMProxy` to use for intercepting the LLM calls.
                   If not provided, algorithm will create one on its own.
        n_workers: Number of agent workers to run in parallel. Deprecated in favor of `n_runners`.
        max_tasks: Maximum number of tasks to process per runner. Deprecated in favor of `max_rollouts`.
        daemon: Whether worker processes should be daemons. Daemon processes
                are terminated automatically when the main process exits. Deprecated.
                Only have effect with `fit_v0`.
        triplet_exporter: An instance of `TraceTripletAdapter` to export triplets from traces,
                          or a dictionary with the initialization parameters for the exporter.
                          Deprecated. Use `adapter` instead.
        dev: If True, rollouts are run against the dev endpoint provided in `fit`.
             Deprecated in favor of `dev()` method.
    """

    def __init__(
        self,
        *,
        dev: bool = False,
        n_runners: Optional[int] = None,
        max_rollouts: Optional[int] = None,
        initial_resources: Optional[NamedResources] = None,
        tracer: ComponentSpec[BaseTracer] = None,
        adapter: ComponentSpec[TraceAdapter[Any]] = None,
        store: ComponentSpec[LightningStore] = None,
        runner: ComponentSpec[BaseRunner[Any]] = None,
        strategy: ComponentSpec[ExecutionStrategy] = None,
        algorithm: ComponentSpec[BaseAlgorithm] = None,
        llm_proxy: ComponentSpec[LLMProxy] = None,
        n_workers: Optional[int] = None,
        max_tasks: Optional[int] = None,
        daemon: bool = True,
        triplet_exporter: ComponentSpec[TraceTripletAdapter] = None,
        hooks: Optional[Union[Hook, Sequence[Hook]]] = None,
    ):
        super().__init__()
        self._dev = dev
        self.daemon = daemon
        self._client: AgentLightningClient | None = None  # Will be initialized in fit or fit_v0

        if n_workers is not None:
            warnings.warn(
                "`n_workers` is deprecated. Please use `n_runners`.",
                DeprecationWarning,
                stacklevel=2,
            )

        if n_runners is None:
            n_runners = n_workers if n_workers is not None else 1
        else:
            if n_workers is not None and n_workers != n_runners:
                warnings.warn(
                    "`n_workers` is ignored when `n_runners` is provided.",
                    DeprecationWarning,
                    stacklevel=2,
                )

        self.n_runners = n_runners
        self.n_workers = n_runners  # Backwards compatibility for fit_v0

        if max_tasks is not None:
            warnings.warn(
                "`max_tasks` is deprecated. Please use `max_rollouts`.",
                DeprecationWarning,
                stacklevel=2,
            )

        if max_rollouts is None:
            max_rollouts = max_tasks
        elif max_tasks is not None and max_tasks != max_rollouts:
            warnings.warn(
                "`max_tasks` is ignored when `max_rollouts` is provided.",
                DeprecationWarning,
                stacklevel=2,
            )

        self.max_rollouts = max_rollouts
        self.max_tasks = max_tasks if max_tasks is not None else max_rollouts

        self.tracer = self._make_tracer(tracer)

        if adapter is not None and triplet_exporter is not None:
            warnings.warn(
                "`triplet_exporter` is deprecated and ignored because `adapter` is provided.",
                DeprecationWarning,
                stacklevel=2,
            )

        adapter_spec = adapter if adapter is not None else triplet_exporter
        self.adapter = self._make_adapter(adapter_spec)
        self.triplet_exporter = self.adapter  # Backwards compatibility

        self.algorithm = self._make_algorithm(algorithm)

        # We might be able to support a list of resources in future.
        self.initial_resources = initial_resources

        # The active store for the current execution context
        self.store = self._make_store(store)
        self.runner = self._make_runner(runner)

        self.strategy = self._make_strategy(strategy, n_runners=self.n_runners)
        if hasattr(self.strategy, "n_runners"):
            strategy_runners = getattr(self.strategy, "n_runners")
            if isinstance(strategy_runners, int) and strategy_runners > 0:
                self.n_runners = strategy_runners
                self.n_workers = strategy_runners

        self.llm_proxy = self._make_llm_proxy(llm_proxy, store=self.store)

        self.hooks = self._normalize_hooks(hooks)

        if not self.daemon:
            logger.warning(
                "daemon=False. Worker processes are non-daemonic. "
                "The worker processes will NOT be terminated when the main process exits. "
                "The cleanup must be handled manually."
            )

    def _make_tracer(self, tracer: ComponentSpec[BaseTracer]) -> BaseTracer:
        """Creates a tracer instance based on the provided configuration."""
        default_factory = lambda: AgentOpsTracer(
            agentops_managed=True,
            instrument_managed=True,
            daemon=self.daemon,
        )
        return build_component(
            tracer,
            expected_type=BaseTracer,
            spec_name="tracer",
            default_factory=default_factory,
            dict_requires_type=True,
            invalid_spec_error_fmt="Invalid tracer type: {actual_type}. Expected BaseTracer, str, dict, or None.",
            type_error_fmt="Tracer factory returned {type_name}, which is not a BaseTracer subclass.",
        )

    def _make_algorithm(self, algorithm: ComponentSpec[BaseAlgorithm]) -> Optional[BaseAlgorithm]:
        """Creates an algorithm instance based on the provided configuration."""
        return build_component(
            algorithm,
            expected_type=BaseAlgorithm,
            spec_name="algorithm",
            allow_none=True,
            invalid_spec_error_fmt="Invalid algorithm type: {actual_type}. Expected BaseAlgorithm, str, dict, or None.",
            type_error_fmt="Algorithm factory returned {type_name}, which is not a BaseAlgorithm subclass.",
        )

    def _make_adapter(self, adapter: ComponentSpec[TraceAdapter[Any]]) -> TraceAdapter[Any]:
        return build_component(
            adapter,
            expected_type=TraceAdapter,
            spec_name="adapter",
            default_factory=TraceTripletAdapter,
            dict_requires_type=False,
            dict_default_cls=TraceTripletAdapter,
            invalid_spec_error_fmt="Invalid adapter type: {actual_type}. Expected TraceAdapter, dict, or None.",
            type_error_fmt="Adapter factory returned {type_name}, which is not a TraceAdapter subclass.",
        )

    def _make_store(self, store: ComponentSpec[LightningStore]) -> LightningStore:
        return build_component(
            store,
            expected_type=LightningStore,
            spec_name="store",
            default_factory=InMemoryLightningStore,
            invalid_spec_error_fmt="Invalid store type: {actual_type}. Expected LightningStore, str, dict, or None.",
            type_error_fmt="Store factory returned {type_name}, which is not a LightningStore subclass.",
        )

    def _make_strategy(
        self,
        strategy: ComponentSpec[ExecutionStrategy],
        *,
        n_runners: int,
    ) -> ExecutionStrategy:
        if isinstance(strategy, ExecutionStrategy):
            return strategy
        optional_defaults: Dict[str, Callable[[], Any]] = {"n_runners": lambda: n_runners}

        def default_factory() -> ExecutionStrategy:
            return ClientServerExecutionStrategy(n_runners=n_runners, role="both")

        return build_component(
            strategy,
            expected_type=ExecutionStrategy,
            spec_name="strategy",
            default_factory=default_factory,
            optional_defaults=optional_defaults,
            invalid_spec_error_fmt="Invalid strategy type: {actual_type}. Expected ExecutionStrategy, str, dict, or None.",
            type_error_fmt="Strategy factory returned {type_name}, which is not an ExecutionStrategy subclass.",
            registry=ExecutionStrategyRegistry,
        )

    def _make_llm_proxy(
        self,
        llm_proxy: ComponentSpec[LLMProxy],
        *,
        store: LightningStore,
    ) -> Optional[LLMProxy]:
        if isinstance(llm_proxy, LLMProxy):
            return llm_proxy

        optional_defaults: Dict[str, Callable[[], Any]] = {"store": lambda: store}
        if isinstance(llm_proxy, dict):
            llm_proxy = {**llm_proxy}
            llm_proxy.setdefault("store", store)

        return build_component(
            llm_proxy,
            expected_type=LLMProxy,
            spec_name="llm_proxy",
            allow_none=True,
            optional_defaults=optional_defaults,
            invalid_spec_error_fmt="Invalid llm_proxy type: {actual_type}. Expected LLMProxy, dict, str, or None.",
            type_error_fmt="llm_proxy factory returned {type_name}, which is not an LLMProxy subclass.",
        )

    def _make_runner(self, runner: ComponentSpec[BaseRunner[Any]]) -> BaseRunner[Any]:
        optional_defaults: Dict[str, Callable[[], Any]] = {"tracer": lambda: self.tracer}
        if self.max_rollouts is not None:
            optional_defaults["max_rollouts"] = lambda: self.max_rollouts

        def default_runner_factory() -> BaseRunner[Any]:
            return instantiate_component(AgentRunnerV2, optional_defaults=optional_defaults)

        return build_component(
            runner,
            expected_type=BaseRunner,
            spec_name="runner",
            default_factory=default_runner_factory,
            optional_defaults=optional_defaults,
            invalid_spec_error_fmt="Invalid runner type: {actual_type}. Expected BaseRunner, callable, str, dict, or None.",
            type_error_fmt="Runner factory returned {type_name}, which is not a BaseRunner subclass.",
        )

    def _normalize_hooks(self, hooks: Optional[Union[Hook, Sequence[Hook]]]) -> Sequence[Hook]:
        if hooks is None:
            return ()
        if isinstance(hooks, Hook):
            return (hooks,)
        return tuple(hooks)

    def fit_v2(
        self,
        agent: LitAgent[T_co],
        train_dataset: Optional[Dataset[T_co]] = None,
        *,
        val_dataset: Optional[Dataset[T_co]] = None,
    ) -> None:
        """Run the training loop using the configured strategy, store, and runner.

        Args:
            agent: The LitAgent instance to be trained on.
            train_dataset: The dataset to train on.
            val_dataset: The dataset to validate on.
        """
        agent.set_trainer(self)

        algorithm_bundle = functools.partial(
            self._algorithm_bundle,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            algorithm=self.algorithm,
        )
        runner_bundle = functools.partial(self._runner_bundle, agent=agent)

        self.strategy.execute(algorithm_bundle, runner_bundle, self.store)

    def dev(
        self,
        agent: LitAgent[T_co],
        train_dataset: Optional[Dataset[T_co]] = None,
        *,
        val_dataset: Optional[Dataset[T_co]] = None,
    ) -> None:
        """Dry run the training loop with a FastAlgorithm and the real runner.

        Args:
            agent: The LitAgent instance to be trained on.
            train_dataset: The dataset to train on.
            val_dataset: The dataset to validate on.
        """
        agent.set_trainer(self)

        # Sanity check
        if self.algorithm is None:
            algorithm = MockAlgorithm()
        else:
            algorithm = self.algorithm

        algorithm_bundle = functools.partial(
            self._algorithm_bundle,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            algorithm=algorithm,
        )
        runner_bundle = functools.partial(self._runner_bundle, agent=agent)
        self.strategy.execute(algorithm_bundle, runner_bundle, self.store)

    async def _algorithm_bundle(
        self,
        store: LightningStore,
        event: Event,
        train_dataset: Optional[Dataset[T_co]],
        val_dataset: Optional[Dataset[T_co]],
        algorithm: Optional[BaseAlgorithm],
    ) -> None:
        if algorithm is not None:
            algorithm.set_trainer(self)
            algorithm.set_store(store)
            algorithm.set_adapter(self.adapter)
            if self.initial_resources is not None:
                algorithm.set_initial_resources(self.initial_resources)
            if self.llm_proxy is not None:
                self.llm_proxy.set_store(store)
                algorithm.set_llm_proxy(self.llm_proxy)

        if algorithm is None:
            while not event.is_set():
                await asyncio.sleep(0.1)
            return
        try:
            if algorithm.is_async():
                await algorithm.run(  # type: ignore
                    train_dataset=train_dataset,
                    val_dataset=val_dataset,
                )
            else:
                # This will block the event loop to maximize the debugging experience
                # It's the responsibility of the execution strategy to enable async execution
                algorithm.run(
                    train_dataset=train_dataset,
                    val_dataset=val_dataset,
                )
        except Exception:
            logger.exception("Algorithm bundle encountered an error.")
            raise

    async def _runner_bundle(self, store: LightningStore, worker_id: int, event: Event, agent: LitAgent[T_co]) -> None:
        runner_instance: BaseRunner[Any] | None = None
        runner_initialized = False
        worker_initialized = False
        try:
            # If not using shm execution strategy, we are already in the forked process
            runner_instance = self.runner
            runner_instance.init(agent=agent, hooks=self.hooks)
            runner_initialized = True
            runner_instance.init_worker(worker_id, store)
            worker_initialized = True
            await runner_instance.iter(event=event)
        except Exception:
            logger.exception("Runner bundle encountered an error (worker_id=%s).", worker_id)
            raise
        finally:
            if runner_instance is not None:
                if worker_initialized:
                    try:
                        runner_instance.teardown_worker(worker_id)
                    except Exception:
                        logger.exception("Error during runner worker teardown (worker_id=%s).", worker_id)
                if runner_initialized:
                    try:
                        runner_instance.teardown()
                    except Exception:
                        logger.exception("Error during runner teardown (worker_id=%s).", worker_id)

    def _extract_client_from_data(
        self, data: Union[str, AgentLightningClient, Dataset[Any]]
    ) -> Optional[AgentLightningClient]:
        """Extract client from data if it's a string URL or AgentLightningClient."""
        if isinstance(data, str):
            if not data.startswith("http://") and not data.startswith("https://"):
                raise ValueError("String data must be a valid URL starting with http:// or https://")
            return AgentLightningClient(endpoint=data)
        elif isinstance(data, AgentLightningClient):
            return data
        return None

    def _extract_dataset_from_data(
        self, data: Union[str, AgentLightningClient, Dataset[Any]]
    ) -> Optional[Dataset[Any]]:
        """Extract dataset from data if it's a Dataset."""
        if isinstance(data, str) or isinstance(data, AgentLightningClient):
            return None
        return data

    def _determine_backend(
        self,
        train_data: Union[str, AgentLightningClient, Dataset[Any]],
        dev_data: Union[str, AgentLightningClient, Dataset[Any], None] = None,
    ) -> Union[str, AgentLightningClient]:
        """Determine which backend to use for initialization."""
        if self._dev:
            if dev_data is None:
                raise ValueError("dev_data must be provided when dev=True.")
            client = self._extract_client_from_data(dev_data)
            if client is None:
                raise ValueError("dev_data must be a string URL or AgentLightningClient when dev=True.")
            return client
        else:
            client = self._extract_client_from_data(train_data)
            if client is None and self.algorithm is None:
                raise ValueError(
                    "train_data must be a string URL or AgentLightningClient when no algorithm is provided."
                )
            elif client is None and self.algorithm is not None:
                # Algorithm will be responsible for creating the client
                client = self.algorithm.get_client()
                logger.info(f"Algorithm created client: {client}")
                return client
            if client is None:
                raise ValueError(
                    "train_data must be a string URL or AgentLightningClient when no algorithm is provided."
                )
            return client

    def init(self, backend: Union[str, AgentLightningClient]) -> None:
        logger.info(f"Initializing Trainer...")

        self._init_client(backend)

        self.tracer.init()

        logger.info(f"Trainer main initialization complete.")

    def teardown(self) -> None:
        logger.info(f"Cleaning up Trainer...")
        self.tracer.teardown()

        self._client = None
        logger.info(f"Trainer main cleanup complete.")

    def client(self) -> AgentLightningClient:
        """Returns the AgentLightningClient instance."""
        if self._client is None:
            raise RuntimeError("AgentLightningClient has not been initialized. Call `init` first.")
        return self._client

    def _init_client(self, backend: Union[str, AgentLightningClient]) -> AgentLightningClient:
        if self._client is None:
            if isinstance(backend, AgentLightningClient):
                logger.info("Using provided AgentLightningClient instance.")
                self._client = backend
            else:
                logger.info(f"Initializing AgentLightningClient with endpoint: {backend}")
                if not isinstance(backend, str):  # type: ignore
                    raise ValueError("backend must be a string URL or an AgentLightningClient instance.")
                if not backend.startswith("http://") and not backend.startswith("https://"):
                    raise ValueError("backend must be a valid URL starting with http:// or https://")
                # Initialize the client with the provided backend URL
                self._client = AgentLightningClient(endpoint=backend)
        else:
            logger.warning("AgentLightningClient already initialized. Returning existing instance.")
        return self._client

    def _worker_main_loop(self, agent: LitAgent[Any], worker_id: int, is_async: bool):
        """The main function for each worker process.

        This function initializes the client and the loop, then starts the
        execution. It also configures process-specific settings like the
        process title and signal handling.

        Args:
            agent: The `LitAgent` instance to run.
            worker_id: The unique ID for this worker.
            is_async: A boolean indicating if the async loop should be run.
        """
        if self.n_workers > 1:
            import setproctitle

            # Ignore Ctrl+C in worker processes; the main process handles it
            signal.signal(signal.SIGINT, signal.SIG_IGN)
            setproctitle.setproctitle(multiprocessing.current_process().name)

        # Now we are in child processes, so we can safely set up the environment.
        agent.set_trainer(self)
        if not isinstance(self.triplet_exporter, TraceTripletAdapter):
            raise ValueError("triplet_exporter must be a TraceTripletAdapter for the legacy trainer.")
        # TODO: this should be set elsewhere
        if agent.trained_agents:
            self.triplet_exporter.agent_match = agent.trained_agents
        self._initialize_worker_env(worker_id)

        mode = "Async" if is_async else "Sync"
        logger.info(f"[Worker {worker_id}] {mode} worker process started.")

        num_processed = 0

        try:
            client = self.client()
            loop = AgentRunner(
                agent=agent,
                client=client,
                tracer=self.tracer,
                triplet_exporter=self.triplet_exporter,
                max_tasks=self.max_tasks,
                worker_id=worker_id,
            )
            loop.init_worker(worker_id)  # type: ignore
            if is_async:
                num_processed = asyncio.run(loop.iter_async())
            else:
                num_processed = loop.iter()
        except Exception:
            logger.exception(f"[Worker {worker_id}] Unhandled exception in worker loop.")
        finally:
            self._teardown_worker_env(worker_id)

        return num_processed

    def _initialize_worker_env(self, worker_id: int):
        logger.info(f"[Worker {worker_id}] Setting up trainer environment...")  # worker_id included in process name
        self.tracer.init_worker(worker_id)

    def _teardown_worker_env(self, worker_id: int):
        logger.info(f"[Worker {worker_id}] Cleaning up trainer environment...")
        self.tracer.teardown_worker(worker_id)
        logger.info(f"[Worker {worker_id}] Environment cleanup complete.")

    @staticmethod
    def kill_orphaned_processes() -> None:
        """
        Kill any orphaned processes that may have been left behind by previous runs.
        This is useful for cleaning up after crashes or unexpected exits.
        """
        import psutil

        for proc in psutil.process_iter():  # type: ignore
            # check whether the process name matches
            if proc.name().startswith("AgentLightning-"):
                proc.kill()

    def _terminate_processes(self, processes: List[multiprocessing.Process]) -> None:
        if self.n_workers > 1 and len(processes) > 0:
            for i, p in enumerate(processes):
                if p.is_alive():
                    logger.info(f"Terminating worker {i} (name: {p.name}, PID: {p.pid})...")
                    p.terminate()
                else:
                    logger.info(f"Worker {i} (name: {p.name}, PID: {p.pid}) is not alive or has already terminated.")
            for i, p in enumerate(processes):
                if p.is_alive():
                    p.join(timeout=10)  # Give some time to terminate
                if p.is_alive():  # If still alive, kill
                    logger.warning(
                        f"Worker {i} (name: {p.name}, PID: {p.pid}) did not terminate gracefully, killing..."
                    )
                    p.kill()
                    p.join(timeout=10)  # Ensure it's reaped

    def fit(
        self,
        agent: LitAgent[T_co],
        train_data: Union[str, AgentLightningClient, Dataset[T_co]],
        *,
        val_data: Union[str, AgentLightningClient, Dataset[T_co], None] = None,
        dev_data: Union[str, AgentLightningClient, Dataset[T_co], None] = None,
        dev_backend: Union[str, AgentLightningClient, None] = None,
    ):
        """Train the agent using the provided data.

        Each data argument can be a string URL connecting to a agent-lightning server,
        or an AgentLightningClient instance connecting to a server (or mock server), or a dataset.
        If no algorithm is provided when instantiating the trainer, the data must be
        provided to connecting a server. Otherwise, dataset is also allowed and will be
        passed to the algorithm.

        If the algorithm is instantiated and there is no URL/client provided,
        the algorithm will be responsible for creating a client that will connect to itself.
        It can also create a mock client if the algorithm does not require a server.
        """

        if dev_backend is not None:
            warnings.warn("dev_backend is deprecated. Use dev_data instead.")
            if dev_data is not None:
                raise ValueError("dev_data and dev_backend cannot be provided at the same time.")
            dev_data = dev_backend

        # Extract datasets for algorithm if available
        train_dataset = self._extract_dataset_from_data(train_data)
        val_dataset = self._extract_dataset_from_data(val_data) if val_data else None

        # Initialize the algorithm with trainer if provided
        if self.algorithm is not None:
            self.algorithm.set_trainer(self)
            # DO NOT RUN TRAINING HERE. Need to spawn the worker first.

        # Determine the backend to use for client-server mode
        backend = self._determine_backend(train_data, dev_data)

        if self._dev:
            logger.warning(f"Running in dev mode. Using dev backend: {backend}")
        else:
            logger.debug(f"Running in non-dev mode. Using backend: {backend}")

        self.init(backend)

        processes: List[multiprocessing.Process] = []

        # Determine if the agent is asynchronous

        mode = "asynchronous" if agent.is_async() else "synchronous"

        try:
            if self.n_workers == 1:
                logger.info(f"Running with n_workers=1 ({mode} in main process).")

                # Warn if algorithm is set with single worker mode
                if self.algorithm is not None:
                    logger.warning(
                        "Algorithm is set but using single worker mode. Algorithm will never get the chance to run."
                    )
                    # Ideally the single worker should be run in a separate thread or process.

                num_tasks = self._worker_main_loop(agent, 0, agent.is_async())
                logger.info(f"Single worker mode finished. Tasks processed: {num_tasks}")

                # If algorithm is provided and we have datasets, run algorithm after worker completes
                if self.algorithm is not None and train_dataset is not None:
                    logger.info("Running algorithm training after worker completion.")
                    self.algorithm.run(
                        train_dataset=train_dataset,
                        val_dataset=val_dataset,
                    )
            else:
                logger.info(f"Running with n_workers={self.n_workers} ({mode} multiprocessing).")
                for i in range(self.n_workers):
                    process_name = f"AgentLightning-Worker-{i}"
                    p = multiprocessing.Process(
                        target=self._worker_main_loop,
                        args=(agent, i, agent.is_async()),
                        daemon=self.daemon,
                        name=process_name,
                    )
                    processes.append(p)
                    logger.info(f"Starting worker process {i} (name: {process_name})...")
                    p.start()

                if self.daemon:
                    # If algorithm is provided and we have datasets, pass them to the algorithm
                    if self.algorithm is not None:
                        logger.info("All workers have been spawned. Running algorithm training with provided datasets.")
                        self.algorithm.run(
                            train_dataset=train_dataset,
                            val_dataset=val_dataset,
                        )
                        logger.info("Algorithm exits. Killing the workers.")
                        self._terminate_processes(processes)

                    for i, p in enumerate(processes):
                        p.join()  # Wait for the process to complete
                        logger.info(
                            f"Worker process {i} (name: {p.name}, PID: {p.pid}) joined with exit code {p.exitcode}."
                        )
                        if p.exitcode != 0:
                            logger.warning(
                                f"Worker process {i} (name: {p.name}, PID: {p.pid}) exited with non-zero code: {p.exitcode}."
                            )

                    logger.info(f"All {self.n_workers} worker processes have completed.")
                else:
                    logger.info("All worker processes started. Main process will not wait.")

                    # A hack to stop the main process from waiting for child processes to finish.
                    time.sleep(1)  # Give workers time to start
                    import multiprocessing.process as multiprocessing_process

                    multiprocessing_process._children.clear()  # type: ignore

                    if self.algorithm is not None:
                        logger.info("Main process continues to run algorithm.")
                        self.algorithm.run(
                            train_dataset=train_dataset,
                            val_dataset=val_dataset,
                        )
                        logger.info("Algorithm exits. Killing the workers.")
                        self._terminate_processes(processes)

        except KeyboardInterrupt:
            logger.info("KeyboardInterrupt received. Killing the workers.")
            self._terminate_processes(processes)
            logger.info(f"Workers terminated or single worker interrupted.")
            raise
        except Exception:
            logger.exception(f"Unhandled exception in fit method.")
            self._terminate_processes(processes)
            logger.info(f"Workers terminated or single worker interrupted.")
            raise
        finally:
            if self.daemon:
                self.teardown()
            else:
                logger.info("Main process exiting. Please use Trainer.kill_orphaned_processes() for cleanup.")

client()

Returns the AgentLightningClient instance.

Source code in agentlightning/trainer/trainer.py
def client(self) -> AgentLightningClient:
    """Returns the AgentLightningClient instance."""
    if self._client is None:
        raise RuntimeError("AgentLightningClient has not been initialized. Call `init` first.")
    return self._client

dev(agent, train_dataset=None, *, val_dataset=None)

Dry run the training loop with a FastAlgorithm and the real runner.

Parameters:

Name Type Description Default
agent LitAgent[T_co]

The LitAgent instance to be trained on.

required
train_dataset Optional[Dataset[T_co]]

The dataset to train on.

None
val_dataset Optional[Dataset[T_co]]

The dataset to validate on.

None
Source code in agentlightning/trainer/trainer.py
def dev(
    self,
    agent: LitAgent[T_co],
    train_dataset: Optional[Dataset[T_co]] = None,
    *,
    val_dataset: Optional[Dataset[T_co]] = None,
) -> None:
    """Dry run the training loop with a FastAlgorithm and the real runner.

    Args:
        agent: The LitAgent instance to be trained on.
        train_dataset: The dataset to train on.
        val_dataset: The dataset to validate on.
    """
    agent.set_trainer(self)

    # Sanity check
    if self.algorithm is None:
        algorithm = MockAlgorithm()
    else:
        algorithm = self.algorithm

    algorithm_bundle = functools.partial(
        self._algorithm_bundle,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        algorithm=algorithm,
    )
    runner_bundle = functools.partial(self._runner_bundle, agent=agent)
    self.strategy.execute(algorithm_bundle, runner_bundle, self.store)

fit(agent, train_data, *, val_data=None, dev_data=None, dev_backend=None)

Train the agent using the provided data.

Each data argument can be a string URL connecting to a agent-lightning server, or an AgentLightningClient instance connecting to a server (or mock server), or a dataset. If no algorithm is provided when instantiating the trainer, the data must be provided to connecting a server. Otherwise, dataset is also allowed and will be passed to the algorithm.

If the algorithm is instantiated and there is no URL/client provided, the algorithm will be responsible for creating a client that will connect to itself. It can also create a mock client if the algorithm does not require a server.

Source code in agentlightning/trainer/trainer.py
def fit(
    self,
    agent: LitAgent[T_co],
    train_data: Union[str, AgentLightningClient, Dataset[T_co]],
    *,
    val_data: Union[str, AgentLightningClient, Dataset[T_co], None] = None,
    dev_data: Union[str, AgentLightningClient, Dataset[T_co], None] = None,
    dev_backend: Union[str, AgentLightningClient, None] = None,
):
    """Train the agent using the provided data.

    Each data argument can be a string URL connecting to a agent-lightning server,
    or an AgentLightningClient instance connecting to a server (or mock server), or a dataset.
    If no algorithm is provided when instantiating the trainer, the data must be
    provided to connecting a server. Otherwise, dataset is also allowed and will be
    passed to the algorithm.

    If the algorithm is instantiated and there is no URL/client provided,
    the algorithm will be responsible for creating a client that will connect to itself.
    It can also create a mock client if the algorithm does not require a server.
    """

    if dev_backend is not None:
        warnings.warn("dev_backend is deprecated. Use dev_data instead.")
        if dev_data is not None:
            raise ValueError("dev_data and dev_backend cannot be provided at the same time.")
        dev_data = dev_backend

    # Extract datasets for algorithm if available
    train_dataset = self._extract_dataset_from_data(train_data)
    val_dataset = self._extract_dataset_from_data(val_data) if val_data else None

    # Initialize the algorithm with trainer if provided
    if self.algorithm is not None:
        self.algorithm.set_trainer(self)
        # DO NOT RUN TRAINING HERE. Need to spawn the worker first.

    # Determine the backend to use for client-server mode
    backend = self._determine_backend(train_data, dev_data)

    if self._dev:
        logger.warning(f"Running in dev mode. Using dev backend: {backend}")
    else:
        logger.debug(f"Running in non-dev mode. Using backend: {backend}")

    self.init(backend)

    processes: List[multiprocessing.Process] = []

    # Determine if the agent is asynchronous

    mode = "asynchronous" if agent.is_async() else "synchronous"

    try:
        if self.n_workers == 1:
            logger.info(f"Running with n_workers=1 ({mode} in main process).")

            # Warn if algorithm is set with single worker mode
            if self.algorithm is not None:
                logger.warning(
                    "Algorithm is set but using single worker mode. Algorithm will never get the chance to run."
                )
                # Ideally the single worker should be run in a separate thread or process.

            num_tasks = self._worker_main_loop(agent, 0, agent.is_async())
            logger.info(f"Single worker mode finished. Tasks processed: {num_tasks}")

            # If algorithm is provided and we have datasets, run algorithm after worker completes
            if self.algorithm is not None and train_dataset is not None:
                logger.info("Running algorithm training after worker completion.")
                self.algorithm.run(
                    train_dataset=train_dataset,
                    val_dataset=val_dataset,
                )
        else:
            logger.info(f"Running with n_workers={self.n_workers} ({mode} multiprocessing).")
            for i in range(self.n_workers):
                process_name = f"AgentLightning-Worker-{i}"
                p = multiprocessing.Process(
                    target=self._worker_main_loop,
                    args=(agent, i, agent.is_async()),
                    daemon=self.daemon,
                    name=process_name,
                )
                processes.append(p)
                logger.info(f"Starting worker process {i} (name: {process_name})...")
                p.start()

            if self.daemon:
                # If algorithm is provided and we have datasets, pass them to the algorithm
                if self.algorithm is not None:
                    logger.info("All workers have been spawned. Running algorithm training with provided datasets.")
                    self.algorithm.run(
                        train_dataset=train_dataset,
                        val_dataset=val_dataset,
                    )
                    logger.info("Algorithm exits. Killing the workers.")
                    self._terminate_processes(processes)

                for i, p in enumerate(processes):
                    p.join()  # Wait for the process to complete
                    logger.info(
                        f"Worker process {i} (name: {p.name}, PID: {p.pid}) joined with exit code {p.exitcode}."
                    )
                    if p.exitcode != 0:
                        logger.warning(
                            f"Worker process {i} (name: {p.name}, PID: {p.pid}) exited with non-zero code: {p.exitcode}."
                        )

                logger.info(f"All {self.n_workers} worker processes have completed.")
            else:
                logger.info("All worker processes started. Main process will not wait.")

                # A hack to stop the main process from waiting for child processes to finish.
                time.sleep(1)  # Give workers time to start
                import multiprocessing.process as multiprocessing_process

                multiprocessing_process._children.clear()  # type: ignore

                if self.algorithm is not None:
                    logger.info("Main process continues to run algorithm.")
                    self.algorithm.run(
                        train_dataset=train_dataset,
                        val_dataset=val_dataset,
                    )
                    logger.info("Algorithm exits. Killing the workers.")
                    self._terminate_processes(processes)

    except KeyboardInterrupt:
        logger.info("KeyboardInterrupt received. Killing the workers.")
        self._terminate_processes(processes)
        logger.info(f"Workers terminated or single worker interrupted.")
        raise
    except Exception:
        logger.exception(f"Unhandled exception in fit method.")
        self._terminate_processes(processes)
        logger.info(f"Workers terminated or single worker interrupted.")
        raise
    finally:
        if self.daemon:
            self.teardown()
        else:
            logger.info("Main process exiting. Please use Trainer.kill_orphaned_processes() for cleanup.")

fit_v2(agent, train_dataset=None, *, val_dataset=None)

Run the training loop using the configured strategy, store, and runner.

Parameters:

Name Type Description Default
agent LitAgent[T_co]

The LitAgent instance to be trained on.

required
train_dataset Optional[Dataset[T_co]]

The dataset to train on.

None
val_dataset Optional[Dataset[T_co]]

The dataset to validate on.

None
Source code in agentlightning/trainer/trainer.py
def fit_v2(
    self,
    agent: LitAgent[T_co],
    train_dataset: Optional[Dataset[T_co]] = None,
    *,
    val_dataset: Optional[Dataset[T_co]] = None,
) -> None:
    """Run the training loop using the configured strategy, store, and runner.

    Args:
        agent: The LitAgent instance to be trained on.
        train_dataset: The dataset to train on.
        val_dataset: The dataset to validate on.
    """
    agent.set_trainer(self)

    algorithm_bundle = functools.partial(
        self._algorithm_bundle,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        algorithm=self.algorithm,
    )
    runner_bundle = functools.partial(self._runner_bundle, agent=agent)

    self.strategy.execute(algorithm_bundle, runner_bundle, self.store)

kill_orphaned_processes() staticmethod

Kill any orphaned processes that may have been left behind by previous runs. This is useful for cleaning up after crashes or unexpected exits.

Source code in agentlightning/trainer/trainer.py
@staticmethod
def kill_orphaned_processes() -> None:
    """
    Kill any orphaned processes that may have been left behind by previous runs.
    This is useful for cleaning up after crashes or unexpected exits.
    """
    import psutil

    for proc in psutil.process_iter():  # type: ignore
        # check whether the process name matches
        if proc.name().startswith("AgentLightning-"):
            proc.kill()

agentlightning.tracer

AgentOpsTracer

Bases: BaseTracer

Traces agent execution using AgentOps.

This tracer provides functionality to capture execution details using the AgentOps library. It manages the AgentOps client initialization, server setup, and integration with the OpenTelemetry tracing ecosystem.

Attributes:

Name Type Description
agentops_managed

Whether to automatically manage agentops. When set to true, tracer calls agentops.init() automatically and launches an agentops endpoint locally. If not, you are responsible for calling and using it before using the tracer.

instrument_managed

Whether to automatically manage instrumentation. When set to false, you will manage the instrumentation yourself and the tracer might not work as expected.

daemon

Whether the AgentOps server runs as a daemon process. Only applicable if agentops_managed is True.

Source code in agentlightning/tracer/agentops.py
class AgentOpsTracer(BaseTracer):
    """Traces agent execution using AgentOps.

    This tracer provides functionality to capture execution details using the
    AgentOps library. It manages the AgentOps client initialization, server setup,
    and integration with the OpenTelemetry tracing ecosystem.

    Attributes:
        agentops_managed: Whether to automatically manage `agentops`.
                          When set to true, tracer calls `agentops.init()`
                          automatically and launches an agentops endpoint locally.
                          If not, you are responsible for calling and using it
                          before using the tracer.
        instrument_managed: Whether to automatically manage instrumentation.
                            When set to false, you will manage the instrumentation
                            yourself and the tracer might not work as expected.
        daemon: Whether the AgentOps server runs as a daemon process.
                Only applicable if `agentops_managed` is True.
    """

    def __init__(self, *, agentops_managed: bool = True, instrument_managed: bool = True, daemon: bool = True):
        super().__init__()
        self._lightning_span_processor: Optional[LightningSpanProcessor] = None
        self.agentops_managed = agentops_managed
        self.instrument_managed = instrument_managed
        self.daemon = daemon

        self._agentops_server_manager = AgentOpsServerManager(self.daemon)
        self._agentops_server_port_val: Optional[int] = None

        if not self.agentops_managed:
            logger.warning("agentops_managed=False. You are responsible for AgentOps setup.")
        if not self.instrument_managed:
            logger.warning("instrument_managed=False. You are responsible for all instrumentation.")

    def __getstate__(self):
        state = self.__dict__.copy()
        state["_agentops_server_manager"] = None  # Exclude the unpicklable server manager
        # _agentops_server_port_val (int) is inherently picklable and will be included.
        logger.debug(f"Getting state for pickling Trainer (PID {os.getpid()}). _agentops_server_manager excluded.")
        return state

    def __setstate__(self, state: Any):
        self.__dict__.update(state)
        # In child process, self._agentops_server_manager will be None.
        logger.debug(f"Setting state for unpickled Trainer (PID {os.getpid()}). _agentops_server_manager is None.")

    def init(self, *args: Any, **kwargs: Any):
        if self.agentops_managed and self._agentops_server_manager:
            self._agentops_server_manager.start()
            self._agentops_server_port_val = self._agentops_server_manager.get_port()
            if self._agentops_server_port_val is None:
                if (
                    self._agentops_server_manager.server_process is not None
                    and self._agentops_server_manager.server_process.is_alive()
                ):
                    raise RuntimeError("AgentOps server started but port is None. Check server manager logic.")
                elif (
                    self._agentops_server_port_val is None and self._agentops_server_manager.server_process is None
                ):  # Server failed to start
                    raise RuntimeError("AgentOps server manager indicates server is not running and port is None.")

    def teardown(self):
        if self.agentops_managed:
            self._agentops_server_manager.stop()
            logger.info("AgentOps server stopped.")

    def instrument(self, worker_id: int):
        instrument_all()

    def uninstrument(self, worker_id: int):
        uninstrument_all()

    def init_worker(self, worker_id: int):
        super().init_worker(worker_id)
        logger.info(f"[Worker {worker_id}] Setting up tracer...")  # worker_id included in process name

        if self.instrument_managed:
            self.instrument(worker_id)
            logger.info(f"[Worker {worker_id}] Instrumentation applied.")

        if self.agentops_managed:
            if self._agentops_server_port_val:  # Use the stored, picklable port value
                base_url = f"http://localhost:{self._agentops_server_port_val}"
                env_vars_to_set = {
                    "AGENTOPS_API_KEY": "dummy",
                    "AGENTOPS_API_ENDPOINT": base_url,
                    "AGENTOPS_APP_URL": f"{base_url}/notavailable",
                    "AGENTOPS_EXPORTER_ENDPOINT": f"{base_url}/traces",
                }
                for key, value in env_vars_to_set.items():
                    os.environ[key] = value
                    logger.info(f"[Worker {worker_id}] Env var set: {key}={value}")
            else:
                logger.warning(
                    f"[Worker {worker_id}] AgentOps managed, but local server port is not available. Client may not connect as expected."
                )

            if not agentops.get_client().initialized:
                agentops.init()  # type: ignore
                logger.info(f"[Worker {worker_id}] AgentOps client initialized.")
            else:
                logger.warning(f"[Worker {worker_id}] AgentOps client was already initialized.")

        self._lightning_span_processor = LightningSpanProcessor()

        try:
            # new versions
            instance = agentops.sdk.core.tracer
            # TODO: The span processor cannot be deleted once added.
            # This might be a problem if the tracer is entered and exited multiple times.
            instance.provider.add_span_processor(self._lightning_span_processor)  # type: ignore
        except AttributeError:
            # old versions
            instance = TracingCore.get_instance()  # type: ignore
            instance._provider.add_span_processor(self._lightning_span_processor)  # type: ignore

    def teardown_worker(self, worker_id: int) -> None:
        super().teardown_worker(worker_id)

        if self.instrument_managed:
            self.uninstrument(worker_id)
            logger.info(f"[Worker {worker_id}] Instrumentation removed.")

    @contextmanager
    def trace_context(
        self,
        name: Optional[str] = None,
        *,
        store: Optional[LightningStore] = None,
        rollout_id: Optional[str] = None,
        attempt_id: Optional[str] = None,
    ) -> Iterator[LightningSpanProcessor]:
        """
        Starts a new tracing context. This should be used as a context manager.

        Args:
            name: Optional name for the tracing context.
            store: Optional store to add the spans to.
            rollout_id: Optional rollout ID to add the spans to.
            attempt_id: Optional attempt ID to add the spans to.

        Yields:
            The LightningSpanProcessor instance to collect spans.
        """
        if not self._lightning_span_processor:
            raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")

        if store is not None and rollout_id is not None and attempt_id is not None:
            ctx = self._lightning_span_processor.with_context(store=store, rollout_id=rollout_id, attempt_id=attempt_id)
            with ctx as processor:
                yield processor
        elif store is None and rollout_id is None and attempt_id is None:
            with self._lightning_span_processor:
                yield self._lightning_span_processor
        else:
            raise ValueError("store, rollout_id, and attempt_id must be either all provided or all None")

    def get_last_trace(self) -> List[ReadableSpan]:
        """
        Retrieves the raw list of captured spans from the most recent trace.

        Returns:
            A list of OpenTelemetry `ReadableSpan` objects.
        """
        if not self._lightning_span_processor:
            raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")
        return self._lightning_span_processor.spans()

    def get_langchain_callback_handler(self, tags: List[str] | None = None) -> LangchainCallbackHandler:
        """
        Get the Langchain callback handler for integrating with Langchain.

        Args:
            tags: Optional list of tags to apply to the Langchain callback handler.

        Returns:
            An instance of the Langchain callback handler.
        """
        import agentops
        from agentops.integration.callbacks.langchain import LangchainCallbackHandler

        tags = tags or []
        client_instance = agentops.get_client()
        api_key = None
        if client_instance.initialized:
            api_key = client_instance.config.api_key
        else:
            logger.warning(
                "AgentOps client not initialized when creating LangchainCallbackHandler. API key may be missing."
            )
        return LangchainCallbackHandler(api_key=api_key, tags=tags)

get_langchain_callback_handler(tags=None)

Get the Langchain callback handler for integrating with Langchain.

Parameters:

Name Type Description Default
tags List[str] | None

Optional list of tags to apply to the Langchain callback handler.

None

Returns:

Type Description
LangchainCallbackHandler

An instance of the Langchain callback handler.

Source code in agentlightning/tracer/agentops.py
def get_langchain_callback_handler(self, tags: List[str] | None = None) -> LangchainCallbackHandler:
    """
    Get the Langchain callback handler for integrating with Langchain.

    Args:
        tags: Optional list of tags to apply to the Langchain callback handler.

    Returns:
        An instance of the Langchain callback handler.
    """
    import agentops
    from agentops.integration.callbacks.langchain import LangchainCallbackHandler

    tags = tags or []
    client_instance = agentops.get_client()
    api_key = None
    if client_instance.initialized:
        api_key = client_instance.config.api_key
    else:
        logger.warning(
            "AgentOps client not initialized when creating LangchainCallbackHandler. API key may be missing."
        )
    return LangchainCallbackHandler(api_key=api_key, tags=tags)

get_last_trace()

Retrieves the raw list of captured spans from the most recent trace.

Returns:

Type Description
List[ReadableSpan]

A list of OpenTelemetry ReadableSpan objects.

Source code in agentlightning/tracer/agentops.py
def get_last_trace(self) -> List[ReadableSpan]:
    """
    Retrieves the raw list of captured spans from the most recent trace.

    Returns:
        A list of OpenTelemetry `ReadableSpan` objects.
    """
    if not self._lightning_span_processor:
        raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")
    return self._lightning_span_processor.spans()

trace_context(name=None, *, store=None, rollout_id=None, attempt_id=None)

Starts a new tracing context. This should be used as a context manager.

Parameters:

Name Type Description Default
name Optional[str]

Optional name for the tracing context.

None
store Optional[LightningStore]

Optional store to add the spans to.

None
rollout_id Optional[str]

Optional rollout ID to add the spans to.

None
attempt_id Optional[str]

Optional attempt ID to add the spans to.

None

Yields:

Type Description
LightningSpanProcessor

The LightningSpanProcessor instance to collect spans.

Source code in agentlightning/tracer/agentops.py
@contextmanager
def trace_context(
    self,
    name: Optional[str] = None,
    *,
    store: Optional[LightningStore] = None,
    rollout_id: Optional[str] = None,
    attempt_id: Optional[str] = None,
) -> Iterator[LightningSpanProcessor]:
    """
    Starts a new tracing context. This should be used as a context manager.

    Args:
        name: Optional name for the tracing context.
        store: Optional store to add the spans to.
        rollout_id: Optional rollout ID to add the spans to.
        attempt_id: Optional attempt ID to add the spans to.

    Yields:
        The LightningSpanProcessor instance to collect spans.
    """
    if not self._lightning_span_processor:
        raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")

    if store is not None and rollout_id is not None and attempt_id is not None:
        ctx = self._lightning_span_processor.with_context(store=store, rollout_id=rollout_id, attempt_id=attempt_id)
        with ctx as processor:
            yield processor
    elif store is None and rollout_id is None and attempt_id is None:
        with self._lightning_span_processor:
            yield self._lightning_span_processor
    else:
        raise ValueError("store, rollout_id, and attempt_id must be either all provided or all None")

BaseTracer

Bases: ParallelWorkerBase

An abstract base class for tracers.

This class defines a standard interface for tracing code execution, capturing the resulting spans, and providing them for analysis. It is designed to be backend-agnostic, allowing for different implementations (e.g., for AgentOps, OpenTelemetry, Docker, etc.).

The primary interaction pattern is through the trace_context context manager, which ensures that traces are properly started and captured, even in the case of exceptions.

A typical workflow:

tracer = YourTracerImplementation()

try:
    with tracer.trace_context(name="my_traced_task"):
        # ... code to be traced ...
        run_my_agent_logic()
except Exception as e:
    print(f"An error occurred: {e}")

# Retrieve the trace data after the context block
spans: list[ReadableSpan] = tracer.get_last_trace()

# Process the trace data
if trace_tree:
    rl_triplets = TraceTripletAdapter().adapt(spans)
    # ... do something with the triplets
Source code in agentlightning/tracer/base.py
class BaseTracer(ParallelWorkerBase):
    """
    An abstract base class for tracers.

    This class defines a standard interface for tracing code execution,
    capturing the resulting spans, and providing them for analysis. It is
    designed to be backend-agnostic, allowing for different implementations
    (e.g., for AgentOps, OpenTelemetry, Docker, etc.).

    The primary interaction pattern is through the `trace_context`
    context manager, which ensures that traces are properly started and captured,
    even in the case of exceptions.

    A typical workflow:

    ```python
    tracer = YourTracerImplementation()

    try:
        with tracer.trace_context(name="my_traced_task"):
            # ... code to be traced ...
            run_my_agent_logic()
    except Exception as e:
        print(f"An error occurred: {e}")

    # Retrieve the trace data after the context block
    spans: list[ReadableSpan] = tracer.get_last_trace()

    # Process the trace data
    if trace_tree:
        rl_triplets = TraceTripletAdapter().adapt(spans)
        # ... do something with the triplets
    ```
    """

    @contextmanager
    def trace_context(
        self,
        name: Optional[str] = None,
        *,
        store: Optional[LightningStore] = None,
        rollout_id: Optional[str] = None,
        attempt_id: Optional[str] = None,
    ) -> Iterator[Any]:
        """
        Starts a new tracing context. This should be used as a context manager.

        The implementation should handle the setup and teardown of the tracing
        for the enclosed code block. It must ensure that any spans generated
        within the `with` block are collected and made available via
        `get_last_trace`.

        If a store is provided, the spans will be added to the store when tracing.

        Args:
            name: The name for the root span of this trace context.
            store: The store to add the spans to.
            rollout_id: The rollout ID to add the spans to.
            attempt_id: The attempt ID to add the spans to.
        """
        raise NotImplementedError()

    def get_last_trace(self) -> List[ReadableSpan]:
        """
        Retrieves the raw list of captured spans from the most recent trace.

        Returns:
            A list of OpenTelemetry `ReadableSpan` objects.
        """
        raise NotImplementedError()

    def trace_run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
        """
        A convenience wrapper to trace the execution of a single synchronous function.

        Args:
            func: The synchronous function to execute and trace.
            *args: Positional arguments to pass to the function.
            **kwargs: Keyword arguments to pass to the function.

        Returns:
            The return value of the function.
        """
        with self.trace_context(name=func.__name__):
            return func(*args, **kwargs)

    async def trace_run_async(self, func: Callable[..., Awaitable[Any]], *args: Any, **kwargs: Any) -> Any:
        """
        A convenience wrapper to trace the execution of a single asynchronous function.

        Args:
            func: The asynchronous function to execute and trace.
            *args: Positional arguments to pass to the function.
            **kwargs: Keyword arguments to pass to the function.

        Returns:
            The return value of the function.
        """
        with self.trace_context(name=func.__name__):
            return await func(*args, **kwargs)

get_last_trace()

Retrieves the raw list of captured spans from the most recent trace.

Returns:

Type Description
List[ReadableSpan]

A list of OpenTelemetry ReadableSpan objects.

Source code in agentlightning/tracer/base.py
def get_last_trace(self) -> List[ReadableSpan]:
    """
    Retrieves the raw list of captured spans from the most recent trace.

    Returns:
        A list of OpenTelemetry `ReadableSpan` objects.
    """
    raise NotImplementedError()

trace_context(name=None, *, store=None, rollout_id=None, attempt_id=None)

Starts a new tracing context. This should be used as a context manager.

The implementation should handle the setup and teardown of the tracing for the enclosed code block. It must ensure that any spans generated within the with block are collected and made available via get_last_trace.

If a store is provided, the spans will be added to the store when tracing.

Parameters:

Name Type Description Default
name Optional[str]

The name for the root span of this trace context.

None
store Optional[LightningStore]

The store to add the spans to.

None
rollout_id Optional[str]

The rollout ID to add the spans to.

None
attempt_id Optional[str]

The attempt ID to add the spans to.

None
Source code in agentlightning/tracer/base.py
@contextmanager
def trace_context(
    self,
    name: Optional[str] = None,
    *,
    store: Optional[LightningStore] = None,
    rollout_id: Optional[str] = None,
    attempt_id: Optional[str] = None,
) -> Iterator[Any]:
    """
    Starts a new tracing context. This should be used as a context manager.

    The implementation should handle the setup and teardown of the tracing
    for the enclosed code block. It must ensure that any spans generated
    within the `with` block are collected and made available via
    `get_last_trace`.

    If a store is provided, the spans will be added to the store when tracing.

    Args:
        name: The name for the root span of this trace context.
        store: The store to add the spans to.
        rollout_id: The rollout ID to add the spans to.
        attempt_id: The attempt ID to add the spans to.
    """
    raise NotImplementedError()

trace_run(func, *args, **kwargs)

A convenience wrapper to trace the execution of a single synchronous function.

Parameters:

Name Type Description Default
func Callable[..., Any]

The synchronous function to execute and trace.

required
*args Any

Positional arguments to pass to the function.

()
**kwargs Any

Keyword arguments to pass to the function.

{}

Returns:

Type Description
Any

The return value of the function.

Source code in agentlightning/tracer/base.py
def trace_run(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
    """
    A convenience wrapper to trace the execution of a single synchronous function.

    Args:
        func: The synchronous function to execute and trace.
        *args: Positional arguments to pass to the function.
        **kwargs: Keyword arguments to pass to the function.

    Returns:
        The return value of the function.
    """
    with self.trace_context(name=func.__name__):
        return func(*args, **kwargs)

trace_run_async(func, *args, **kwargs) async

A convenience wrapper to trace the execution of a single asynchronous function.

Parameters:

Name Type Description Default
func Callable[..., Awaitable[Any]]

The asynchronous function to execute and trace.

required
*args Any

Positional arguments to pass to the function.

()
**kwargs Any

Keyword arguments to pass to the function.

{}

Returns:

Type Description
Any

The return value of the function.

Source code in agentlightning/tracer/base.py
async def trace_run_async(self, func: Callable[..., Awaitable[Any]], *args: Any, **kwargs: Any) -> Any:
    """
    A convenience wrapper to trace the execution of a single asynchronous function.

    Args:
        func: The asynchronous function to execute and trace.
        *args: Positional arguments to pass to the function.
        **kwargs: Keyword arguments to pass to the function.

    Returns:
        The return value of the function.
    """
    with self.trace_context(name=func.__name__):
        return await func(*args, **kwargs)

OtelTracer

Bases: BaseTracer

Tracer that provides a basic OpenTelemetry tracer provider.

You should be able to collect signals like rewards with this tracer, but no other function instrumentations like openai.chat.completion.

Source code in agentlightning/tracer/otel.py
class OtelTracer(BaseTracer):
    """Tracer that provides a basic OpenTelemetry tracer provider.

    You should be able to collect signals like rewards with this tracer,
    but no other function instrumentations like `openai.chat.completion`.
    """

    def __init__(self):
        super().__init__()
        # This provider is only initialized when the worker is initialized.
        self._tracer_provider: Optional[TracerProvider] = None
        self._lightning_span_processor: Optional[LightningSpanProcessor] = None
        self._initialized: bool = False

    def init_worker(self, worker_id: int):
        super().init_worker(worker_id)
        logger.info(f"[Worker {worker_id}] Setting up OpenTelemetry tracer...")

        if self._initialized:
            logger.error("Tracer provider is already initialized. OpenTelemetry may not work as expected.")

        tracer_provider = TracerProvider()
        trace_api.set_tracer_provider(tracer_provider)
        self._lightning_span_processor = LightningSpanProcessor()
        tracer_provider.add_span_processor(self._lightning_span_processor)
        self._initialized = True

    def teardown_worker(self, worker_id: int):
        super().teardown_worker(worker_id)
        logger.info(f"[Worker {worker_id}] Tearing down OpenTelemetry tracer...")
        self._tracer_provider = None

    @contextmanager
    def trace_context(
        self,
        name: Optional[str] = None,
        *,
        store: Optional[LightningStore] = None,
        rollout_id: Optional[str] = None,
        attempt_id: Optional[str] = None,
    ) -> Iterator[LightningSpanProcessor]:
        """
        Starts a new tracing context. This should be used as a context manager.

        Args:
            name: Optional name for the tracing context.
            store: Optional store to add the spans to.
            rollout_id: Optional rollout ID to add the spans to.
            attempt_id: Optional attempt ID to add the spans to.

        Yields:
            The LightningSpanProcessor instance to collect spans.
        """
        if not self._lightning_span_processor:
            raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")

        if store is not None and rollout_id is not None and attempt_id is not None:
            ctx = self._lightning_span_processor.with_context(store=store, rollout_id=rollout_id, attempt_id=attempt_id)
            with ctx as processor:
                yield processor
        elif store is None and rollout_id is None and attempt_id is None:
            with self._lightning_span_processor:
                yield self._lightning_span_processor
        else:
            raise ValueError("store, rollout_id, and attempt_id must be either all provided or all None")

    def get_last_trace(self) -> List[ReadableSpan]:
        """
        Retrieves the raw list of captured spans from the most recent trace.

        Returns:
            A list of OpenTelemetry `ReadableSpan` objects.
        """
        if not self._lightning_span_processor:
            raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")
        return self._lightning_span_processor.spans()

get_last_trace()

Retrieves the raw list of captured spans from the most recent trace.

Returns:

Type Description
List[ReadableSpan]

A list of OpenTelemetry ReadableSpan objects.

Source code in agentlightning/tracer/otel.py
def get_last_trace(self) -> List[ReadableSpan]:
    """
    Retrieves the raw list of captured spans from the most recent trace.

    Returns:
        A list of OpenTelemetry `ReadableSpan` objects.
    """
    if not self._lightning_span_processor:
        raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")
    return self._lightning_span_processor.spans()

trace_context(name=None, *, store=None, rollout_id=None, attempt_id=None)

Starts a new tracing context. This should be used as a context manager.

Parameters:

Name Type Description Default
name Optional[str]

Optional name for the tracing context.

None
store Optional[LightningStore]

Optional store to add the spans to.

None
rollout_id Optional[str]

Optional rollout ID to add the spans to.

None
attempt_id Optional[str]

Optional attempt ID to add the spans to.

None

Yields:

Type Description
LightningSpanProcessor

The LightningSpanProcessor instance to collect spans.

Source code in agentlightning/tracer/otel.py
@contextmanager
def trace_context(
    self,
    name: Optional[str] = None,
    *,
    store: Optional[LightningStore] = None,
    rollout_id: Optional[str] = None,
    attempt_id: Optional[str] = None,
) -> Iterator[LightningSpanProcessor]:
    """
    Starts a new tracing context. This should be used as a context manager.

    Args:
        name: Optional name for the tracing context.
        store: Optional store to add the spans to.
        rollout_id: Optional rollout ID to add the spans to.
        attempt_id: Optional attempt ID to add the spans to.

    Yields:
        The LightningSpanProcessor instance to collect spans.
    """
    if not self._lightning_span_processor:
        raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")

    if store is not None and rollout_id is not None and attempt_id is not None:
        ctx = self._lightning_span_processor.with_context(store=store, rollout_id=rollout_id, attempt_id=attempt_id)
        with ctx as processor:
            yield processor
    elif store is None and rollout_id is None and attempt_id is None:
        with self._lightning_span_processor:
            yield self._lightning_span_processor
    else:
        raise ValueError("store, rollout_id, and attempt_id must be either all provided or all None")

agentlightning.reward

emit_reward(reward)

Record a new reward as a new span.

Source code in agentlightning/emitter/reward.py
def emit_reward(reward: float) -> ReadableSpan:
    """
    Record a new reward as a new span.
    """
    logger.debug(f"Emitting reward: {reward}")
    if isinstance(reward, (int, bool)):
        reward = float(reward)
    if not isinstance(reward, float):
        raise ValueError(f"Reward must be a number, got: {type(reward)}")

    tracer = get_tracer()
    span = tracer.start_span(SpanNames.REWARD.value, attributes={"reward": reward})
    # Do nothing; it's just a number
    with span:
        pass
    if not isinstance(span, ReadableSpan):
        raise ValueError(f"Span is not a ReadableSpan: {span}")
    return span

find_final_reward(spans)

Get the last reward value from a list of spans.

Parameters:

Name Type Description Default
spans Sequence[SpanLike]

A list of spans (either ReadableSpan or Span).

required

Returns:

Type Description
Optional[float]

The reward value from the last reward span, or None if not found.

Source code in agentlightning/emitter/reward.py
def find_final_reward(spans: Sequence[SpanLike]) -> Optional[float]:
    """
    Get the last reward value from a list of spans.

    Args:
        spans: A list of spans (either ReadableSpan or Span).

    Returns:
        The reward value from the last reward span, or None if not found.
    """
    for span in reversed(spans):
        reward = get_reward_value(span)
        if reward is not None:
            return reward
    return None

find_reward_spans(spans)

Find all reward spans in the given list of spans.

Parameters:

Name Type Description Default
spans Sequence[SpanLike]

A list of spans (either ReadableSpan or Span).

required

Returns:

Type Description
List[SpanLike]

A list of spans whose name matches the reward span name.

Source code in agentlightning/emitter/reward.py
def find_reward_spans(spans: Sequence[SpanLike]) -> List[SpanLike]:
    """
    Find all reward spans in the given list of spans.

    Args:
        spans: A list of spans (either ReadableSpan or Span).

    Returns:
        A list of spans whose name matches the reward span name.
    """
    return [span for span in spans if is_reward_span(span)]

get_reward_value(span)

Get the reward value from a span.

Source code in agentlightning/emitter/reward.py
def get_reward_value(span: SpanLike) -> Optional[float]:
    """
    Get the reward value from a span.
    """
    for key in [
        "agentops.task.output",  # newer versions of agentops
        "agentops.entity.output",
    ]:
        reward_dict: Dict[str, Any] | None = None
        if span.attributes:
            output = span.attributes.get(key)
            if output:
                if isinstance(output, dict):
                    reward_dict = cast(Dict[str, Any], output)
                elif isinstance(output, str):
                    try:
                        reward_dict = cast(Dict[str, Any], json.loads(output))
                    except json.JSONDecodeError:
                        reward_dict = None

        if reward_dict and reward_dict.get("type") == "reward":
            reward_value = reward_dict.get("value", None)
            if reward_value is None:
                return None
            if not isinstance(reward_value, float):
                logger.error(f"Reward is not a number, got: {type(reward_value)}. This may cause undefined behaviors.")
            return cast(float, reward_value)

    # Latest emit reward format
    if span.name == SpanNames.REWARD.value and span.attributes:
        reward_value = span.attributes.get("reward", None)
        if reward_value is None:
            return None
        if not isinstance(reward_value, float):
            logger.error(f"Reward is not a number, got: {type(reward_value)}. This may cause undefined behaviors.")
        return cast(float, reward_value)
    return None

is_reward_span(span)

Check if a span is a reward span.

Source code in agentlightning/emitter/reward.py
def is_reward_span(span: SpanLike) -> bool:
    """
    Check if a span is a reward span.
    """
    maybe_reward = get_reward_value(span)
    return maybe_reward is not None

reward(fn)

A decorator to wrap a function that computes rewards. It will automatically handle the input and output of the function.

Source code in agentlightning/emitter/reward.py
def reward(fn: FnType) -> FnType:
    """
    A decorator to wrap a function that computes rewards.
    It will automatically handle the input and output of the function.
    """

    def wrap_result(result: Optional[float]) -> RewardSpanData:
        """
        Wrap the result of the function in a dict.
        """
        if result is None:
            return {"type": "reward", "value": None}
        if not isinstance(result, (float, int)):  # type: ignore
            warnings.warn(f"Reward is ignored because it is not a number: {result}")
            return {"type": "reward", "value": None}
        return {"type": "reward", "value": float(result)}

    # Check if the function is async
    is_async = asyncio.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn)

    if is_async:

        async def wrapper_async(*args: Any, **kwargs: Any) -> Any:
            if not _agentops_initialized():
                # Track the reward without AgentOps
                result = await fn(*args, **kwargs)
                emit_reward(cast(float, result))
                return result

            result: Optional[float] = None

            @operation
            async def agentops_reward_operation() -> RewardSpanData:
                # The reward function we are interested in tracing
                # It takes zero inputs and return a formatted dict
                nonlocal result
                result = await fn(*args, **kwargs)
                return wrap_result(result)

            await agentops_reward_operation()
            return result

        return wrapper_async  # type: ignore

    else:

        def wrapper(*args: Any, **kwargs: Any) -> Any:
            if not _agentops_initialized():
                # Track the reward without AgentOps
                result = fn(*args, **kwargs)
                emit_reward(cast(float, result))
                return result

            result: Optional[float] = None

            @operation
            def agentops_reward_operation() -> RewardSpanData:
                nonlocal result
                result = fn(*args, **kwargs)
                return wrap_result(result)

            agentops_reward_operation()
            return result

        return wrapper  # type: ignore

Server Side

agentlightning.server

Legacy server for the Agent Lightning framework. Deprecated in favor of agentlightning.store.

AgentLightningServer

The main SDK class for developers to control the Agent Lightning Server.

This class manages the server lifecycle, task queueing, resources updates, and retrieval of results, providing a simple interface for the optimization logic.

Source code in agentlightning/server.py
class AgentLightningServer:
    """
    The main SDK class for developers to control the Agent Lightning Server.

    This class manages the server lifecycle, task queueing, resources updates,
    and retrieval of results, providing a simple interface for the optimization logic.
    """

    def __init__(self, host: str = "127.0.0.1", port: int = 8000, task_timeout_seconds: float = 300.0):
        """
        Initializes the server controller.

        Args:
            host: The host to bind the server to.
            port: The port to bind the server to.
            task_timeout_seconds: Time in seconds after which a claimed task is considered stale and requeued.
        """
        self.host = host
        self.port = port
        self.endpoint = f"http://{host}:{port}"
        self._task_timeout_seconds = task_timeout_seconds

        # Defer initialization and use event for cross-thread communication
        self._store: Optional[ServerDataStore] = None
        self.loop: Optional[asyncio.AbstractEventLoop] = None
        self.startup_event = threading.Event()

        # Create FastAPI app instance with a lifespan manager
        self._app = FastAPI(lifespan=self._lifespan)
        self._setup_routes()

        self._uvicorn_config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info")
        self._uvicorn_server = uvicorn.Server(self._uvicorn_config)

    # --- ADDED: Lifespan context manager ---
    @asynccontextmanager
    async def _lifespan(self, app: FastAPI):
        """
        Manages server startup and shutdown. This runs inside the server's event loop.
        """
        logger.info("Server is starting up...")
        self.loop = asyncio.get_running_loop()
        self._store = ServerDataStore()  # Initialize data store here
        self.startup_event.set()  # Signal that the server is ready

        yield

        logger.info("Server is shutting down.")
        self._store = None
        self.startup_event.clear()  # Clear the startup event
        self.loop = None

    async def _check_and_requeue_stale_tasks(self):
        """
        Check for stale tasks and requeue them. Called reactively during get_next_task.
        """
        current_time = time.time()
        # Ensure store is initialized before checking
        if not self._store:
            return
        processing_tasks = self._store.get_processing_tasks()

        for _, task in processing_tasks.items():
            if task.last_claim_time and current_time - task.last_claim_time > self._task_timeout_seconds:
                await self._store.requeue_task(task)
                logger.warning(
                    f"Task {task.rollout_id} timed out after {self._task_timeout_seconds}s, requeued (attempt {task.num_claims})"
                )

    def _setup_routes(self):
        """Setup FastAPI routes."""

        @self._app.get("/task", response_model=TaskIfAny)
        async def next_task() -> TaskIfAny:  # type: ignore
            """Endpoint for clients to poll for the next available task."""
            await self._check_and_requeue_stale_tasks()

            if not self._store:
                return TaskIfAny(is_available=False)

            task = await self._store.get_next_task()
            if task:
                logger.debug(f"Serving task {task.rollout_id} to a client.")
                return TaskIfAny(is_available=True, task=task)
            else:
                logger.debug("No task available for client.")
                return TaskIfAny(is_available=False)

        @self._app.get("/resources/latest", response_model=ResourcesUpdate)
        async def fetch_latest_resources() -> ResourcesUpdate:  # type: ignore
            """Endpoint for clients to poll for the latest available resources."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            resources_update = await self._store.get_latest_resources()
            if not resources_update:
                raise HTTPException(status_code=404, detail="No resources have been set on the server.")
            logger.debug(f"Serving latest resources '{resources_update.resources_id}' to a client.")
            return resources_update

        @self._app.get("/resources/{resource_id}", response_model=ResourcesUpdate)
        async def fetch_resources_by_id(  # type: ignore
            resource_id: str = Path(..., description="The unique identifier for the resource version.")
        ) -> ResourcesUpdate:
            """Endpoint for clients to fetch a specific version of resources."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            resources_update = await self._store.get_resources_by_id(resource_id)
            if not resources_update:
                raise HTTPException(status_code=404, detail=f"Resource ID '{resource_id}' not found.")
            logger.debug(f"Serving resources for ID '{resource_id}' to a client.")
            return resources_update

        @self._app.post("/rollout", response_model=GenericResponse)
        async def post_rollout(payload: Rollout) -> GenericResponse:  # type: ignore
            """Endpoint for clients to report a completed rollout."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            await self._store.store_rollout(payload)
            return GenericResponse(
                status="ok",
                message=f"Rollout {payload.rollout_id} received and stored.",
            )

    async def start(self):
        """Starts the FastAPI server in the background."""
        logger.info(f"Starting server at {self.endpoint}")
        asyncio.create_task(self._uvicorn_server.serve())
        await asyncio.sleep(1)  # Allow time for server to start up.

    async def stop(self):
        """Gracefully stops the running FastAPI server."""
        if self._uvicorn_server.started:
            logger.info("Stopping server...")
            self._uvicorn_server.should_exit = True
            await asyncio.sleep(1)  # Allow time for graceful shutdown.
            logger.info("Server stopped.")

    async def run_forever(self):
        """
        Runs the server indefinitely until stopped.
        This is useful when async start and stop methods do not work.
        """
        await self._uvicorn_server.serve()

    async def queue_task(
        self,
        sample: Any,
        mode: Literal["train", "val", "test"] | None = None,
        resources_id: str | None = None,
        metadata: Dict[str, Any] | None = None,
    ) -> str:
        """
        Adds a task to the queue for a client to process.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata)

    async def update_resources(self, resources: NamedResources) -> str:
        """
        Updates the resources, creating a new version and setting it as the latest.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        resources_id = f"res-{uuid.uuid4()}"
        update = ResourcesUpdate(resources_id=resources_id, resources=resources)
        await self._store.update_resources(update)
        return resources_id

    async def get_completed_rollout(self, rollout_id: str) -> Optional[Rollout]:
        """
        Retrieves a specific completed rollout by its ID.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.retrieve_rollout(rollout_id)

    async def poll_completed_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
        """
        Polls for a completed rollout by its ID, waiting up to `timeout` seconds.
        """
        start_time = time.time()
        while True:
            rollout = await self.get_completed_rollout(rollout_id)
            if rollout:
                return rollout
            if timeout and (time.time() - start_time) >= timeout:
                return None
            await asyncio.sleep(1)

    async def retrieve_completed_rollouts(self) -> List[Rollout]:
        """
        Retrieves all available completed trajectories and clears the internal store.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.retrieve_completed_rollouts()

__init__(host='127.0.0.1', port=8000, task_timeout_seconds=300.0)

Initializes the server controller.

Parameters:

Name Type Description Default
host str

The host to bind the server to.

'127.0.0.1'
port int

The port to bind the server to.

8000
task_timeout_seconds float

Time in seconds after which a claimed task is considered stale and requeued.

300.0
Source code in agentlightning/server.py
def __init__(self, host: str = "127.0.0.1", port: int = 8000, task_timeout_seconds: float = 300.0):
    """
    Initializes the server controller.

    Args:
        host: The host to bind the server to.
        port: The port to bind the server to.
        task_timeout_seconds: Time in seconds after which a claimed task is considered stale and requeued.
    """
    self.host = host
    self.port = port
    self.endpoint = f"http://{host}:{port}"
    self._task_timeout_seconds = task_timeout_seconds

    # Defer initialization and use event for cross-thread communication
    self._store: Optional[ServerDataStore] = None
    self.loop: Optional[asyncio.AbstractEventLoop] = None
    self.startup_event = threading.Event()

    # Create FastAPI app instance with a lifespan manager
    self._app = FastAPI(lifespan=self._lifespan)
    self._setup_routes()

    self._uvicorn_config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info")
    self._uvicorn_server = uvicorn.Server(self._uvicorn_config)

get_completed_rollout(rollout_id) async

Retrieves a specific completed rollout by its ID.

Source code in agentlightning/server.py
async def get_completed_rollout(self, rollout_id: str) -> Optional[Rollout]:
    """
    Retrieves a specific completed rollout by its ID.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.retrieve_rollout(rollout_id)

poll_completed_rollout(rollout_id, timeout=None) async

Polls for a completed rollout by its ID, waiting up to timeout seconds.

Source code in agentlightning/server.py
async def poll_completed_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
    """
    Polls for a completed rollout by its ID, waiting up to `timeout` seconds.
    """
    start_time = time.time()
    while True:
        rollout = await self.get_completed_rollout(rollout_id)
        if rollout:
            return rollout
        if timeout and (time.time() - start_time) >= timeout:
            return None
        await asyncio.sleep(1)

queue_task(sample, mode=None, resources_id=None, metadata=None) async

Adds a task to the queue for a client to process.

Source code in agentlightning/server.py
async def queue_task(
    self,
    sample: Any,
    mode: Literal["train", "val", "test"] | None = None,
    resources_id: str | None = None,
    metadata: Dict[str, Any] | None = None,
) -> str:
    """
    Adds a task to the queue for a client to process.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata)

retrieve_completed_rollouts() async

Retrieves all available completed trajectories and clears the internal store.

Source code in agentlightning/server.py
async def retrieve_completed_rollouts(self) -> List[Rollout]:
    """
    Retrieves all available completed trajectories and clears the internal store.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.retrieve_completed_rollouts()

run_forever() async

Runs the server indefinitely until stopped. This is useful when async start and stop methods do not work.

Source code in agentlightning/server.py
async def run_forever(self):
    """
    Runs the server indefinitely until stopped.
    This is useful when async start and stop methods do not work.
    """
    await self._uvicorn_server.serve()

start() async

Starts the FastAPI server in the background.

Source code in agentlightning/server.py
async def start(self):
    """Starts the FastAPI server in the background."""
    logger.info(f"Starting server at {self.endpoint}")
    asyncio.create_task(self._uvicorn_server.serve())
    await asyncio.sleep(1)  # Allow time for server to start up.

stop() async

Gracefully stops the running FastAPI server.

Source code in agentlightning/server.py
async def stop(self):
    """Gracefully stops the running FastAPI server."""
    if self._uvicorn_server.started:
        logger.info("Stopping server...")
        self._uvicorn_server.should_exit = True
        await asyncio.sleep(1)  # Allow time for graceful shutdown.
        logger.info("Server stopped.")

update_resources(resources) async

Updates the resources, creating a new version and setting it as the latest.

Source code in agentlightning/server.py
async def update_resources(self, resources: NamedResources) -> str:
    """
    Updates the resources, creating a new version and setting it as the latest.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    resources_id = f"res-{uuid.uuid4()}"
    update = ResourcesUpdate(resources_id=resources_id, resources=resources)
    await self._store.update_resources(update)
    return resources_id

ServerDataStore

A centralized, thread-safe, async, in-memory data store for the server's state. This holds the task queue, versioned resources, and completed rollouts.

Source code in agentlightning/server.py
class ServerDataStore:
    """
    A centralized, thread-safe, async, in-memory data store for the server's state.
    This holds the task queue, versioned resources, and completed rollouts.
    """

    def __init__(self):
        self._task_queue: asyncio.Queue[Task] = asyncio.Queue()
        self._processing_tasks: Dict[str, Task] = {}  # Currently processing tasks
        self._completed_rollouts: Dict[str, Rollout] = {}

        # Store for versioned resources
        self._resource_versions: Dict[str, NamedResources] = {}
        self._latest_resources_id: Optional[str] = None

        # Locks for thread-safe access
        self._results_lock = asyncio.Lock()
        self._resources_lock = asyncio.Lock()

    async def add_task(
        self,
        sample: Any,
        mode: Literal["train", "val", "test"] | None = None,
        resources_id: str | None = None,
        metadata: Dict[str, Any] | None = None,
    ) -> str:
        """
        Adds a new task to the queue with specific metadata and returns its unique ID.
        """
        rollout_id = f"rollout-{uuid.uuid4()}"
        task = Task(
            rollout_id=rollout_id,
            input=sample,
            mode=mode,
            resources_id=resources_id,
            create_time=time.time(),
            num_claims=0,
            metadata=metadata or {},
        )
        await self._task_queue.put(task)
        logger.info(f"Task queued: {rollout_id} (mode: {mode}, resources_id: {resources_id})")
        return rollout_id

    async def get_next_task(self) -> Optional[Task]:
        """
        Retrieves the next task from the queue without blocking.
        Returns None if the queue is empty.
        """
        try:
            async with self._results_lock:
                task = self._task_queue.get_nowait()
                task = task.model_copy(
                    update={
                        "last_claim_time": time.time(),
                        "num_claims": (task.num_claims or 0) + 1,
                    }
                )
                self._processing_tasks[task.rollout_id] = task
                if task.num_claims == 1:
                    logger.debug(f"Next task retrieved: {task.rollout_id}")
                else:
                    logger.info(f"Task {task.rollout_id} re-claimed (attempt {task.num_claims})")
                return task
        except asyncio.QueueEmpty:
            return None

    async def update_resources(self, update: ResourcesUpdate):
        """
        Safely stores a new version of named resources and sets it as the latest.
        """
        # TODO: evict old resources if necessary.
        async with self._resources_lock:
            self._resource_versions[update.resources_id] = update.resources
            self._latest_resources_id = update.resources_id
            logger.info(f"Resources updated. New version '{update.resources_id}' is now latest.")

    async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
        """
        Safely retrieves a specific version of named resources by its ID.
        """
        async with self._resources_lock:
            resources = self._resource_versions.get(resources_id)
            if resources:
                return ResourcesUpdate(resources_id=resources_id, resources=resources)
            return None

    async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
        """
        Safely retrieves the latest version of named resources.
        """
        if self._latest_resources_id:
            return await self.get_resources_by_id(self._latest_resources_id)
        return None

    async def store_rollout(self, rollout: Rollout):
        """
        Safely stores a completed rollout from a client.
        """
        async with self._results_lock:
            self._processing_tasks.pop(rollout.rollout_id, None)
            self._completed_rollouts[rollout.rollout_id] = rollout
            logger.info(f"Rollout received and stored: {rollout.rollout_id}")

    async def retrieve_rollout(self, rollout_id: str) -> Optional[Rollout]:
        """
        Safely retrieves a single rollout by its ID, removing it from the store.
        """
        async with self._results_lock:
            return self._completed_rollouts.pop(rollout_id, None)

    async def retrieve_completed_rollouts(self) -> List[Rollout]:
        """
        Retrieves all completed rollouts and clears the store.
        """
        async with self._results_lock:
            rollouts = list(self._completed_rollouts.values())
            self._completed_rollouts.clear()
            return rollouts

    def get_processing_tasks(self) -> Dict[str, Task]:
        """Returns a copy of currently processing tasks for timeout checking."""
        return self._processing_tasks.copy()

    async def requeue_task(self, task: Task):
        """Requeues a task that has timed out and removes it from processing."""
        logger.warning(f"Requeuing task {task.rollout_id} after timeout (attempt {task.num_claims})")
        async with self._results_lock:
            # Remove from processing tasks
            self._processing_tasks.pop(task.rollout_id, None)
            self._task_queue.put_nowait(task)

add_task(sample, mode=None, resources_id=None, metadata=None) async

Adds a new task to the queue with specific metadata and returns its unique ID.

Source code in agentlightning/server.py
async def add_task(
    self,
    sample: Any,
    mode: Literal["train", "val", "test"] | None = None,
    resources_id: str | None = None,
    metadata: Dict[str, Any] | None = None,
) -> str:
    """
    Adds a new task to the queue with specific metadata and returns its unique ID.
    """
    rollout_id = f"rollout-{uuid.uuid4()}"
    task = Task(
        rollout_id=rollout_id,
        input=sample,
        mode=mode,
        resources_id=resources_id,
        create_time=time.time(),
        num_claims=0,
        metadata=metadata or {},
    )
    await self._task_queue.put(task)
    logger.info(f"Task queued: {rollout_id} (mode: {mode}, resources_id: {resources_id})")
    return rollout_id

get_latest_resources() async

Safely retrieves the latest version of named resources.

Source code in agentlightning/server.py
async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
    """
    Safely retrieves the latest version of named resources.
    """
    if self._latest_resources_id:
        return await self.get_resources_by_id(self._latest_resources_id)
    return None

get_next_task() async

Retrieves the next task from the queue without blocking. Returns None if the queue is empty.

Source code in agentlightning/server.py
async def get_next_task(self) -> Optional[Task]:
    """
    Retrieves the next task from the queue without blocking.
    Returns None if the queue is empty.
    """
    try:
        async with self._results_lock:
            task = self._task_queue.get_nowait()
            task = task.model_copy(
                update={
                    "last_claim_time": time.time(),
                    "num_claims": (task.num_claims or 0) + 1,
                }
            )
            self._processing_tasks[task.rollout_id] = task
            if task.num_claims == 1:
                logger.debug(f"Next task retrieved: {task.rollout_id}")
            else:
                logger.info(f"Task {task.rollout_id} re-claimed (attempt {task.num_claims})")
            return task
    except asyncio.QueueEmpty:
        return None

get_processing_tasks()

Returns a copy of currently processing tasks for timeout checking.

Source code in agentlightning/server.py
def get_processing_tasks(self) -> Dict[str, Task]:
    """Returns a copy of currently processing tasks for timeout checking."""
    return self._processing_tasks.copy()

get_resources_by_id(resources_id) async

Safely retrieves a specific version of named resources by its ID.

Source code in agentlightning/server.py
async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
    """
    Safely retrieves a specific version of named resources by its ID.
    """
    async with self._resources_lock:
        resources = self._resource_versions.get(resources_id)
        if resources:
            return ResourcesUpdate(resources_id=resources_id, resources=resources)
        return None

requeue_task(task) async

Requeues a task that has timed out and removes it from processing.

Source code in agentlightning/server.py
async def requeue_task(self, task: Task):
    """Requeues a task that has timed out and removes it from processing."""
    logger.warning(f"Requeuing task {task.rollout_id} after timeout (attempt {task.num_claims})")
    async with self._results_lock:
        # Remove from processing tasks
        self._processing_tasks.pop(task.rollout_id, None)
        self._task_queue.put_nowait(task)

retrieve_completed_rollouts() async

Retrieves all completed rollouts and clears the store.

Source code in agentlightning/server.py
async def retrieve_completed_rollouts(self) -> List[Rollout]:
    """
    Retrieves all completed rollouts and clears the store.
    """
    async with self._results_lock:
        rollouts = list(self._completed_rollouts.values())
        self._completed_rollouts.clear()
        return rollouts

retrieve_rollout(rollout_id) async

Safely retrieves a single rollout by its ID, removing it from the store.

Source code in agentlightning/server.py
async def retrieve_rollout(self, rollout_id: str) -> Optional[Rollout]:
    """
    Safely retrieves a single rollout by its ID, removing it from the store.
    """
    async with self._results_lock:
        return self._completed_rollouts.pop(rollout_id, None)

store_rollout(rollout) async

Safely stores a completed rollout from a client.

Source code in agentlightning/server.py
async def store_rollout(self, rollout: Rollout):
    """
    Safely stores a completed rollout from a client.
    """
    async with self._results_lock:
        self._processing_tasks.pop(rollout.rollout_id, None)
        self._completed_rollouts[rollout.rollout_id] = rollout
        logger.info(f"Rollout received and stored: {rollout.rollout_id}")

update_resources(update) async

Safely stores a new version of named resources and sets it as the latest.

Source code in agentlightning/server.py
async def update_resources(self, update: ResourcesUpdate):
    """
    Safely stores a new version of named resources and sets it as the latest.
    """
    # TODO: evict old resources if necessary.
    async with self._resources_lock:
        self._resource_versions[update.resources_id] = update.resources
        self._latest_resources_id = update.resources_id
        logger.info(f"Resources updated. New version '{update.resources_id}' is now latest.")

Utilities

agentlightning.config

This file is not carefully reviewed. It might contain unintentional bugs and issues. Please always review the parsed construction arguments before using them.

lightning_cli(*classes)

lightning_cli(cls1: Type[_C1]) -> _C1
lightning_cli(
    cls1: Type[_C1], cls2: Type[_C2]
) -> Tuple[_C1, _C2]
lightning_cli(
    cls1: Type[_C1], cls2: Type[_C2], cls3: Type[_C3]
) -> Tuple[_C1, _C2, _C3]
lightning_cli(
    cls1: Type[_C1],
    cls2: Type[_C2],
    cls3: Type[_C3],
    cls4: Type[_C4],
) -> Tuple[_C1, _C2, _C3, _C4]
lightning_cli(
    *classes: Type[CliConfigurable],
) -> Tuple[CliConfigurable, ...]

Parses command-line arguments to configure and instantiate provided CliConfigurable classes.

Parameters:

Name Type Description Default
*classes Type[CliConfigurable]

One or more classes that inherit from CliConfigurable. Each class's init parameters will be exposed as command-line arguments.

()

Returns:

Type Description
CliConfigurable | Tuple[CliConfigurable, ...]

A tuple of instantiated objects, corresponding to the input classes in order.

Source code in agentlightning/config.py
def lightning_cli(*classes: Type[CliConfigurable]) -> CliConfigurable | Tuple[CliConfigurable, ...]:  # type: ignore
    """
    Parses command-line arguments to configure and instantiate provided CliConfigurable classes.

    Args:
        *classes: One or more classes that inherit from CliConfigurable. Each class's
                  __init__ parameters will be exposed as command-line arguments.

    Returns:
        A tuple of instantiated objects, corresponding to the input classes in order.
    """
    if not classes:
        return tuple()  # Return an empty tuple if no classes are provided

    parser = _create_argument_parser()

    # This map will store {cls: {init_param_name: argparse_dest_name}}
    class_arg_configs_maps: Dict[Type[CliConfigurable], Dict[str, str]] = {}

    for cls in classes:
        _add_arguments_for_class(parser, cls, class_arg_configs_maps)

    parsed_args = parser.parse_args()  # Uses sys.argv[1:] by default

    # Correctly handle single class case for return type matching overloads
    instances = _instantiate_classes(parsed_args, classes, class_arg_configs_maps)
    if len(classes) == 1:
        return instances[0]
    return instances

nullable_float(value)

Converts specific string values (case-insensitive) to None, otherwise returns the float.

Source code in agentlightning/config.py
def nullable_float(value: str) -> float | None:
    """Converts specific string values (case-insensitive) to None, otherwise returns the float."""
    if value.lower() in ["none", "null", "~", "nil"]:  # Define keywords for None
        return None
    try:
        return float(value)
    except ValueError:
        raise argparse.ArgumentTypeError(f"Invalid float value: '{value}'")

nullable_int(value)

Converts specific string values (case-insensitive) to None, otherwise returns the integer.

Source code in agentlightning/config.py
def nullable_int(value: str) -> int | None:
    """Converts specific string values (case-insensitive) to None, otherwise returns the integer."""
    if value.lower() in ["none", "null", "~", "nil"]:  # Define keywords for None
        return None
    try:
        return int(value)
    except ValueError:
        raise argparse.ArgumentTypeError(f"Invalid integer value: '{value}'")

nullable_str(value)

Converts specific string values (case-insensitive) to None, otherwise returns the string.

Source code in agentlightning/config.py
def nullable_str(value: str) -> str | None:
    """Converts specific string values (case-insensitive) to None, otherwise returns the string."""
    if value.lower() in ["none", "null", "~", "nil"]:  # Define keywords for None
        return None
    return value

agentlightning.types

NamedResources = Dict[str, ResourceUnion] module-attribute

A dictionary-like class to hold named resources.

Example

resources: NamedResources = { 'main_llm': LLM( endpoint="http://localhost:8080", model="llama3", sampling_parameters={'temperature': 0.7, 'max_tokens': 100} ), 'system_prompt': PromptTemplate( template="You are a helpful assistant.", engine='f-string' ) }

Attempt

Bases: BaseModel

An attempt to execute a rollout. A rollout can have multiple attempts if retries are needed.

Source code in agentlightning/types/core.py
class Attempt(BaseModel):
    """An attempt to execute a rollout. A rollout can have multiple attempts if retries are needed."""

    rollout_id: str  # the rollout this attempt belongs to
    attempt_id: str  # the universal id for current attempt
    sequence_id: int  # the sequence number of the attempt, starting from 1
    start_time: float  # time when the attempt has started
    end_time: Optional[float] = None  # time when the attempt has ended

    status: AttemptStatus = "preparing"
    # The rollout worker which is executing this attempt
    worker_id: Optional[str] = None

    last_heartbeat_time: Optional[float] = None  # last time when the worker has reported progress

    # A bucket for any other relevant information
    metadata: Optional[Dict[str, Any]] = None

AttemptedRollout

Bases: RolloutV2

A rollout along with its active attempt.

Source code in agentlightning/types/core.py
class AttemptedRollout(RolloutV2):
    """A rollout along with its active attempt."""

    attempt: Attempt

    @model_validator(mode="after")
    def check_consistency(self) -> AttemptedRollout:
        if self.attempt.rollout_id != self.rollout_id:
            raise ValueError("Inconsistent rollout_id between Rollout and Attempt")
        return self

Dataset

Bases: Protocol, Generic[T_co]

The general interface for a dataset.

It's currently implemented as a protocol, having a similar interface to torch.utils.data.Dataset. You don't have to inherit from this class; you can use a simple list if you want to.

Source code in agentlightning/types/core.py
class Dataset(Protocol, Generic[T_co]):
    """The general interface for a dataset.

    It's currently implemented as a protocol, having a similar interface to torch.utils.data.Dataset.
    You don't have to inherit from this class; you can use a simple list if you want to.
    """

    def __getitem__(self, index: int) -> T_co: ...

    def __len__(self) -> int: ...

Event

Bases: BaseModel

Corresponding to opentelemetry.trace.Event

Source code in agentlightning/types/tracer.py
class Event(BaseModel):
    """Corresponding to opentelemetry.trace.Event"""

    name: str
    attributes: Attributes
    timestamp: Optional[float] = None

    class Config:
        allow_extra = True

    @classmethod
    def from_opentelemetry(cls, src: OtelEvent) -> "Event":
        return cls(
            name=src.name,
            attributes=dict(src.attributes) if src.attributes else {},
            timestamp=convert_timestamp(src.timestamp),
            **extract_extra_fields(src, ["name", "attributes", "timestamp"]),
        )

GenericResponse

Bases: BaseModel

A generic response message that can be used for various purposes.

Source code in agentlightning/types/core.py
class GenericResponse(BaseModel):
    """
    A generic response message that can be used for various purposes.
    """

    status: str = "success"
    message: Optional[str] = None
    data: Optional[Dict[str, Any]] = None

Hook

Bases: ParallelWorkerBase

Base class for defining hooks in the agent runner's lifecycle.

Source code in agentlightning/types/core.py
class Hook(ParallelWorkerBase):
    """Base class for defining hooks in the agent runner's lifecycle."""

    async def on_trace_start(
        self, *, agent: LitAgent[Any], runner: BaseRunner[Any], tracer: BaseTracer, rollout: RolloutV2
    ) -> None:
        """Hook called immediately after the tracer enters the trace context but before the rollout begins.

        Args:
            agent: The :class:`LitAgent` instance associated with the runner.
            runner: The :class:`BaseRunner` managing the rollout.
            tracer: The :class:`BaseTracer` instance associated with the runner.
            rollout: The :class:`RolloutV2` object that will be processed.

        Subclasses can override this method to implement custom logic such as logging,
        metric collection, or resource setup. By default, this is a no-op.
        """

    async def on_trace_end(
        self, *, agent: LitAgent[Any], runner: BaseRunner[Any], tracer: BaseTracer, rollout: RolloutV2
    ) -> None:
        """Hook called immediately after the rollout completes but before the tracer exits the trace context.

        Args:
            agent: The :class:`LitAgent` instance associated with the runner.
            runner: The :class:`BaseRunner` managing the rollout.
            tracer: The :class:`BaseTracer` instance associated with the runner.
            rollout: The :class:`RolloutV2` object that has been processed.

        Subclasses can override this method to implement custom logic such as logging,
        metric collection, or resource cleanup. By default, this is a no-op.
        """

    async def on_rollout_start(self, *, agent: LitAgent[Any], runner: BaseRunner[Any], rollout: RolloutV2) -> None:
        """Hook called immediately before a rollout *attempt* begins.

        Args:
            agent: The :class:`LitAgent` instance associated with the runner.
            runner: The :class:`BaseRunner` managing the rollout.
            rollout: The :class:`RolloutV2` object that will be processed.

        Subclasses can override this method to implement custom logic such as
        logging, metric collection, or resource setup. By default, this is a
        no-op.
        """

    async def on_rollout_end(
        self,
        *,
        agent: LitAgent[Any],
        runner: BaseRunner[Any],
        rollout: RolloutV2,
        spans: Union[List[ReadableSpan], List[Span]],
    ) -> None:
        """Hook called after a rollout *attempt* completes.

        Args:
            agent: The :class:`LitAgent` instance associated with the runner.
            runner: The :class:`BaseRunner` managing the rollout.
            rollout: The :class:`RolloutV2` object that has been processed.
            spans: The spans that have been added to the store.

        Subclasses can override this method for cleanup or additional
        logging. By default, this is a no-op.
        """

on_rollout_end(*, agent, runner, rollout, spans) async

Hook called after a rollout attempt completes.

Parameters:

Name Type Description Default
agent LitAgent[Any]

The :class:LitAgent instance associated with the runner.

required
runner BaseRunner[Any]

The :class:BaseRunner managing the rollout.

required
rollout RolloutV2

The :class:RolloutV2 object that has been processed.

required
spans Union[List[ReadableSpan], List[Span]]

The spans that have been added to the store.

required

Subclasses can override this method for cleanup or additional logging. By default, this is a no-op.

Source code in agentlightning/types/core.py
async def on_rollout_end(
    self,
    *,
    agent: LitAgent[Any],
    runner: BaseRunner[Any],
    rollout: RolloutV2,
    spans: Union[List[ReadableSpan], List[Span]],
) -> None:
    """Hook called after a rollout *attempt* completes.

    Args:
        agent: The :class:`LitAgent` instance associated with the runner.
        runner: The :class:`BaseRunner` managing the rollout.
        rollout: The :class:`RolloutV2` object that has been processed.
        spans: The spans that have been added to the store.

    Subclasses can override this method for cleanup or additional
    logging. By default, this is a no-op.
    """

on_rollout_start(*, agent, runner, rollout) async

Hook called immediately before a rollout attempt begins.

Parameters:

Name Type Description Default
agent LitAgent[Any]

The :class:LitAgent instance associated with the runner.

required
runner BaseRunner[Any]

The :class:BaseRunner managing the rollout.

required
rollout RolloutV2

The :class:RolloutV2 object that will be processed.

required

Subclasses can override this method to implement custom logic such as logging, metric collection, or resource setup. By default, this is a no-op.

Source code in agentlightning/types/core.py
async def on_rollout_start(self, *, agent: LitAgent[Any], runner: BaseRunner[Any], rollout: RolloutV2) -> None:
    """Hook called immediately before a rollout *attempt* begins.

    Args:
        agent: The :class:`LitAgent` instance associated with the runner.
        runner: The :class:`BaseRunner` managing the rollout.
        rollout: The :class:`RolloutV2` object that will be processed.

    Subclasses can override this method to implement custom logic such as
    logging, metric collection, or resource setup. By default, this is a
    no-op.
    """

on_trace_end(*, agent, runner, tracer, rollout) async

Hook called immediately after the rollout completes but before the tracer exits the trace context.

Parameters:

Name Type Description Default
agent LitAgent[Any]

The :class:LitAgent instance associated with the runner.

required
runner BaseRunner[Any]

The :class:BaseRunner managing the rollout.

required
tracer BaseTracer

The :class:BaseTracer instance associated with the runner.

required
rollout RolloutV2

The :class:RolloutV2 object that has been processed.

required

Subclasses can override this method to implement custom logic such as logging, metric collection, or resource cleanup. By default, this is a no-op.

Source code in agentlightning/types/core.py
async def on_trace_end(
    self, *, agent: LitAgent[Any], runner: BaseRunner[Any], tracer: BaseTracer, rollout: RolloutV2
) -> None:
    """Hook called immediately after the rollout completes but before the tracer exits the trace context.

    Args:
        agent: The :class:`LitAgent` instance associated with the runner.
        runner: The :class:`BaseRunner` managing the rollout.
        tracer: The :class:`BaseTracer` instance associated with the runner.
        rollout: The :class:`RolloutV2` object that has been processed.

    Subclasses can override this method to implement custom logic such as logging,
    metric collection, or resource cleanup. By default, this is a no-op.
    """

on_trace_start(*, agent, runner, tracer, rollout) async

Hook called immediately after the tracer enters the trace context but before the rollout begins.

Parameters:

Name Type Description Default
agent LitAgent[Any]

The :class:LitAgent instance associated with the runner.

required
runner BaseRunner[Any]

The :class:BaseRunner managing the rollout.

required
tracer BaseTracer

The :class:BaseTracer instance associated with the runner.

required
rollout RolloutV2

The :class:RolloutV2 object that will be processed.

required

Subclasses can override this method to implement custom logic such as logging, metric collection, or resource setup. By default, this is a no-op.

Source code in agentlightning/types/core.py
async def on_trace_start(
    self, *, agent: LitAgent[Any], runner: BaseRunner[Any], tracer: BaseTracer, rollout: RolloutV2
) -> None:
    """Hook called immediately after the tracer enters the trace context but before the rollout begins.

    Args:
        agent: The :class:`LitAgent` instance associated with the runner.
        runner: The :class:`BaseRunner` managing the rollout.
        tracer: The :class:`BaseTracer` instance associated with the runner.
        rollout: The :class:`RolloutV2` object that will be processed.

    Subclasses can override this method to implement custom logic such as logging,
    metric collection, or resource setup. By default, this is a no-op.
    """

LLM

Bases: Resource

Provide an LLM endpoint and model name as a resource.

Attributes:

Name Type Description
endpoint str

The URL of the LLM API endpoint.

model str

The identifier for the model to be used (e.g., 'gpt-4o').

sampling_parameters SamplingParameters

A dictionary of hyperparameters for model inference, such as temperature, top_p, etc.

Source code in agentlightning/types/resources.py
class LLM(Resource):
    """
    Provide an LLM endpoint and model name as a resource.

    Attributes:
        endpoint (str): The URL of the LLM API endpoint.
        model (str): The identifier for the model to be used (e.g., 'gpt-4o').
        sampling_parameters (SamplingParameters): A dictionary of hyperparameters
            for model inference, such as temperature, top_p, etc.
    """

    resource_type: Literal["llm"] = "llm"
    endpoint: str
    model: str
    api_key: Optional[str] = None
    sampling_parameters: Dict[str, Any] = Field(default_factory=dict)

    def get_base_url(self, *args: Any, **kwargs: Any) -> str:
        """The base_url to put into openai.OpenAI.

        Users are encouraged to use `base_url` to get the LLM endpoint instead of accessing `endpoint` directly.
        """
        return self.endpoint

get_base_url(*args, **kwargs)

The base_url to put into openai.OpenAI.

Users are encouraged to use base_url to get the LLM endpoint instead of accessing endpoint directly.

Source code in agentlightning/types/resources.py
def get_base_url(self, *args: Any, **kwargs: Any) -> str:
    """The base_url to put into openai.OpenAI.

    Users are encouraged to use `base_url` to get the LLM endpoint instead of accessing `endpoint` directly.
    """
    return self.endpoint

Bases: BaseModel

Corresponding to opentelemetry.trace.Link

Source code in agentlightning/types/tracer.py
class Link(BaseModel):
    """Corresponding to opentelemetry.trace.Link"""

    context: SpanContext
    attributes: Optional[Attributes] = None

    class Config:
        allow_extra = True

    @classmethod
    def from_opentelemetry(cls, src: trace_api.Link) -> "Link":
        return cls(
            context=SpanContext.from_opentelemetry(src.context),
            attributes=dict(src.attributes) if src.attributes else None,
            **extract_extra_fields(src, ["context", "attributes"]),
        )

ParallelWorkerBase

Base class for objects that can be parallelized across multiple worker processes.

This class defines the standard lifecycle for parallel processing:

Main Process
  1. init() - Initialize the object in the main process
  2. spawn workers and call init_worker() in each worker
  3. run() - Execute the main workload in parallel across workers
  4. teardown_worker() - Clean up resources in each worker
  5. teardown() - Final cleanup in the main process

Subclasses should implement the run() method and optionally override the lifecycle methods for custom initialization and cleanup behavior.

Source code in agentlightning/types/core.py
class ParallelWorkerBase:
    """Base class for objects that can be parallelized across multiple worker processes.

    This class defines the standard lifecycle for parallel processing:

    Main Process:
        1. init() - Initialize the object in the main process
        2. spawn workers and call init_worker() in each worker
        3. run() - Execute the main workload in parallel across workers
        4. teardown_worker() - Clean up resources in each worker
        5. teardown() - Final cleanup in the main process

    Subclasses should implement the run() method and optionally override
    the lifecycle methods for custom initialization and cleanup behavior.
    """

    def __init__(self) -> None:
        """Initialize the base class. This method can be overridden by subclasses."""
        self.worker_id: Optional[int] = None

    def init(self, *args: Any, **kwargs: Any) -> None:
        pass

    def init_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
        self.worker_id = worker_id

    def run(self, *args: Any, **kwargs: Any) -> Any:
        pass

    def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
        pass

    def teardown(self, *args: Any, **kwargs: Any) -> None:
        pass

__init__()

Initialize the base class. This method can be overridden by subclasses.

Source code in agentlightning/types/core.py
def __init__(self) -> None:
    """Initialize the base class. This method can be overridden by subclasses."""
    self.worker_id: Optional[int] = None

PromptTemplate

Bases: Resource

A prompt template as a resource.

Attributes:

Name Type Description
template str

The template string. The format depends on the engine.

engine Literal['jinja', 'f-string', 'poml']

The templating engine to use for rendering the prompt. I imagine users can use their own customized engines, but algos can only well operate on a subset of them.

Source code in agentlightning/types/resources.py
class PromptTemplate(Resource):
    """
    A prompt template as a resource.

    Attributes:
        template (str): The template string. The format depends on the engine.
        engine (Literal['jinja', 'f-string', 'poml']): The templating engine
            to use for rendering the prompt. I imagine users can use their own
            customized engines, but algos can only well operate on a subset of them.
    """

    resource_type: Literal["prompt_template"] = "prompt_template"
    template: str
    engine: Literal["jinja", "f-string", "poml"]

    def format(self, **kwargs: Any) -> str:
        """Format the prompt template with the given kwargs."""
        if self.engine == "f-string":
            return self.template.format(**kwargs)
        else:
            raise NotImplementedError(
                "Formatting prompt templates for non-f-string engines with format() helper is not supported yet."
            )

format(**kwargs)

Format the prompt template with the given kwargs.

Source code in agentlightning/types/resources.py
def format(self, **kwargs: Any) -> str:
    """Format the prompt template with the given kwargs."""
    if self.engine == "f-string":
        return self.template.format(**kwargs)
    else:
        raise NotImplementedError(
            "Formatting prompt templates for non-f-string engines with format() helper is not supported yet."
        )

ProxyLLM

Bases: LLM

Proxy LLM resource that is tailored by llm_proxy.LLMProxy.

Source code in agentlightning/types/resources.py
class ProxyLLM(LLM):
    """Proxy LLM resource that is tailored by `llm_proxy.LLMProxy`."""

    resource_type: Literal["proxy_llm"] = "proxy_llm"  # type: ignore
    _initialized: bool = False

    def model_post_init(self, __context: Any) -> None:
        """Mark initialization as complete after Pydantic finishes setup."""
        super().model_post_init(__context)
        object.__setattr__(self, "_initialized", True)

    def __getattribute__(self, name: str) -> Any:
        """Override to emit a warning when endpoint is accessed directly."""
        # Check if we're accessing endpoint after initialization and not from base_url
        if name == "endpoint":
            try:
                initialized = object.__getattribute__(self, "_initialized")
            except AttributeError:
                initialized = False

            if initialized:
                # Check the call stack to see if we're being called from base_url
                frame = inspect.currentframe()
                if frame and frame.f_back:
                    caller_name = frame.f_back.f_code.co_name
                    if caller_name != "get_base_url":
                        logger.warning(
                            "Accessing 'endpoint' directly on ProxyLLM is discouraged. "
                            "Use 'get_base_url(rollout_id, attempt_id)' instead to get the properly formatted endpoint."
                        )
        return super().__getattribute__(name)

    def with_attempted_rollout(self, rollout: AttemptedRollout) -> LLM:
        """Bake the rollout and attempt id into the endpoint."""
        return LLM(
            endpoint=self.get_base_url(rollout.rollout_id, rollout.attempt.attempt_id),
            model=self.model,
            sampling_parameters=self.sampling_parameters,
            api_key=self.api_key,
        )

    def get_base_url(self, rollout_id: Optional[str], attempt_id: Optional[str]) -> str:
        if rollout_id is None and attempt_id is None:
            return self.endpoint

        if not (isinstance(rollout_id, str) and isinstance(attempt_id, str)):
            raise ValueError("rollout_id and attempt_id must be strings or all be empty")

        prefix = self.endpoint
        if prefix.endswith("/"):
            prefix = prefix[:-1]
        if prefix.endswith("/v1"):
            prefix = prefix[:-3]
            has_v1 = True
        else:
            has_v1 = False
        # Now the prefix should look like "http://localhost:11434"

        # Append the rollout and attempt id to the prefix
        prefix = prefix + f"/rollout/{rollout_id}/attempt/{attempt_id}"
        if has_v1:
            prefix = prefix + "/v1"
        return prefix

__getattribute__(name)

Override to emit a warning when endpoint is accessed directly.

Source code in agentlightning/types/resources.py
def __getattribute__(self, name: str) -> Any:
    """Override to emit a warning when endpoint is accessed directly."""
    # Check if we're accessing endpoint after initialization and not from base_url
    if name == "endpoint":
        try:
            initialized = object.__getattribute__(self, "_initialized")
        except AttributeError:
            initialized = False

        if initialized:
            # Check the call stack to see if we're being called from base_url
            frame = inspect.currentframe()
            if frame and frame.f_back:
                caller_name = frame.f_back.f_code.co_name
                if caller_name != "get_base_url":
                    logger.warning(
                        "Accessing 'endpoint' directly on ProxyLLM is discouraged. "
                        "Use 'get_base_url(rollout_id, attempt_id)' instead to get the properly formatted endpoint."
                    )
    return super().__getattribute__(name)

model_post_init(__context)

Mark initialization as complete after Pydantic finishes setup.

Source code in agentlightning/types/resources.py
def model_post_init(self, __context: Any) -> None:
    """Mark initialization as complete after Pydantic finishes setup."""
    super().model_post_init(__context)
    object.__setattr__(self, "_initialized", True)

with_attempted_rollout(rollout)

Bake the rollout and attempt id into the endpoint.

Source code in agentlightning/types/resources.py
def with_attempted_rollout(self, rollout: AttemptedRollout) -> LLM:
    """Bake the rollout and attempt id into the endpoint."""
    return LLM(
        endpoint=self.get_base_url(rollout.rollout_id, rollout.attempt.attempt_id),
        model=self.model,
        sampling_parameters=self.sampling_parameters,
        api_key=self.api_key,
    )

Resource

Bases: BaseModel

Corresponding to opentelemetry.sdk.resources.Resource

Source code in agentlightning/types/tracer.py
class Resource(BaseModel):
    """Corresponding to opentelemetry.sdk.resources.Resource"""

    attributes: Attributes
    schema_url: str

    @classmethod
    def from_opentelemetry(cls, src: OtelResource) -> "Resource":
        return cls(
            attributes=dict(src.attributes) if src.attributes else {},
            schema_url=src.schema_url if src.schema_url else "",
            **extract_extra_fields(src, ["attributes", "schema_url"]),
        )

ResourcesUpdate

Bases: BaseModel

A resource update message to be sent from the server to clients.

This message contains a dictionary of resources that clients should use for subsequent tasks. It is used to update the resources available to clients dynamically.

Source code in agentlightning/types/resources.py
class ResourcesUpdate(BaseModel):
    """
    A resource update message to be sent from the server to clients.

    This message contains a dictionary of resources that clients should use
    for subsequent tasks. It is used to update the resources available to
    clients dynamically.
    """

    resources_id: str
    resources: NamedResources

Rollout

Bases: BaseModel

The standard reporting object from client to server.

Source code in agentlightning/types/core.py
class Rollout(BaseModel):
    """The standard reporting object from client to server."""

    rollout_id: str

    # Echoing the input task
    task: Optional[Task] = None

    # Primary, high-level feedback
    final_reward: Optional[float] = None

    # Structured, sequential feedback for RL-style optimization
    triplets: Optional[List[Triplet]] = None

    # Optional, rich-context data for deep analysis
    trace: Optional[List[Dict[str, Any]]] = Field(
        default=None,
        description="A list of spans that conform to the OpenTelemetry JSON format. "
        "Users of the opentelemetry-sdk can generate this by calling "
        "json.loads(readable_span.to_json()).",
    )
    logs: Optional[List[str]] = None

    # A bucket for any other relevant information
    metadata: Dict[str, Any] = Field(default_factory=dict)

RolloutConfig

Bases: BaseModel

Configurations for rollout execution.

Source code in agentlightning/types/core.py
class RolloutConfig(BaseModel):
    """Configurations for rollout execution."""

    timeout_seconds: Optional[float] = None  # none indicates no timeout
    unresponsive_seconds: Optional[float] = None  # none indicates no unresponsive timeout
    max_attempts: int = Field(default=1, ge=1)  # including the first attempt
    retry_condition: List[AttemptStatus] = Field(
        default_factory=cast(Callable[[], List[AttemptStatus]], list)
    )  # list of statuses that should trigger a retry

SpanAttributeNames

Bases: str, Enum

Standard attribute names for AgentLightning spans.

Source code in agentlightning/types/tracer.py
class SpanAttributeNames(str, Enum):
    """Standard attribute names for AgentLightning spans."""

    MESSAGE = "message"
    OBJECT = "object"

SpanContext

Bases: BaseModel

Corresponding to opentelemetry.trace.SpanContext

Source code in agentlightning/types/tracer.py
class SpanContext(BaseModel):
    """Corresponding to opentelemetry.trace.SpanContext"""

    trace_id: str
    span_id: str
    is_remote: bool
    trace_state: TraceState

    class Config:
        allow_extra = True

    @classmethod
    def from_opentelemetry(cls, src: trace_api.SpanContext) -> "SpanContext":
        return cls(
            trace_id=trace_api.format_trace_id(src.trace_id),
            span_id=trace_api.format_span_id(src.span_id),
            is_remote=src.is_remote,
            trace_state={k: v for k, v in src.trace_state.items()} if src.trace_state else {},
            **extract_extra_fields(src, ["trace_id", "span_id", "is_remote", "trace_state"]),
        )

SpanNames

Bases: str, Enum

Standard span name values for AgentLightning.

Currently reward, message, object and exception spans are supported. We will add more spans related to error handling in the future.

Source code in agentlightning/types/tracer.py
class SpanNames(str, Enum):
    """Standard span name values for AgentLightning.

    Currently reward, message, object and exception spans are supported.
    We will add more spans related to error handling in the future.
    """

    REWARD = "agentlightning.reward"
    MESSAGE = "agentlightning.message"
    OBJECT = "agentlightning.object"
    EXCEPTION = "agentlightning.exception"
    VIRTUAL = "agentlightning.virtual"

Task

Bases: BaseModel

A task (rollout request) to be processed by the client agent.

Source code in agentlightning/types/core.py
class Task(BaseModel):
    """A task (rollout request) to be processed by the client agent."""

    rollout_id: str
    input: TaskInput

    mode: Optional[RolloutMode] = None
    resources_id: Optional[str] = None

    # Optional fields for tracking task lifecycle
    create_time: Optional[float] = None
    last_claim_time: Optional[float] = None
    num_claims: Optional[int] = None

    # Allow additional metadata fields
    metadata: Dict[str, Any] = Field(default_factory=dict)

TraceStatus

Bases: BaseModel

Corresponding to opentelemetry.trace.Status

Source code in agentlightning/types/tracer.py
class TraceStatus(BaseModel):
    """Corresponding to opentelemetry.trace.Status"""

    status_code: str
    description: Optional[str] = None

    class Config:
        allow_extra = True

    @classmethod
    def from_opentelemetry(cls, src: OtelStatus) -> "TraceStatus":
        return cls(
            status_code=src.status_code.name,
            description=src.description,
            **extract_extra_fields(src, ["status_code", "description"]),
        )

Triplet

Bases: BaseModel

A standard structure for a single turn in a trajectory.

Source code in agentlightning/types/core.py
class Triplet(BaseModel):
    """A standard structure for a single turn in a trajectory."""

    prompt: Any
    response: Any
    reward: Optional[float] = None
    metadata: Dict[str, Any] = Field(default_factory=dict)

agentlightning.logging

agentlightning.instrumentation