Skip to content

Reinforcement Learning API

agentlightning.verl

NamedResources = Dict[str, ResourceUnion] module-attribute

A dictionary-like class to hold named resources.

Example

resources: NamedResources = { 'main_llm': LLM( endpoint="http://localhost:8080", model="llama3", sampling_parameters={'temperature': 0.7, 'max_tokens': 100} ), 'system_prompt': PromptTemplate( template="You are a helpful assistant.", engine='f-string' ) }

AgentLightningServer

The main SDK class for developers to control the Agent Lightning Server.

This class manages the server lifecycle, task queueing, resources updates, and retrieval of results, providing a simple interface for the optimization logic.

Source code in agentlightning/server.py
class AgentLightningServer:
    """
    The main SDK class for developers to control the Agent Lightning Server.

    This class manages the server lifecycle, task queueing, resources updates,
    and retrieval of results, providing a simple interface for the optimization logic.
    """

    def __init__(self, host: str = "127.0.0.1", port: int = 8000, task_timeout_seconds: float = 300.0):
        """
        Initializes the server controller.

        Args:
            host: The host to bind the server to.
            port: The port to bind the server to.
            task_timeout_seconds: Time in seconds after which a claimed task is considered stale and requeued.
        """
        self.host = host
        self.port = port
        self.endpoint = f"http://{host}:{port}"
        self._task_timeout_seconds = task_timeout_seconds

        # Defer initialization and use event for cross-thread communication
        self._store: Optional[ServerDataStore] = None
        self.loop: Optional[asyncio.AbstractEventLoop] = None
        self.startup_event = threading.Event()

        # Create FastAPI app instance with a lifespan manager
        self._app = FastAPI(lifespan=self._lifespan)
        self._setup_routes()

        self._uvicorn_config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info")
        self._uvicorn_server = uvicorn.Server(self._uvicorn_config)

    # --- ADDED: Lifespan context manager ---
    @asynccontextmanager
    async def _lifespan(self, app: FastAPI):
        """
        Manages server startup and shutdown. This runs inside the server's event loop.
        """
        logger.info("Server is starting up...")
        self.loop = asyncio.get_running_loop()
        self._store = ServerDataStore()  # Initialize data store here
        self.startup_event.set()  # Signal that the server is ready

        yield

        logger.info("Server is shutting down.")
        self._store = None
        self.startup_event.clear()  # Clear the startup event
        self.loop = None

    async def _check_and_requeue_stale_tasks(self):
        """
        Check for stale tasks and requeue them. Called reactively during get_next_task.
        """
        current_time = time.time()
        # Ensure store is initialized before checking
        if not self._store:
            return
        processing_tasks = self._store.get_processing_tasks()

        for _, task in processing_tasks.items():
            if task.last_claim_time and current_time - task.last_claim_time > self._task_timeout_seconds:
                await self._store.requeue_task(task)
                logger.warning(
                    f"Task {task.rollout_id} timed out after {self._task_timeout_seconds}s, requeued (attempt {task.num_claims})"
                )

    def _setup_routes(self):
        """Setup FastAPI routes."""

        @self._app.get("/task", response_model=TaskIfAny)
        async def next_task() -> TaskIfAny:  # type: ignore
            """Endpoint for clients to poll for the next available task."""
            await self._check_and_requeue_stale_tasks()

            if not self._store:
                return TaskIfAny(is_available=False)

            task = await self._store.get_next_task()
            if task:
                logger.debug(f"Serving task {task.rollout_id} to a client.")
                return TaskIfAny(is_available=True, task=task)
            else:
                logger.debug("No task available for client.")
                return TaskIfAny(is_available=False)

        @self._app.get("/resources/latest", response_model=ResourcesUpdate)
        async def fetch_latest_resources() -> ResourcesUpdate:  # type: ignore
            """Endpoint for clients to poll for the latest available resources."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            resources_update = await self._store.get_latest_resources()
            if not resources_update:
                raise HTTPException(status_code=404, detail="No resources have been set on the server.")
            logger.debug(f"Serving latest resources '{resources_update.resources_id}' to a client.")
            return resources_update

        @self._app.get("/resources/{resource_id}", response_model=ResourcesUpdate)
        async def fetch_resources_by_id(  # type: ignore
            resource_id: str = Path(..., description="The unique identifier for the resource version.")
        ) -> ResourcesUpdate:
            """Endpoint for clients to fetch a specific version of resources."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            resources_update = await self._store.get_resources_by_id(resource_id)
            if not resources_update:
                raise HTTPException(status_code=404, detail=f"Resource ID '{resource_id}' not found.")
            logger.debug(f"Serving resources for ID '{resource_id}' to a client.")
            return resources_update

        @self._app.post("/rollout", response_model=GenericResponse)
        async def post_rollout(payload: Rollout) -> GenericResponse:  # type: ignore
            """Endpoint for clients to report a completed rollout."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            await self._store.store_rollout(payload)
            return GenericResponse(
                status="ok",
                message=f"Rollout {payload.rollout_id} received and stored.",
            )

    async def start(self):
        """Starts the FastAPI server in the background."""
        logger.info(f"Starting server at {self.endpoint}")
        asyncio.create_task(self._uvicorn_server.serve())
        await asyncio.sleep(1)  # Allow time for server to start up.

    async def stop(self):
        """Gracefully stops the running FastAPI server."""
        if self._uvicorn_server.started:
            logger.info("Stopping server...")
            self._uvicorn_server.should_exit = True
            await asyncio.sleep(1)  # Allow time for graceful shutdown.
            logger.info("Server stopped.")

    async def run_forever(self):
        """
        Runs the server indefinitely until stopped.
        This is useful when async start and stop methods do not work.
        """
        await self._uvicorn_server.serve()

    async def queue_task(
        self,
        sample: Any,
        mode: Literal["train", "val", "test"] | None = None,
        resources_id: str | None = None,
        metadata: Dict[str, Any] | None = None,
    ) -> str:
        """
        Adds a task to the queue for a client to process.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata)

    async def update_resources(self, resources: NamedResources) -> str:
        """
        Updates the resources, creating a new version and setting it as the latest.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        resources_id = f"res-{uuid.uuid4()}"
        update = ResourcesUpdate(resources_id=resources_id, resources=resources)
        await self._store.update_resources(update)
        return resources_id

    async def get_completed_rollout(self, rollout_id: str) -> Optional[Rollout]:
        """
        Retrieves a specific completed rollout by its ID.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.retrieve_rollout(rollout_id)

    async def poll_completed_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
        """
        Polls for a completed rollout by its ID, waiting up to `timeout` seconds.
        """
        start_time = time.time()
        while True:
            rollout = await self.get_completed_rollout(rollout_id)
            if rollout:
                return rollout
            if timeout and (time.time() - start_time) >= timeout:
                return None
            await asyncio.sleep(1)

    async def retrieve_completed_rollouts(self) -> List[Rollout]:
        """
        Retrieves all available completed trajectories and clears the internal store.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.retrieve_completed_rollouts()

__init__(host='127.0.0.1', port=8000, task_timeout_seconds=300.0)

Initializes the server controller.

Parameters:

Name Type Description Default
host str

The host to bind the server to.

'127.0.0.1'
port int

The port to bind the server to.

8000
task_timeout_seconds float

Time in seconds after which a claimed task is considered stale and requeued.

300.0
Source code in agentlightning/server.py
def __init__(self, host: str = "127.0.0.1", port: int = 8000, task_timeout_seconds: float = 300.0):
    """
    Initializes the server controller.

    Args:
        host: The host to bind the server to.
        port: The port to bind the server to.
        task_timeout_seconds: Time in seconds after which a claimed task is considered stale and requeued.
    """
    self.host = host
    self.port = port
    self.endpoint = f"http://{host}:{port}"
    self._task_timeout_seconds = task_timeout_seconds

    # Defer initialization and use event for cross-thread communication
    self._store: Optional[ServerDataStore] = None
    self.loop: Optional[asyncio.AbstractEventLoop] = None
    self.startup_event = threading.Event()

    # Create FastAPI app instance with a lifespan manager
    self._app = FastAPI(lifespan=self._lifespan)
    self._setup_routes()

    self._uvicorn_config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info")
    self._uvicorn_server = uvicorn.Server(self._uvicorn_config)

get_completed_rollout(rollout_id) async

Retrieves a specific completed rollout by its ID.

Source code in agentlightning/server.py
async def get_completed_rollout(self, rollout_id: str) -> Optional[Rollout]:
    """
    Retrieves a specific completed rollout by its ID.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.retrieve_rollout(rollout_id)

poll_completed_rollout(rollout_id, timeout=None) async

Polls for a completed rollout by its ID, waiting up to timeout seconds.

Source code in agentlightning/server.py
async def poll_completed_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
    """
    Polls for a completed rollout by its ID, waiting up to `timeout` seconds.
    """
    start_time = time.time()
    while True:
        rollout = await self.get_completed_rollout(rollout_id)
        if rollout:
            return rollout
        if timeout and (time.time() - start_time) >= timeout:
            return None
        await asyncio.sleep(1)

queue_task(sample, mode=None, resources_id=None, metadata=None) async

Adds a task to the queue for a client to process.

Source code in agentlightning/server.py
async def queue_task(
    self,
    sample: Any,
    mode: Literal["train", "val", "test"] | None = None,
    resources_id: str | None = None,
    metadata: Dict[str, Any] | None = None,
) -> str:
    """
    Adds a task to the queue for a client to process.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata)

retrieve_completed_rollouts() async

Retrieves all available completed trajectories and clears the internal store.

Source code in agentlightning/server.py
async def retrieve_completed_rollouts(self) -> List[Rollout]:
    """
    Retrieves all available completed trajectories and clears the internal store.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.retrieve_completed_rollouts()

run_forever() async

Runs the server indefinitely until stopped. This is useful when async start and stop methods do not work.

Source code in agentlightning/server.py
async def run_forever(self):
    """
    Runs the server indefinitely until stopped.
    This is useful when async start and stop methods do not work.
    """
    await self._uvicorn_server.serve()

start() async

Starts the FastAPI server in the background.

Source code in agentlightning/server.py
async def start(self):
    """Starts the FastAPI server in the background."""
    logger.info(f"Starting server at {self.endpoint}")
    asyncio.create_task(self._uvicorn_server.serve())
    await asyncio.sleep(1)  # Allow time for server to start up.

stop() async

Gracefully stops the running FastAPI server.

Source code in agentlightning/server.py
async def stop(self):
    """Gracefully stops the running FastAPI server."""
    if self._uvicorn_server.started:
        logger.info("Stopping server...")
        self._uvicorn_server.should_exit = True
        await asyncio.sleep(1)  # Allow time for graceful shutdown.
        logger.info("Server stopped.")

update_resources(resources) async

Updates the resources, creating a new version and setting it as the latest.

Source code in agentlightning/server.py
async def update_resources(self, resources: NamedResources) -> str:
    """
    Updates the resources, creating a new version and setting it as the latest.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    resources_id = f"res-{uuid.uuid4()}"
    update = ResourcesUpdate(resources_id=resources_id, resources=resources)
    await self._store.update_resources(update)
    return resources_id

AgentLightningTrainer

Bases: RayPPOTrainer

Specialized PPO trainer for agent-based reinforcement learning.

This trainer is designed specifically for scenarios where the model interacts with external environments, tools, or APIs through an AgentLightningServer. It simplifies the training loop by removing the complex conditional logic present in the original RayPPOTrainer and focusing on the agent mode workflow.

Key differences from RayPPOTrainer: 1. Uses AgentModeDaemon for server communication 2. Simplified data flow without pop/union operations 3. Direct batch processing through agent daemon 4. Streamlined validation using agent_mode validation

Source code in agentlightning/verl/trainer.py
class AgentLightningTrainer(RayPPOTrainer):
    """
    Specialized PPO trainer for agent-based reinforcement learning.

    This trainer is designed specifically for scenarios where the model interacts with
    external environments, tools, or APIs through an AgentLightningServer. It simplifies
    the training loop by removing the complex conditional logic present in the original
    RayPPOTrainer and focusing on the agent mode workflow.

    Key differences from RayPPOTrainer:
    1. Uses AgentModeDaemon for server communication
    2. Simplified data flow without pop/union operations
    3. Direct batch processing through agent daemon
    4. Streamlined validation using agent_mode validation
    """

    def __init__(
        self, store: LightningStore | None, llm_proxy: LLMProxy | None, adapter: TraceAdapter | None, **kwargs
    ):
        super().__init__(**kwargs)
        self.store = store
        self.llm_proxy = llm_proxy
        self.adapter = adapter

    def _validate(self):
        assert len(self.val_dataloader) == 1, "Please set val_batch_size to None for better throughput."

        test_data = next(iter(self.val_dataloader))
        test_batch = DataProto.from_single_dict(test_data)

        self.async_rollout_manager.wake_up()
        self.agent_mode_daemon.set_up_data_and_server(
            test_batch.non_tensor_batch,
            self.async_rollout_manager.server_addresses,
            is_train=False,
        )
        self.agent_mode_daemon.run_until_all_finished()
        test_metrics = self.agent_mode_daemon.get_test_metrics()
        self.agent_mode_daemon.clear_data_and_server()
        self.async_rollout_manager.sleep()
        return test_metrics

    def _train_step(self, batch_dict: dict) -> dict:
        # Isolate in a separate method to automatically recycle the variables before validation.
        batch: DataProto = DataProto.from_single_dict(batch_dict)
        metrics = {}
        timing_raw = {}

        with _timer("step", timing_raw):

            # When agent mode is enabled, we read the batch as it is.
            gen_batch = batch

            # generate a batch
            with _timer("gen", timing_raw):
                self.async_rollout_manager.wake_up()
                self.agent_mode_daemon.set_up_data_and_server(
                    gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
                )
                self.agent_mode_daemon.run_until_all_finished()
                batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
                    max_prompt_length=self.config.data.max_prompt_length,
                    max_response_length=self.config.data.max_response_length,
                    device=gen_batch.batch["fake_ids"].device,
                )
                metrics.update(agent_metrics)
                self.agent_mode_daemon.clear_data_and_server()
                self.async_rollout_manager.sleep()

            if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                with _timer("gen_max", timing_raw):
                    gen_baseline_batch = deepcopy(gen_batch)
                    gen_baseline_batch.meta_info["do_sample"] = False
                    gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)

                    batch = batch.union(gen_baseline_output)
                    reward_baseline_tensor = self.reward_fn(batch)
                    reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                    batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                    batch.batch["reward_baselines"] = reward_baseline_tensor

                    del gen_baseline_batch, gen_baseline_output

            # uid is used for algorithm like GRPO, should be aligned to data id
            batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]

            batch.batch["response_mask"] = compute_response_mask(batch)

            # compute global_valid tokens
            batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

            with _timer("reward", timing_raw):
                # compute reward model score
                if self.use_rm:
                    reward_tensor = self.rm_wg.compute_rm_score(batch)
                    batch = batch.union(reward_tensor)

                reward_extra_infos_dict = {}

            # for agent mode, pad the lengths to calculate old log prob, ref, and values
            batch, pad_size = pad_dataproto_to_divisor(batch, self.actor_rollout_wg.world_size)

            # recompute old_log_probs
            with _timer("old_log_prob", timing_raw):
                old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                entropys = old_log_prob.batch["entropys"]
                response_masks = batch.batch["response_mask"]
                loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
                old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
                metrics.update(old_log_prob_metrics)
                old_log_prob.batch.pop("entropys")
                batch = batch.union(old_log_prob)

            if self.use_reference_policy:
                # compute reference log_prob
                with _timer("ref", timing_raw):
                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                    batch = batch.union(ref_log_prob)

            # compute values
            if self.use_critic:
                with _timer("values", timing_raw):
                    values = self.critic_wg.compute_values(batch)
                    batch = batch.union(values)

            # for agent mode, unpad to calculate adv
            # it is important, as adv should be based on the raw traces
            batch = unpad_dataproto(batch, pad_size=pad_size)

            with _timer("adv", timing_raw):
                # if agent_mode is enabled, there is already token_level_scores
                # token_level_scores is not needed to compute here

                # compute rewards. apply_kl_penalty if available
                if self.config.algorithm.use_kl_in_reward:
                    batch, kl_metrics = apply_kl_penalty(
                        batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
                    )
                    metrics.update(kl_metrics)
                else:
                    batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                # compute advantages, executed on the driver process

                norm_adv_by_std_in_grpo = self.config.algorithm.get(
                    "norm_adv_by_std_in_grpo", True
                )  # GRPO adv normalization factor

                batch = compute_advantage(
                    batch,
                    adv_estimator=self.config.algorithm.adv_estimator,
                    gamma=self.config.algorithm.gamma,
                    lam=self.config.algorithm.lam,
                    num_repeat=self.config.actor_rollout_ref.rollout.n,
                    norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                    config=self.config.algorithm,
                )

            # after advantages are assinged, we begin to drop (1) long prompt (2) floor to ppo minisize
            keep_indices = (~batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0]
            metrics["training/n_triplets_prompt_too_long"] = (
                batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0]
            )
            batch = batch[keep_indices]
            # next, round to minibatch size
            mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
            n_transition = len(batch)
            random_indices = list(range(n_transition))
            random.shuffle(random_indices)
            batch.reorder(torch.tensor(random_indices).type(torch.int32))
            n_remained_transition = n_transition // mini_batch_size * mini_batch_size
            batch = batch[list(range(n_remained_transition))]
            metrics["training/n_triplets_dropped_remainder"] = n_transition - n_remained_transition

            # Agent mode note: Change the order of balance batch;
            #     1. first calculate advantage
            #     2. then drop the samples (too long prompt & floor to ppo minisize)
            #     3. balance
            # balance the number of valid tokens on each dp rank.
            # Note that this breaks the order of data inside the batch.
            # Please take care when you implement group based adv computation such as GRPO and rloo
            if self.config.trainer.balance_batch:
                self._balance_batch(batch, metrics=metrics)

            # update critic
            if self.use_critic:
                with _timer("update_critic", timing_raw):
                    critic_output = self.critic_wg.update_critic(batch)
                critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                metrics.update(critic_output_metrics)

            # implement critic warmup
            if self.config.trainer.critic_warmup <= self.global_steps:
                # update actor
                with _timer("update_actor", timing_raw):
                    batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
                    actor_output = self.actor_rollout_wg.update_actor(batch)
                actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                metrics.update(actor_output_metrics)

            # Log rollout generations if enabled
            rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
            if rollout_data_dir:
                with _timer("dump_rollout_generations", timing_raw):
                    print(batch.batch.keys())
                    inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
                    outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
                    scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
                    self._dump_generations(
                        inputs=inputs,
                        outputs=outputs,
                        scores=scores,
                        reward_extra_infos_dict=reward_extra_infos_dict,
                        dump_path=rollout_data_dir,
                    )

        # compute training metrics
        metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
        metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
        # TODO: implement actual tflpo and theoretical tflpo
        n_gpus = self.resource_pool_manager.get_n_gpus()
        metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))

        return metrics

    def fit(self):
        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        assert self.async_rollout_mode, "If agent mode is enabled, async server must be enabled"
        if self.adapter is not None and not isinstance(self.adapter, BaseTraceTripletAdapter):
            raise ValueError("Adapter must be a BaseTraceTripletAdapter for currently VERL implementation.")
        self.agent_mode_daemon = AgentModeDaemon(
            self.config.agentlightning.port,
            self.config.actor_rollout_ref.rollout.n,
            train_information={
                # Note (Zhiyuan): To avoid further patch into vllm async server, using the same sentence to get the naming here.
                # However, it is possible that verl updates the naming and causes incompatibility.
                # Reference: https://github.com/volcengine/verl/blob/5b5e09d9cc20625e436d01f69d9cc739ff681c54/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L217
                "model": "/".join(self.config.actor_rollout_ref.model.path.split("/")[-2:]),
                "temperature": self.config.actor_rollout_ref.rollout.temperature,
            },
            tokenizer=self.tokenizer,
            mini_batch_size=self.config.actor_rollout_ref.actor.ppo_mini_batch_size,
            pad_token_id=self.tokenizer.pad_token_id,
            mode="v1" if self.store is not None else "v0",
            store=self.store,
            llm_proxy=self.llm_proxy,
            adapter=self.adapter,
        )
        self.agent_mode_daemon.start()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        # add tqdm
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                metrics = {}
                timing_raw = {}
                is_last_step = self.global_steps >= self.total_training_steps

                # train step
                metrics = self._train_step(batch_dict)

                # validate
                if (
                    self.val_reward_fn is not None
                    and self.config.trainer.test_freq > 0
                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
                ):
                    with _timer("validate", timing_raw):
                        val_metrics: dict = self._validate()
                        if is_last_step:
                            last_val_metrics = val_metrics
                    metrics.update(val_metrics)

                if self.config.trainer.save_freq > 0 and (
                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0
                ):
                    with _timer("save_checkpoint", timing_raw):
                        self._save_checkpoint()

                # step metrics
                metrics.update(
                    {
                        "training/global_step": self.global_steps,
                        "training/epoch": epoch,
                    }
                )

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()

                    # This exit logic is to ensure a robust CI.
                    pprint(f"Flush the logger...")
                    del logger  # Make sure the loggers are flushed and closed properly
                    pprint(f"Training finished at step {self.global_steps}.")
                    return

                progress_bar.update(1)
                self.global_steps += 1

AgentModeDaemon

AgentModeDaemon using the AgentLightningServer SDK.

This class manages the server lifecycle, task queueing, and results retrieval, while also running a proxy server for LLM requests. It maintains the original interface for compatibility with the RayPPOTrainer.

Source code in agentlightning/verl/daemon.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
class AgentModeDaemon:
    """
    AgentModeDaemon using the AgentLightningServer SDK.

    This class manages the server lifecycle, task queueing, and results
    retrieval, while also running a proxy server for LLM requests. It maintains
    the original interface for compatibility with the RayPPOTrainer.
    """

    def __init__(
        self,
        port: Optional[int],
        train_rollout_n: int,
        train_information: Dict[str, Any],
        tokenizer: Any,
        mini_batch_size: int,
        pad_token_id: int,
        reward_fillna_value: float = 0.0,
        llm_timeout_seconds: float = 1200.0,
        mode: Literal["v0", "v1"] = "v1",
        llm_proxy: LLMProxy | None = None,
        store: LightningStore | None = None,
        adapter: BaseTraceTripletAdapter | None = None,
    ):
        self.mode = mode
        self.llm_timeout_seconds = llm_timeout_seconds

        # Server and Task Configuration
        if mode == "v0":
            assert port is not None
            self.server_port = port
            self.server = AgentLightningServer(
                host="0.0.0.0", port=self.server_port, task_timeout_seconds=self.llm_timeout_seconds
            )
            self.proxy_port = _find_available_port()  # Run proxy on a different port
        else:
            assert store is not None
            self.store = store
            if llm_proxy is None:
                self.llm_proxy = LLMProxy(
                    port=_find_available_port(),
                    model_list=[],
                    store=store,
                )
            else:
                # Reuse the existing LLM proxy (probably configured by user)
                self.llm_proxy = llm_proxy
            if adapter is None:
                self.adapter = TraceTripletAdapter()
            else:
                # Reuse the one from trainer
                self.adapter = adapter
            self._internal_loop: Optional[asyncio.AbstractEventLoop] = None
            self._internal_loop_thread = threading.Thread(target=self._internal_loop_runner, daemon=True)
            self._internal_loop_thread.start()

        # Training and Data Configuration
        self.train_rollout_n = train_rollout_n
        self.train_information = train_information
        self.mini_batch_size = mini_batch_size
        self.pad_token_id = pad_token_id
        self.tokenizer = tokenizer
        self.reward_fillna_value = reward_fillna_value

        # Internal State
        self.backend_llm_server_addresses: List[str] = []
        self._total_tasks_queued = 0
        self._completed_rollouts_v0: Dict[str, Rollout] = {}
        self._task_id_to_original_sample: Dict[str, Dict[str, Any]] = {}
        self._server_thread: Optional[threading.Thread] = None
        self._proxy_thread: Optional[threading.Thread] = None
        self.is_train = True

    def _internal_loop_runner(self):
        """Run the internal loop."""
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        self._internal_loop = loop
        loop.run_forever()
        loop.close()

    def _start_proxy_server_v0(self):
        """
        Initializes and runs a Flask-based proxy server in a separate thread.
        This proxy load-balances requests to the actual backend LLM servers.
        """
        app = Flask(__name__)

        num_requests = 0
        last_request_time = 0

        @app.route("/v1/<path:path>", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
        def proxy(path: str):  # type: ignore
            if not self.backend_llm_server_addresses:
                abort(503, description="No backend LLM servers available.")

            # Randomly choose a backend server for load balancing
            target_server = random.choice(self.backend_llm_server_addresses)
            target_url = f"http://{target_server}/v1/{path}"

            # Copy client request headers, removing the Host header
            headers = {key: value for key, value in request.headers if key.lower() != "host"}

            # Log the request for debugging
            nonlocal num_requests, last_request_time
            current_time = time.time()
            num_requests += 1
            if current_time - last_request_time > 60 or num_requests == 1 or num_requests % 100 == 0:
                print(f"Proxying {request.method} request to {target_server}. Request data: {request.get_data()}")
            last_request_time = current_time

            try:
                # Forward the request to the target backend
                resp = requests.request(
                    method=request.method,
                    url=target_url,
                    headers=headers,
                    params=request.args,  # type: ignore
                    data=request.get_data(),
                    cookies=request.cookies,
                    allow_redirects=False,
                    timeout=self.llm_timeout_seconds,
                )
                # Filter out hop-by-hop headers before returning the response
                excluded_headers = [
                    "content-encoding",
                    "content-length",
                    "transfer-encoding",
                    "connection",
                    "keep-alive",
                    "proxy-authenticate",
                    "proxy-authorization",
                    "te",
                    "trailers",
                    "upgrade",
                ]
                response_headers = [
                    (name, value) for name, value in resp.raw.headers.items() if name.lower() not in excluded_headers
                ]
                if resp.status_code == 200:
                    # NOTE: from Zhiyuan's code.
                    # https://github.com/hzy46/verl_agent_mode/blob/2db65ea9858f645a914120357412a7540f8bd82d/verl/trainer/ppo/ray_trainer.py#L692-L711
                    # request_json = json.loads(request.get_data().decode("utf-8"))
                    response_json = json.loads(resp.content.decode("utf-8"))
                    # response_message = ChatCompletion(**response_json).choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
                    # tool_schemas = request_json.get("tools", None)
                    # prompt_ids = self.tokenizer.apply_chat_template(request_json["messages"], tools=tool_schemas, add_generation_prompt=True, tokenize=True)
                    # full_ids = self.tokenizer.apply_chat_template(request_json["messages"] + [response_message], tools=tool_schemas, add_generation_prompt=False, tokenize=True)
                    # TBD: response_ids sometimes ends with "<eos_id>\n", shall we keep the extra "\n"?
                    # sometimes it has some differences with the hacky method in the end, but this should align with ToolCompletionCallback
                    # response_ids = full_ids[len(prompt_ids):]

                    # NOTE (yuge): They are different. Don't know why.
                    # assert response_json['prompt_token_ids'] == prompt_ids
                    # patched_response_ids = response_json['response_token_ids'][0]
                    # assert patched_response_ids == response_ids[:len(patched_response_ids)], f"{patched_response_ids} != {response_ids[:len(patched_response_ids)]}"
                    # response_json['prompt_token_ids'] = prompt_ids
                    # response_json['response_token_ids'] = [response_ids]
                    replaced_return_content = json.dumps(response_json).encode("utf-8")
                    return Response(replaced_return_content, status=resp.status_code, headers=response_headers)
                return Response(resp.content, resp.status_code, response_headers)
            except requests.exceptions.RequestException as e:
                abort(500, description=f"Error proxying request: {e}")

        def run_app():
            app.run(host="0.0.0.0", port=self.proxy_port, threaded=True, debug=False)

        self._proxy_thread = threading.Thread(target=run_app, daemon=True)
        self._proxy_thread.start()
        print(f"Proxy server running on port {self.proxy_port}")

    def _update_proxy_server_v1(self):
        model_name = self.train_information.get("model")
        if not model_name:
            raise ValueError("Model name is not set.")
        self.llm_proxy.update_model_list(
            [
                ModelConfig(
                    {
                        "model_name": model_name,
                        "litellm_params": {
                            "model": "hosted_vllm/" + model_name,
                            "api_base": f"http://{address}/v1/",
                        },
                    }
                )
                for address in self.backend_llm_server_addresses
            ],
        )

        if self.llm_proxy.is_running():
            # FIXME: Need to switch to a different port right now
            # because the forked processes carried the old fd
            self.llm_proxy.restart(_port=_find_available_port())
        else:
            self.llm_proxy.start()

    def start(self):
        """Starts the main AgentLightningServer and the proxy server."""

        if self.mode == "v0":

            def run_server():
                """Run the AgentLightningServer in a separate thread."""
                asyncio.run(self.server.run_forever())

            self._server_thread = threading.Thread(target=run_server, daemon=True)
            self._server_thread.start()

            # Wait for the server's internal startup event to be set.
            print("Waiting for AgentLightningServer to start...")
            is_ready = self.server.startup_event.wait(timeout=20.0)  # Wait up to 20s
            if not is_ready:
                raise RuntimeError("AgentLightningServer failed to start within the timeout period.")

            print(f"AgentLightningServer control plane running on port {self.server_port}")

            self._start_proxy_server_v0()
        else:
            # Agent lightning server is no longer needed;
            # Start proxy server in _async_set_up
            pass

    async def _async_set_up(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True):
        """Async helper to set up data and resources on the server."""
        self.clear_data_and_server()
        if server_addresses != self.backend_llm_server_addresses:
            self.backend_llm_server_addresses = server_addresses
            if self.mode == "v1" and not self.llm_proxy.is_running():
                self._update_proxy_server_v1()
        self.is_train = is_train

        # 1. Update resources on the server for clients to use
        if self.mode == "v0":
            llm_resource = LLM(
                endpoint=f"http://127.0.0.1:{self.proxy_port}/v1",
                model=self.train_information.get("model", "default-model"),
                sampling_parameters={
                    "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0)
                },
            )
        else:
            llm_resource = self.llm_proxy.as_resource(
                sampling_parameters={
                    "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0)
                },
            )

        resources: NamedResources = {"main_llm": llm_resource}

        if self.mode == "v0":
            resources_id = await self.server.update_resources(resources)
        else:
            resources_update = await self.store.add_resources(resources)
            resources_id = resources_update.resources_id

        # 2. Queue tasks for agents to process
        keys = list(data.keys())
        num_samples = len(data[keys[0]])
        rollouts_per_sample = self.train_rollout_n if is_train else 1

        for i in range(num_samples):
            data_id = str(uuid.uuid4())
            original_sample = {key: data[key][i] for key in keys}
            original_sample["data_id"] = data_id

            # For training, each sample is rolled out multiple times
            for _ in range(rollouts_per_sample):
                task_metadata = {"data_id": data_id, "is_train": is_train}

                # Data ID is different from Rollout ID, as one data can have multiple rollouts.
                if self.mode == "v0":
                    rollout_id = await self.server.queue_task(
                        sample=_to_native(original_sample),
                        mode="train" if is_train else "val",
                        resources_id=resources_id,
                        metadata=task_metadata,
                    )
                else:
                    rollout = await self.store.enqueue_rollout(
                        input=_to_native(original_sample),
                        mode="train" if is_train else "val",
                        resources_id=resources_id,
                        metadata=task_metadata,
                    )
                    await self.store.update_rollout(
                        rollout_id=rollout.rollout_id,
                        config=RolloutConfig(
                            unresponsive_seconds=self.llm_timeout_seconds,
                            timeout_seconds=self.llm_timeout_seconds,
                        ),
                    )
                    rollout_id = rollout.rollout_id

                # Store original sample data to reconstruct batch information later
                self._task_id_to_original_sample[rollout_id] = original_sample
                self._total_tasks_queued += 1

    def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True):
        """Synchronous wrapper for setting up data and server resources."""
        coro = self._async_set_up(data, server_addresses, is_train)

        if self.mode == "v0":
            if not self.server.loop or not self.server.startup_event.is_set():
                raise RuntimeError("Server is not running or ready.")

            future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)

        else:
            if self._internal_loop is None:
                raise RuntimeError("Internal loop is not running.")
            future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop)
        try:
            future.result(timeout=60)  # Wait for completion with a timeout
        except Exception as e:
            print(f"Failed to set up data on server: {e}")
            raise

    def _validate_data(self, rollout: Rollout):
        if rollout.final_reward is None:
            print(
                f"Warning: Reward is None for rollout {rollout.rollout_id}, will be auto-set to {self.reward_fillna_value}."
            )
        if rollout.triplets is None:
            print(f"Warning: Triplet is None for rollout {rollout.rollout_id}.")
        elif len(rollout.triplets) == 0:
            print(f"Warning: Length of triplets is 0 for rollout {rollout.rollout_id}.")
        elif any(not r.response.get("token_ids", []) for r in rollout.triplets):
            print(f"Warning: Rollout {rollout.rollout_id} contains empty response: {rollout.triplets}")
        elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets):
            print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}")

    async def _validate_data_v1(self, rollout: RolloutV2) -> Rollout:
        """Convert RolloutV2 to Rollout and validate.

        1. Task: construct from RolloutV2
        2. Triplets: obtained by querying spans and feeding into the adapter
        3. Final reward: extracted from last triplet's reward, searching backwards if not found
        """
        # Query spans for this rollout (latest attempt)
        spans = await self.store.query_spans(rollout.rollout_id, attempt_id="latest")

        # Convert spans to triplets using the adapter
        if not spans:
            # No triplets found, will emit a warning later.
            triplets = []
        else:
            triplets = self.adapter.adapt(spans)

        # Extract final reward from triplets
        final_reward: Optional[float] = None
        if triplets:
            # Search backwards through triplets for the first non-None reward
            for triplet in reversed(triplets):
                if triplet.reward is not None:
                    final_reward = triplet.reward
                    break

        # Construct the Task object from RolloutV2
        task = Task(
            rollout_id=rollout.rollout_id,
            input=rollout.input,
            mode=rollout.mode,
            resources_id=rollout.resources_id,
            metadata=rollout.metadata or {},
        )

        # Create the Rollout object (without trace and logs as per user's note)
        result_rollout = Rollout(
            rollout_id=rollout.rollout_id,
            task=task,
            final_reward=final_reward,
            triplets=triplets,
            metadata=rollout.metadata or {},
        )

        # Run the same validation as v0
        self._validate_data(result_rollout)

        return result_rollout

    async def _async_run_until_finished(self, verbose: bool = True):
        """Async helper to wait for all tasks to complete."""
        while len(self._completed_rollouts_v0) < self._total_tasks_queued:
            if self.mode == "v0":
                completed_batch = await self.server.retrieve_completed_rollouts()
            else:
                completed_batch = await self.store.wait_for_rollouts(
                    rollout_ids=list(self._task_id_to_original_sample.keys()), timeout=0
                )
            for rollout in completed_batch:
                if rollout.rollout_id in self._completed_rollouts_v0:
                    # Already processed, skip
                    continue
                if isinstance(rollout, RolloutV2):
                    rollout = await self._validate_data_v1(rollout)
                else:
                    self._validate_data(rollout)
                if rollout.rollout_id not in self._task_id_to_original_sample:
                    print(f"Warning: Received unknown rollout ID {rollout.rollout_id}, skipping.")
                else:
                    self._completed_rollouts_v0[rollout.rollout_id] = rollout
            if verbose:
                print(f"Completed {len(self._completed_rollouts_v0)}/{self._total_tasks_queued} tasks...")
            await asyncio.sleep(5)

        print("All tasks finished.")

    def run_until_all_finished(self, verbose: bool = True):
        """Synchronously waits for all queued tasks to be completed and reported."""
        if self._total_tasks_queued == 0:
            print("Warning: No tasks were queued.")
            return

        if self.mode == "v0":
            if not self.server.loop or not self.server.startup_event.is_set():
                raise RuntimeError("Server is not running or ready.")
            loop = self.server.loop
        else:
            loop = self._internal_loop
            assert loop is not None

        coro = self._async_run_until_finished(verbose)
        future = asyncio.run_coroutine_threadsafe(coro, loop)
        try:
            future.result()  # Wait indefinitely for all tasks to complete
        except Exception as e:
            print(f"Error while waiting for tasks to finish: {e}")
            raise

    def get_test_metrics(self):
        """Calculates and returns metrics for a validation run."""
        assert not self.is_train, "This method should only be called during validation."
        assert len(self._completed_rollouts_v0) == self._total_tasks_queued

        sample_stat_list: List[Dict[str, Any]] = []
        for _, rollout in self._completed_rollouts_v0.items():
            final_reward = self._fillna_reward(rollout)
            if not rollout.triplets:
                print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.")
                sample_stat_list.append({"reward": final_reward})
                continue
            response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
            sample_stat_list.append(
                {
                    "sum_response_length": np.sum(response_length_list),
                    "mean_response_length": np.mean(response_length_list) if response_length_list else 0,
                    "turn_count": len(rollout.triplets),
                    "reward": final_reward,
                }
            )

        stats_w_trace = [stat for stat in sample_stat_list if "sum_response_length" in stat]
        return {
            "val/n_rollouts": len(sample_stat_list),
            "val/n_rollouts_w_trace": len(stats_w_trace),
            "val/reward": np.mean(
                [stat["reward"] for stat in sample_stat_list]
            ),  # each rollout must have a reward (fillna if missing)
            "val/mean_response_length": np.mean([stat["mean_response_length"] for stat in stats_w_trace]),
            "val/sum_response_length": np.mean([stat["sum_response_length"] for stat in stats_w_trace]),
            "val/turn_count": np.mean([stat["turn_count"] for stat in stats_w_trace]),
        }

    def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, device: torch.device):
        """
        Processes completed rollouts to generate a training data batch.

        This function reconstructs the logic from the original AgentModeDaemon,
        using data retrieved from the new server architecture. It handles padding,
        truncation, and tensor creation for the PPO training loop.
        """
        assert self.is_train, "This method should only be called during training."
        assert len(self._completed_rollouts_v0) == self._total_tasks_queued

        # 1. Reconstruct the `finished_id_to_sample_info` structure from completed rollouts
        finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {}
        finished_id_to_final_reward: Dict[str, float] = {}
        for rollout_id, rollout in self._completed_rollouts_v0.items():
            original_sample = self._task_id_to_original_sample[rollout_id]

            final_reward = self._fillna_reward(rollout)

            if not rollout.triplets:
                finished_id_to_final_reward[rollout_id] = final_reward
                print(f"Warning: No triplets found for training rollout {rollout.rollout_id}, skipping.")
                continue

            # The client should report triplets that contain prompt_ids and response_ids.
            # Example triplet.prompt: {"token_ids": [...]}
            # Example triplet.response: {"token_ids": [...]}
            trace_list = [
                {"prompt_ids": t.prompt.get("token_ids", []), "response_ids": t.response.get("token_ids", [])}
                for t in rollout.triplets
            ]
            info = {
                "reward": final_reward,
                "trace_list": trace_list,
                "data_id": original_sample["data_id"],
            }
            finished_id_to_sample_info[rollout_id] = info
            finished_id_to_final_reward[rollout_id] = final_reward
        #
        # --- Data processing and tensor creation logic ---
        # Get all the reported data.
        # prompt_ids are left-padded.
        # response_ids are right-padded.
        # They are concatenated in the middle.
        # Discard handling:
        #   - Those exceeding max_prompt_length will be marked for discard, but not
        #     discarded here. They are only truncated and marked, to be discarded later.
        #     This is for the correctness of the advantage calculation.
        #   - The discard for the PPO mini-batch should also be handled this way.
        input_ids_list: List[List[int]] = []
        input_attention_mask_list: List[List[int]] = []
        response_ids_list: List[List[int]] = []
        response_attention_mask_list: List[List[int]] = []
        reward_list: List[float] = []
        data_id_list: List[str] = []
        rollout_id_list: List[str] = []
        turn_index_list: List[int] = []
        is_drop_list: List[bool] = []
        n_trunc_sample_because_of_response = 0

        for rollout_id, sample_info in finished_id_to_sample_info.items():
            for turn_index, trace in enumerate(sample_info["trace_list"]):

                reward_list.append(sample_info["reward"])
                prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]

                # Mark samples with prompts exceeding max_prompt_length to be dropped later
                if len(prompt_ids) > max_prompt_length:
                    prompt_ids = prompt_ids[:max_prompt_length]
                    is_drop_list.append(True)
                else:
                    is_drop_list.append(False)

                # Truncate responses that exceed max_response_length
                if len(response_ids) > max_response_length:
                    response_ids = response_ids[:max_response_length]
                    n_trunc_sample_because_of_response += 1

                # Pad prompts to the left and responses to the right
                one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
                    prompt_ids, max_prompt_length, self.pad_token_id
                )
                one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
                    response_ids, max_response_length, self.pad_token_id
                )

                input_ids_list.append(one_input_ids)
                input_attention_mask_list.append(one_input_attention_mask)
                response_ids_list.append(one_response_ids)
                response_attention_mask_list.append(one_response_attention_mask)
                data_id_list.append(sample_info["data_id"])
                rollout_id_list.append(rollout_id)
                turn_index_list.append(turn_index)

        n_transition = len(input_ids_list)
        batch_input_ids = torch.LongTensor(input_ids_list).to(device)
        input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
        batch_response_ids = torch.LongTensor(response_ids_list).to(device)
        response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)

        # Concatenate prompts and responses to form the full sequence
        batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
        attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1)
        position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
        is_drop_mask = torch.BoolTensor(is_drop_list).to(device)
        scores = torch.tensor(reward_list, dtype=torch.bfloat16).to(device)

        # Create token-level scores by placing the final reward at the last token position
        token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
        # At the eos_mask_idx position of each sample, fill in the corresponding scores.
        # torch.arange(n_transition) generates [0,1,2,...,bsz-1] as indices for the batch dimension.
        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)
        token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores
        # Only take the last response_length part of the sequence to get the token-level scores for the model's response part.
        token_level_scores = token_level_scores[:, -max_response_length:]

        # Form the final batch using TensorDict
        batch = TensorDict(
            {
                "prompts": batch_input_ids,
                "responses": batch_response_ids,
                "input_ids": batch_seq,  # here input_ids become the whole sentences
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "is_drop_mask": is_drop_mask,
                "token_level_scores": token_level_scores.contiguous(),
            },
            batch_size=n_transition,
        )
        data_proto = DataProto(batch=batch)

        data_metrics = {
            "training/reward": np.mean(list(finished_id_to_final_reward.values())),
            "training/n_rollouts": len(finished_id_to_final_reward),
            "training/n_rollouts_w_trace": len(finished_id_to_sample_info),
            "training/n_truncated_triplets": n_trunc_sample_because_of_response,
            "training/n_triplets": n_transition,
        }

        # Add non-tensor data for advantage calculation and logging
        data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list)  # type: ignore
        data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list)  # type: ignore
        data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list)  # type: ignore

        return data_proto, data_metrics

    def clear_data_and_server(self):
        """Resets the internal state of the daemon for the next run."""
        self.backend_llm_server_addresses = []
        self._completed_rollouts_v0.clear()
        self._task_id_to_original_sample.clear()
        self._total_tasks_queued = 0
        # For a true reset, the server's internal queues would also need clearing.
        # This implementation assumes that `set_up_data_and_server` is called
        # for each new run, effectively starting a fresh batch.

    def _fillna_reward(self, rollout: Rollout):
        if rollout.final_reward is None:
            if self.reward_fillna_value is not None:  # type: ignore
                final_reward = self.reward_fillna_value
            else:
                raise ValueError(f"Reward is None for rollout {rollout.rollout_id}, please check the reward function.")
        else:
            final_reward = rollout.final_reward
        return final_reward

clear_data_and_server()

Resets the internal state of the daemon for the next run.

Source code in agentlightning/verl/daemon.py
def clear_data_and_server(self):
    """Resets the internal state of the daemon for the next run."""
    self.backend_llm_server_addresses = []
    self._completed_rollouts_v0.clear()
    self._task_id_to_original_sample.clear()
    self._total_tasks_queued = 0

get_test_metrics()

Calculates and returns metrics for a validation run.

Source code in agentlightning/verl/daemon.py
def get_test_metrics(self):
    """Calculates and returns metrics for a validation run."""
    assert not self.is_train, "This method should only be called during validation."
    assert len(self._completed_rollouts_v0) == self._total_tasks_queued

    sample_stat_list: List[Dict[str, Any]] = []
    for _, rollout in self._completed_rollouts_v0.items():
        final_reward = self._fillna_reward(rollout)
        if not rollout.triplets:
            print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.")
            sample_stat_list.append({"reward": final_reward})
            continue
        response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
        sample_stat_list.append(
            {
                "sum_response_length": np.sum(response_length_list),
                "mean_response_length": np.mean(response_length_list) if response_length_list else 0,
                "turn_count": len(rollout.triplets),
                "reward": final_reward,
            }
        )

    stats_w_trace = [stat for stat in sample_stat_list if "sum_response_length" in stat]
    return {
        "val/n_rollouts": len(sample_stat_list),
        "val/n_rollouts_w_trace": len(stats_w_trace),
        "val/reward": np.mean(
            [stat["reward"] for stat in sample_stat_list]
        ),  # each rollout must have a reward (fillna if missing)
        "val/mean_response_length": np.mean([stat["mean_response_length"] for stat in stats_w_trace]),
        "val/sum_response_length": np.mean([stat["sum_response_length"] for stat in stats_w_trace]),
        "val/turn_count": np.mean([stat["turn_count"] for stat in stats_w_trace]),
    }

get_train_data_batch(max_prompt_length, max_response_length, device)

Processes completed rollouts to generate a training data batch.

This function reconstructs the logic from the original AgentModeDaemon, using data retrieved from the new server architecture. It handles padding, truncation, and tensor creation for the PPO training loop.

Source code in agentlightning/verl/daemon.py
def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, device: torch.device):
    """
    Processes completed rollouts to generate a training data batch.

    This function reconstructs the logic from the original AgentModeDaemon,
    using data retrieved from the new server architecture. It handles padding,
    truncation, and tensor creation for the PPO training loop.
    """
    assert self.is_train, "This method should only be called during training."
    assert len(self._completed_rollouts_v0) == self._total_tasks_queued

    # 1. Reconstruct the `finished_id_to_sample_info` structure from completed rollouts
    finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {}
    finished_id_to_final_reward: Dict[str, float] = {}
    for rollout_id, rollout in self._completed_rollouts_v0.items():
        original_sample = self._task_id_to_original_sample[rollout_id]

        final_reward = self._fillna_reward(rollout)

        if not rollout.triplets:
            finished_id_to_final_reward[rollout_id] = final_reward
            print(f"Warning: No triplets found for training rollout {rollout.rollout_id}, skipping.")
            continue

        # The client should report triplets that contain prompt_ids and response_ids.
        # Example triplet.prompt: {"token_ids": [...]}
        # Example triplet.response: {"token_ids": [...]}
        trace_list = [
            {"prompt_ids": t.prompt.get("token_ids", []), "response_ids": t.response.get("token_ids", [])}
            for t in rollout.triplets
        ]
        info = {
            "reward": final_reward,
            "trace_list": trace_list,
            "data_id": original_sample["data_id"],
        }
        finished_id_to_sample_info[rollout_id] = info
        finished_id_to_final_reward[rollout_id] = final_reward
    #
    # --- Data processing and tensor creation logic ---
    # Get all the reported data.
    # prompt_ids are left-padded.
    # response_ids are right-padded.
    # They are concatenated in the middle.
    # Discard handling:
    #   - Those exceeding max_prompt_length will be marked for discard, but not
    #     discarded here. They are only truncated and marked, to be discarded later.
    #     This is for the correctness of the advantage calculation.
    #   - The discard for the PPO mini-batch should also be handled this way.
    input_ids_list: List[List[int]] = []
    input_attention_mask_list: List[List[int]] = []
    response_ids_list: List[List[int]] = []
    response_attention_mask_list: List[List[int]] = []
    reward_list: List[float] = []
    data_id_list: List[str] = []
    rollout_id_list: List[str] = []
    turn_index_list: List[int] = []
    is_drop_list: List[bool] = []
    n_trunc_sample_because_of_response = 0

    for rollout_id, sample_info in finished_id_to_sample_info.items():
        for turn_index, trace in enumerate(sample_info["trace_list"]):

            reward_list.append(sample_info["reward"])
            prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]

            # Mark samples with prompts exceeding max_prompt_length to be dropped later
            if len(prompt_ids) > max_prompt_length:
                prompt_ids = prompt_ids[:max_prompt_length]
                is_drop_list.append(True)
            else:
                is_drop_list.append(False)

            # Truncate responses that exceed max_response_length
            if len(response_ids) > max_response_length:
                response_ids = response_ids[:max_response_length]
                n_trunc_sample_because_of_response += 1

            # Pad prompts to the left and responses to the right
            one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
                prompt_ids, max_prompt_length, self.pad_token_id
            )
            one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
                response_ids, max_response_length, self.pad_token_id
            )

            input_ids_list.append(one_input_ids)
            input_attention_mask_list.append(one_input_attention_mask)
            response_ids_list.append(one_response_ids)
            response_attention_mask_list.append(one_response_attention_mask)
            data_id_list.append(sample_info["data_id"])
            rollout_id_list.append(rollout_id)
            turn_index_list.append(turn_index)

    n_transition = len(input_ids_list)
    batch_input_ids = torch.LongTensor(input_ids_list).to(device)
    input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
    batch_response_ids = torch.LongTensor(response_ids_list).to(device)
    response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)

    # Concatenate prompts and responses to form the full sequence
    batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
    attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1)
    position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
    is_drop_mask = torch.BoolTensor(is_drop_list).to(device)
    scores = torch.tensor(reward_list, dtype=torch.bfloat16).to(device)

    # Create token-level scores by placing the final reward at the last token position
    token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
    # At the eos_mask_idx position of each sample, fill in the corresponding scores.
    # torch.arange(n_transition) generates [0,1,2,...,bsz-1] as indices for the batch dimension.
    eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)
    token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores
    # Only take the last response_length part of the sequence to get the token-level scores for the model's response part.
    token_level_scores = token_level_scores[:, -max_response_length:]

    # Form the final batch using TensorDict
    batch = TensorDict(
        {
            "prompts": batch_input_ids,
            "responses": batch_response_ids,
            "input_ids": batch_seq,  # here input_ids become the whole sentences
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "is_drop_mask": is_drop_mask,
            "token_level_scores": token_level_scores.contiguous(),
        },
        batch_size=n_transition,
    )
    data_proto = DataProto(batch=batch)

    data_metrics = {
        "training/reward": np.mean(list(finished_id_to_final_reward.values())),
        "training/n_rollouts": len(finished_id_to_final_reward),
        "training/n_rollouts_w_trace": len(finished_id_to_sample_info),
        "training/n_truncated_triplets": n_trunc_sample_because_of_response,
        "training/n_triplets": n_transition,
    }

    # Add non-tensor data for advantage calculation and logging
    data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list)  # type: ignore
    data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list)  # type: ignore
    data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list)  # type: ignore

    return data_proto, data_metrics

run_until_all_finished(verbose=True)

Synchronously waits for all queued tasks to be completed and reported.

Source code in agentlightning/verl/daemon.py
def run_until_all_finished(self, verbose: bool = True):
    """Synchronously waits for all queued tasks to be completed and reported."""
    if self._total_tasks_queued == 0:
        print("Warning: No tasks were queued.")
        return

    if self.mode == "v0":
        if not self.server.loop or not self.server.startup_event.is_set():
            raise RuntimeError("Server is not running or ready.")
        loop = self.server.loop
    else:
        loop = self._internal_loop
        assert loop is not None

    coro = self._async_run_until_finished(verbose)
    future = asyncio.run_coroutine_threadsafe(coro, loop)
    try:
        future.result()  # Wait indefinitely for all tasks to complete
    except Exception as e:
        print(f"Error while waiting for tasks to finish: {e}")
        raise

set_up_data_and_server(data, server_addresses, is_train=True)

Synchronous wrapper for setting up data and server resources.

Source code in agentlightning/verl/daemon.py
def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True):
    """Synchronous wrapper for setting up data and server resources."""
    coro = self._async_set_up(data, server_addresses, is_train)

    if self.mode == "v0":
        if not self.server.loop or not self.server.startup_event.is_set():
            raise RuntimeError("Server is not running or ready.")

        future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)

    else:
        if self._internal_loop is None:
            raise RuntimeError("Internal loop is not running.")
        future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop)
    try:
        future.result(timeout=60)  # Wait for completion with a timeout
    except Exception as e:
        print(f"Failed to set up data on server: {e}")
        raise

start()

Starts the main AgentLightningServer and the proxy server.

Source code in agentlightning/verl/daemon.py
def start(self):
    """Starts the main AgentLightningServer and the proxy server."""

    if self.mode == "v0":

        def run_server():
            """Run the AgentLightningServer in a separate thread."""
            asyncio.run(self.server.run_forever())

        self._server_thread = threading.Thread(target=run_server, daemon=True)
        self._server_thread.start()

        # Wait for the server's internal startup event to be set.
        print("Waiting for AgentLightningServer to start...")
        is_ready = self.server.startup_event.wait(timeout=20.0)  # Wait up to 20s
        if not is_ready:
            raise RuntimeError("AgentLightningServer failed to start within the timeout period.")

        print(f"AgentLightningServer control plane running on port {self.server_port}")

        self._start_proxy_server_v0()
    else:
        # Agent lightning server is no longer needed;
        # Start proxy server in _async_set_up
        pass

BaseTraceTripletAdapter

Bases: TraceAdapter[List[Triplet]]

Base class for trace triplet adapters.

Source code in agentlightning/adapter/triplet.py
class BaseTraceTripletAdapter(TraceAdapter[List[Triplet]]):
    """
    Base class for trace triplet adapters.
    """

Dataset

Bases: Protocol, Generic[T_co]

The general interface for a dataset.

It's currently implemented as a protocol, having a similar interface to torch.utils.data.Dataset. You don't have to inherit from this class; you can use a simple list if you want to.

Source code in agentlightning/types/core.py
class Dataset(Protocol, Generic[T_co]):
    """The general interface for a dataset.

    It's currently implemented as a protocol, having a similar interface to torch.utils.data.Dataset.
    You don't have to inherit from this class; you can use a simple list if you want to.
    """

    def __getitem__(self, index: int) -> T_co: ...

    def __len__(self) -> int: ...

LLM

Bases: Resource

Provide an LLM endpoint and model name as a resource.

Attributes:

Name Type Description
endpoint str

The URL of the LLM API endpoint.

model str

The identifier for the model to be used (e.g., 'gpt-4o').

sampling_parameters SamplingParameters

A dictionary of hyperparameters for model inference, such as temperature, top_p, etc.

Source code in agentlightning/types/resources.py
class LLM(Resource):
    """
    Provide an LLM endpoint and model name as a resource.

    Attributes:
        endpoint (str): The URL of the LLM API endpoint.
        model (str): The identifier for the model to be used (e.g., 'gpt-4o').
        sampling_parameters (SamplingParameters): A dictionary of hyperparameters
            for model inference, such as temperature, top_p, etc.
    """

    resource_type: Literal["llm"] = "llm"
    endpoint: str
    model: str
    api_key: Optional[str] = None
    sampling_parameters: Dict[str, Any] = Field(default_factory=dict)

    def get_base_url(self, *args: Any, **kwargs: Any) -> str:
        """The base_url to put into openai.OpenAI.

        Users are encouraged to use `base_url` to get the LLM endpoint instead of accessing `endpoint` directly.
        """
        return self.endpoint

get_base_url(*args, **kwargs)

The base_url to put into openai.OpenAI.

Users are encouraged to use base_url to get the LLM endpoint instead of accessing endpoint directly.

Source code in agentlightning/types/resources.py
def get_base_url(self, *args: Any, **kwargs: Any) -> str:
    """The base_url to put into openai.OpenAI.

    Users are encouraged to use `base_url` to get the LLM endpoint instead of accessing `endpoint` directly.
    """
    return self.endpoint

LLMProxy

Host a LiteLLM OpenAI-compatible proxy bound to a LightningStore.

The proxy:

  • Serves an OpenAI-compatible API via uvicorn.
  • Adds rollout/attempt routing and headers via middleware.
  • Registers OTEL export and token-id callbacks.
  • Writes a LiteLLM worker config file with model_list and settings.

Lifecycle:

  • start() writes config, starts uvicorn server in a thread, and waits until ready.
  • stop() tears down the server and removes the temp config file.
  • restart() convenience wrapper to stop then start.

Usage Note: As the LLM Proxy sets up an OpenTelemetry tracer, it's recommended to run it in a different process from the main runner (i.e., tracer from agents).

Parameters:

Name Type Description Default
port int

TCP port to bind.

required
model_list List[ModelConfig]

LiteLLM model_list entries.

required
store LightningStore

LightningStore used for span sequence and persistence.

required
host str | None

Publicly reachable host used in resource endpoints. Defaults to best-guess IPv4.

None
litellm_config Dict[str, Any] | None

Extra LiteLLM proxy config merged with model_list.

None
num_retries int

Default LiteLLM retry count injected into litellm_settings.

0
Source code in agentlightning/llm_proxy.py
class LLMProxy:
    """Host a LiteLLM OpenAI-compatible proxy bound to a LightningStore.

    The proxy:

    * Serves an OpenAI-compatible API via uvicorn.
    * Adds rollout/attempt routing and headers via middleware.
    * Registers OTEL export and token-id callbacks.
    * Writes a LiteLLM worker config file with ``model_list`` and settings.

    Lifecycle:

    * ``start()`` writes config, starts uvicorn server in a thread, and waits until ready.
    * ``stop()`` tears down the server and removes the temp config file.
    * ``restart()`` convenience wrapper to stop then start.

    Usage Note:
    As the LLM Proxy sets up an OpenTelemetry tracer, it's recommended to run it in a different
    process from the main runner (i.e., tracer from agents).

    Args:
        port: TCP port to bind.
        model_list: LiteLLM ``model_list`` entries.
        store: LightningStore used for span sequence and persistence.
        host: Publicly reachable host used in resource endpoints. Defaults to best-guess IPv4.
        litellm_config: Extra LiteLLM proxy config merged with ``model_list``.
        num_retries: Default LiteLLM retry count injected into ``litellm_settings``.
    """

    def __init__(
        self,
        port: int,
        model_list: List[ModelConfig],
        store: LightningStore,
        host: str | None = None,
        litellm_config: Dict[str, Any] | None = None,
        num_retries: int = 0,
    ):
        self.store = store
        self.host = host or _get_default_ipv4_address()
        self.port = port
        self.model_list = model_list
        self.litellm_config = litellm_config or {}

        # Ensure num_retries is present inside the litellm_settings block.
        self.litellm_config.setdefault("litellm_settings", {})
        self.litellm_config["litellm_settings"].setdefault("num_retries", num_retries)

        self._server_thread = None
        self._config_file = None
        self._uvicorn_server = None
        self._ready_event = threading.Event()

    def set_store(self, store: LightningStore) -> None:
        """Set the store for the proxy.

        Args:
            store: The store to use for the proxy.
        """
        self.store = store

    def update_model_list(self, model_list: List[ModelConfig]) -> None:
        """Replace the in-memory model list and hot-restart if running.

        Args:
            model_list: New list of model entries.
        """
        self.model_list = model_list
        logger.info(f"Updating LLMProxy model list to: {model_list}")
        if self.is_running():
            self.restart()
        # Do nothing if the server is not running.

    def _wait_until_started(self, startup_timeout: float = 20.0):
        """Block until the uvicorn server reports started or timeout.

        Args:
            startup_timeout: Maximum seconds to wait.
        """
        start = time.time()
        while True:
            if self._uvicorn_server is None:
                break
            if self._uvicorn_server.started:
                self._ready_event.set()
                break
            if self._uvicorn_server.should_exit:
                break
            if time.time() - start > startup_timeout:
                break
            time.sleep(0.01)

    def start(self):
        """Start the proxy server thread and initialize global wiring.

        Side effects:

        * Sets the module-level global store for middleware/exporter access.
        * Calls ``initialize()`` once to register middleware and callbacks.
        * Writes a temporary YAML config consumed by LiteLLM worker.
        * Launches uvicorn in a daemon thread and waits for readiness.
        """
        if self.is_running():
            # Trigger restart
            self.stop()

        global _global_store

        _global_store = self.store

        # Initialize global middleware and callbacks once.
        initialize()

        # Persist a temp worker config for LiteLLM and point the proxy at it.
        self._config_file = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False).name
        with open(self._config_file, "w") as fp:
            yaml.safe_dump(
                {
                    "model_list": self.model_list,
                    **self.litellm_config,
                },
                fp,
            )

        save_worker_config(config=self._config_file)

        # Bind to all interfaces to allow other hosts to reach it if needed.
        self._uvicorn_server = uvicorn.Server(uvicorn.Config(app, host="0.0.0.0", port=self.port))

        def run_server():
            # Serve uvicorn in this background thread with its own event loop.
            assert self._uvicorn_server is not None
            asyncio.run(self._uvicorn_server.serve())

        logger.info("Starting LLMProxy server thread...")
        self._ready_event.clear()
        self._server_thread = threading.Thread(target=run_server, daemon=True)
        self._server_thread.start()
        self._wait_until_started()

    def stop(self):
        """Stop the proxy server and clean up temporary artifacts.

        This is a best-effort graceful shutdown with a bounded join timeout.
        """
        if not self.is_running():
            logger.warning("LLMProxy is not running. Nothing to stop.")
            return

        # Remove worker config to avoid stale references.
        if self._config_file and os.path.exists(self._config_file):
            os.unlink(self._config_file)

        logger.info("Stopping LLMProxy server thread...")
        stop_success = True
        if self._server_thread is not None and self._uvicorn_server is not None and self._uvicorn_server.started:
            self._uvicorn_server.should_exit = True
            self._server_thread.join(timeout=10.0)  # Allow time for graceful shutdown.
            if self._server_thread.is_alive():
                logger.error(
                    "LLMProxy server thread is still alive after 10 seconds. Cannot kill it because it's a thread."
                )
                stop_success = False
            self._server_thread = None
            self._uvicorn_server = None
            self._config_file = None
            self._ready_event.clear()
            if not _check_port(self.host, self.port):
                logger.error(f"Port {self.port} is still in use. Stopping LLMProxy is not successful.")
                stop_success = False
        if stop_success:
            logger.info("LLMProxy server thread stopped.")
        else:
            logger.error("LLMProxy server is not stopped successfully.")

    def restart(self, *, _port: int | None = None) -> None:
        """Restart the proxy if running, else start it.

        Convenience wrapper calling ``stop()`` followed by ``start()``.
        """
        logger.info("Restarting LLMProxy server...")
        if self.is_running():
            self.stop()
        if _port is not None:
            self.port = _port
        self.start()

    def is_running(self) -> bool:
        """Return whether the uvicorn server is active.

        Returns:
            bool: True if server was started and did not signal exit.
        """
        return self._uvicorn_server is not None and self._uvicorn_server.started

    def as_resource(
        self,
        rollout_id: str | None = None,
        attempt_id: str | None = None,
        model: str | None = None,
        sampling_parameters: Dict[str, Any] | None = None,
    ) -> LLM:
        """Create an ``LLM`` resource pointing at this proxy with rollout context.

        The returned endpoint is:
            ``http://{host}:{port}/rollout/{rollout_id}/attempt/{attempt_id}``

        Args:
            rollout_id: Rollout identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
            attempt_id: Attempt identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
            model: Logical model name to use. If omitted and exactly one model
                is configured, that model is used.
            sampling_parameters: Optional default sampling parameters.

        Returns:
            LLM: Configured resource ready for OpenAI-compatible calls.

        Raises:
            ValueError: If ``model`` is omitted and zero or multiple models are configured.
        """
        if model is None:
            if len(self.model_list) == 1:
                model = self.model_list[0]["model_name"]
            else:
                raise ValueError(
                    f"Multiple or zero models found in model_list: {self.model_list}. Please specify the model."
                )

        if rollout_id is None and attempt_id is None:
            return ProxyLLM(
                endpoint=f"http://{self.host}:{self.port}",
                model=model,
                sampling_parameters=dict(sampling_parameters or {}),
            )
        elif rollout_id is not None and attempt_id is not None:
            return LLM(
                endpoint=f"http://{self.host}:{self.port}/rollout/{rollout_id}/attempt/{attempt_id}",
                model=model,
                sampling_parameters=dict(sampling_parameters or {}),
            )
        else:
            raise ValueError("Either rollout_id and attempt_id must be provided, or neither.")

as_resource(rollout_id=None, attempt_id=None, model=None, sampling_parameters=None)

Create an LLM resource pointing at this proxy with rollout context.

The returned endpoint is

http://{host}:{port}/rollout/{rollout_id}/attempt/{attempt_id}

Parameters:

Name Type Description Default
rollout_id str | None

Rollout identifier used for span attribution. If None, will instantiate a ProxyLLM resource.

None
attempt_id str | None

Attempt identifier used for span attribution. If None, will instantiate a ProxyLLM resource.

None
model str | None

Logical model name to use. If omitted and exactly one model is configured, that model is used.

None
sampling_parameters Dict[str, Any] | None

Optional default sampling parameters.

None

Returns:

Name Type Description
LLM LLM

Configured resource ready for OpenAI-compatible calls.

Raises:

Type Description
ValueError

If model is omitted and zero or multiple models are configured.

Source code in agentlightning/llm_proxy.py
def as_resource(
    self,
    rollout_id: str | None = None,
    attempt_id: str | None = None,
    model: str | None = None,
    sampling_parameters: Dict[str, Any] | None = None,
) -> LLM:
    """Create an ``LLM`` resource pointing at this proxy with rollout context.

    The returned endpoint is:
        ``http://{host}:{port}/rollout/{rollout_id}/attempt/{attempt_id}``

    Args:
        rollout_id: Rollout identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
        attempt_id: Attempt identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
        model: Logical model name to use. If omitted and exactly one model
            is configured, that model is used.
        sampling_parameters: Optional default sampling parameters.

    Returns:
        LLM: Configured resource ready for OpenAI-compatible calls.

    Raises:
        ValueError: If ``model`` is omitted and zero or multiple models are configured.
    """
    if model is None:
        if len(self.model_list) == 1:
            model = self.model_list[0]["model_name"]
        else:
            raise ValueError(
                f"Multiple or zero models found in model_list: {self.model_list}. Please specify the model."
            )

    if rollout_id is None and attempt_id is None:
        return ProxyLLM(
            endpoint=f"http://{self.host}:{self.port}",
            model=model,
            sampling_parameters=dict(sampling_parameters or {}),
        )
    elif rollout_id is not None and attempt_id is not None:
        return LLM(
            endpoint=f"http://{self.host}:{self.port}/rollout/{rollout_id}/attempt/{attempt_id}",
            model=model,
            sampling_parameters=dict(sampling_parameters or {}),
        )
    else:
        raise ValueError("Either rollout_id and attempt_id must be provided, or neither.")

is_running()

Return whether the uvicorn server is active.

Returns:

Name Type Description
bool bool

True if server was started and did not signal exit.

Source code in agentlightning/llm_proxy.py
def is_running(self) -> bool:
    """Return whether the uvicorn server is active.

    Returns:
        bool: True if server was started and did not signal exit.
    """
    return self._uvicorn_server is not None and self._uvicorn_server.started

restart(*, _port=None)

Restart the proxy if running, else start it.

Convenience wrapper calling stop() followed by start().

Source code in agentlightning/llm_proxy.py
def restart(self, *, _port: int | None = None) -> None:
    """Restart the proxy if running, else start it.

    Convenience wrapper calling ``stop()`` followed by ``start()``.
    """
    logger.info("Restarting LLMProxy server...")
    if self.is_running():
        self.stop()
    if _port is not None:
        self.port = _port
    self.start()

set_store(store)

Set the store for the proxy.

Parameters:

Name Type Description Default
store LightningStore

The store to use for the proxy.

required
Source code in agentlightning/llm_proxy.py
def set_store(self, store: LightningStore) -> None:
    """Set the store for the proxy.

    Args:
        store: The store to use for the proxy.
    """
    self.store = store

start()

Start the proxy server thread and initialize global wiring.

Side effects:

  • Sets the module-level global store for middleware/exporter access.
  • Calls initialize() once to register middleware and callbacks.
  • Writes a temporary YAML config consumed by LiteLLM worker.
  • Launches uvicorn in a daemon thread and waits for readiness.
Source code in agentlightning/llm_proxy.py
def start(self):
    """Start the proxy server thread and initialize global wiring.

    Side effects:

    * Sets the module-level global store for middleware/exporter access.
    * Calls ``initialize()`` once to register middleware and callbacks.
    * Writes a temporary YAML config consumed by LiteLLM worker.
    * Launches uvicorn in a daemon thread and waits for readiness.
    """
    if self.is_running():
        # Trigger restart
        self.stop()

    global _global_store

    _global_store = self.store

    # Initialize global middleware and callbacks once.
    initialize()

    # Persist a temp worker config for LiteLLM and point the proxy at it.
    self._config_file = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False).name
    with open(self._config_file, "w") as fp:
        yaml.safe_dump(
            {
                "model_list": self.model_list,
                **self.litellm_config,
            },
            fp,
        )

    save_worker_config(config=self._config_file)

    # Bind to all interfaces to allow other hosts to reach it if needed.
    self._uvicorn_server = uvicorn.Server(uvicorn.Config(app, host="0.0.0.0", port=self.port))

    def run_server():
        # Serve uvicorn in this background thread with its own event loop.
        assert self._uvicorn_server is not None
        asyncio.run(self._uvicorn_server.serve())

    logger.info("Starting LLMProxy server thread...")
    self._ready_event.clear()
    self._server_thread = threading.Thread(target=run_server, daemon=True)
    self._server_thread.start()
    self._wait_until_started()

stop()

Stop the proxy server and clean up temporary artifacts.

This is a best-effort graceful shutdown with a bounded join timeout.

Source code in agentlightning/llm_proxy.py
def stop(self):
    """Stop the proxy server and clean up temporary artifacts.

    This is a best-effort graceful shutdown with a bounded join timeout.
    """
    if not self.is_running():
        logger.warning("LLMProxy is not running. Nothing to stop.")
        return

    # Remove worker config to avoid stale references.
    if self._config_file and os.path.exists(self._config_file):
        os.unlink(self._config_file)

    logger.info("Stopping LLMProxy server thread...")
    stop_success = True
    if self._server_thread is not None and self._uvicorn_server is not None and self._uvicorn_server.started:
        self._uvicorn_server.should_exit = True
        self._server_thread.join(timeout=10.0)  # Allow time for graceful shutdown.
        if self._server_thread.is_alive():
            logger.error(
                "LLMProxy server thread is still alive after 10 seconds. Cannot kill it because it's a thread."
            )
            stop_success = False
        self._server_thread = None
        self._uvicorn_server = None
        self._config_file = None
        self._ready_event.clear()
        if not _check_port(self.host, self.port):
            logger.error(f"Port {self.port} is still in use. Stopping LLMProxy is not successful.")
            stop_success = False
    if stop_success:
        logger.info("LLMProxy server thread stopped.")
    else:
        logger.error("LLMProxy server is not stopped successfully.")

update_model_list(model_list)

Replace the in-memory model list and hot-restart if running.

Parameters:

Name Type Description Default
model_list List[ModelConfig]

New list of model entries.

required
Source code in agentlightning/llm_proxy.py
def update_model_list(self, model_list: List[ModelConfig]) -> None:
    """Replace the in-memory model list and hot-restart if running.

    Args:
        model_list: New list of model entries.
    """
    self.model_list = model_list
    logger.info(f"Updating LLMProxy model list to: {model_list}")
    if self.is_running():
        self.restart()

LightningStore

A centralized, thread-safe, async, data store for the lightning's state. This holds the task queue, versioned resources, and completed rollouts.

The store has a built-in clock and it should be responsible for tracking the times. All the time-based operations like retry, timeout, etc. should be handled by the store.

Source code in agentlightning/store/base.py
class LightningStore:
    """
    A centralized, thread-safe, async, data store for the lightning's state.
    This holds the task queue, versioned resources, and completed rollouts.

    The store has a built-in clock and it should be responsible for tracking the times.
    All the time-based operations like retry, timeout, etc. should be handled by the store.
    """

    async def start_rollout(
        self,
        input: TaskInput,
        mode: Literal["train", "val", "test"] | None = None,
        resources_id: str | None = None,
        metadata: Dict[str, Any] | None = None,
    ) -> AttemptedRollout:
        """
        Add one incomplete rollout to the store, and get an attempt created for it.
        This will immediately sets the rollout to a preparing state, and should be
        used by whoever is going to execute the rollout.

        Return a special rollout with attempt object. Do not update it directly.

        But if the rollout fails or timeouts, it's still possible that the watchdog
        sends it back to the queue for retry.

        To enqueue a rollout to the task queue, use `enqueue_rollout` instead.
        """
        raise NotImplementedError()

    async def enqueue_rollout(
        self,
        input: TaskInput,
        mode: Literal["train", "val", "test"] | None = None,
        resources_id: str | None = None,
        metadata: Dict[str, Any] | None = None,
    ) -> RolloutV2:
        """
        Adds a new task to the queue with specific metadata and
        returns the rollout object with its unique ID.
        """
        raise NotImplementedError()

    async def dequeue_rollout(self) -> Optional[AttemptedRollout]:
        """
        Retrieves the next task from the queue without blocking.
        Returns None if the queue is empty.

        Will set the rollout status to preparing.
        """
        raise NotImplementedError()

    async def start_attempt(self, rollout_id: str) -> AttemptedRollout:
        """
        Create a new attempt for a given rollout ID and return the attempt details.
        """
        raise NotImplementedError()

    async def add_span(self, span: Span) -> Span:
        """
        Add a span to the store.

        This method is responsible for updating the rollout/attempt status to "running" if needed.
        """
        raise NotImplementedError()

    async def add_otel_span(
        self,
        rollout_id: str,
        attempt_id: str,
        readable_span: ReadableSpan,
        sequence_id: int | None = None,
    ) -> Span:
        """
        Add an opentelemetry span to the store.

        If sequence_id is not provided, it will be fetched from `get_next_span_sequence_id` and assigned automatically.
        """
        raise NotImplementedError()

    async def query_rollouts(
        self, *, status: Optional[Sequence[RolloutStatus]] = None, rollout_ids: Optional[Sequence[str]] = None
    ) -> List[RolloutV2]:
        """
        Query and retrieve rollouts filtered by their status.
        If no status is provided, returns all rollouts.
        """
        raise NotImplementedError()

    async def query_attempts(self, rollout_id: str) -> List[Attempt]:
        """
        Query and retrieve all attempts associated with a specific rollout ID.
        Returns an empty list if no attempts are found.
        """
        raise NotImplementedError()

    async def get_rollout_by_id(self, rollout_id: str) -> Optional[RolloutV2]:
        """
        Safely retrieves a specific rollout by its ID.
        """
        raise NotImplementedError()

    async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
        """
        Safely retrieves the latest attempt for a given rollout ID.
        """
        raise NotImplementedError()

    async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
        """
        Safely retrieves a specific version of named resources by its ID.
        """
        raise NotImplementedError()

    async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
        """
        Safely retrieves the latest version of named resources.
        """
        raise NotImplementedError()

    async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int:
        """
        Get the next span sequence ID for a given rollout and attempt.
        This should be used to assign a unique sequence ID to each span within an attempt.

        Recommend getting the ID before the operation even begins to avoid racing conditions.
        """
        raise NotImplementedError()

    async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[RolloutV2]:
        """
        Wait for specified rollouts to complete with a timeout.
        Returns the completed rollouts, potentially incomplete if timeout is reached.

        TODO: Add support for waiting for 20 new rollouts, or wait until 80% of the pending ids are completed.
        """
        raise NotImplementedError()

    async def query_spans(self, rollout_id: str, attempt_id: str | Literal["latest"] | None = None) -> List[Span]:
        """
        Query and retrieve all spans associated with a specific rollout ID.
        Returns an empty list if no spans are found.
        """
        raise NotImplementedError()

    async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
        """
        Safely stores a new version of named resources and sets it as the latest.
        Not implemented by many stores yet.
        """
        raise NotImplementedError()

    async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
        """
        Safely stores a new version or updates an existing version of named resources and sets it as the latest.
        """
        raise NotImplementedError()

    async def update_rollout(
        self,
        rollout_id: str,
        input: TaskInput | Unset = UNSET,
        mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
        resources_id: Optional[str] | Unset = UNSET,
        status: RolloutStatus | Unset = UNSET,
        config: RolloutConfig | Unset = UNSET,
        metadata: Optional[Dict[str, Any]] | Unset = UNSET,
    ) -> RolloutV2:
        """
        Update the rollout status and related metadata.

        Not-listed fields here either cannot be updated, or should be auto-updated (e.g., end_time).

        When status is updated to a finished / problematic state, other states like task
        queues will be updated accordingly.

        Args:
            rollout_id: Unique identifier for the rollout to update
            input: New input data for the rollout. If set, will be updated. Can be updated to None
            mode: New mode for the rollout. If set, will be updated. Can be updated to None
            resources_id: New resources ID for the rollout. If set, will be updated. Can be updated to None
            status: New status for the rollout. If set, will be updated
            config: New config for the rollout. If set, will be updated
            metadata: Dictionary of additional metadata to update. If set, will replace the existing metadata
        """
        raise NotImplementedError()

    async def update_attempt(
        self,
        rollout_id: str,
        attempt_id: str | Literal["latest"],
        status: AttemptStatus | Unset = UNSET,
        worker_id: str | Unset = UNSET,
        last_heartbeat_time: float | Unset = UNSET,
        metadata: Optional[Dict[str, Any]] | Unset = UNSET,
    ) -> Attempt:
        """
        Update a specific or latest attempt for a given rollout.

        Update the latest attempt will NOT affect the corresponding rollout status.


        Args:
            rollout_id: Unique identifier for the rollout
            attempt_id: Unique identifier for the attempt
            status: Status to set for the attempt, update if provided
            worker_id: Worker identifier, update if provided
            last_heartbeat_time: Timestamp of the last heartbeat from the worker
            metadata: Dictionary of additional metadata to update, will replace the existing metadata
        """
        raise NotImplementedError()

add_otel_span(rollout_id, attempt_id, readable_span, sequence_id=None) async

Add an opentelemetry span to the store.

If sequence_id is not provided, it will be fetched from get_next_span_sequence_id and assigned automatically.

Source code in agentlightning/store/base.py
async def add_otel_span(
    self,
    rollout_id: str,
    attempt_id: str,
    readable_span: ReadableSpan,
    sequence_id: int | None = None,
) -> Span:
    """
    Add an opentelemetry span to the store.

    If sequence_id is not provided, it will be fetched from `get_next_span_sequence_id` and assigned automatically.
    """
    raise NotImplementedError()

add_resources(resources) async

Safely stores a new version of named resources and sets it as the latest. Not implemented by many stores yet.

Source code in agentlightning/store/base.py
async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
    """
    Safely stores a new version of named resources and sets it as the latest.
    Not implemented by many stores yet.
    """
    raise NotImplementedError()

add_span(span) async

Add a span to the store.

This method is responsible for updating the rollout/attempt status to "running" if needed.

Source code in agentlightning/store/base.py
async def add_span(self, span: Span) -> Span:
    """
    Add a span to the store.

    This method is responsible for updating the rollout/attempt status to "running" if needed.
    """
    raise NotImplementedError()

dequeue_rollout() async

Retrieves the next task from the queue without blocking. Returns None if the queue is empty.

Will set the rollout status to preparing.

Source code in agentlightning/store/base.py
async def dequeue_rollout(self) -> Optional[AttemptedRollout]:
    """
    Retrieves the next task from the queue without blocking.
    Returns None if the queue is empty.

    Will set the rollout status to preparing.
    """
    raise NotImplementedError()

enqueue_rollout(input, mode=None, resources_id=None, metadata=None) async

Adds a new task to the queue with specific metadata and returns the rollout object with its unique ID.

Source code in agentlightning/store/base.py
async def enqueue_rollout(
    self,
    input: TaskInput,
    mode: Literal["train", "val", "test"] | None = None,
    resources_id: str | None = None,
    metadata: Dict[str, Any] | None = None,
) -> RolloutV2:
    """
    Adds a new task to the queue with specific metadata and
    returns the rollout object with its unique ID.
    """
    raise NotImplementedError()

get_latest_attempt(rollout_id) async

Safely retrieves the latest attempt for a given rollout ID.

Source code in agentlightning/store/base.py
async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
    """
    Safely retrieves the latest attempt for a given rollout ID.
    """
    raise NotImplementedError()

get_latest_resources() async

Safely retrieves the latest version of named resources.

Source code in agentlightning/store/base.py
async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
    """
    Safely retrieves the latest version of named resources.
    """
    raise NotImplementedError()

get_next_span_sequence_id(rollout_id, attempt_id) async

Get the next span sequence ID for a given rollout and attempt. This should be used to assign a unique sequence ID to each span within an attempt.

Recommend getting the ID before the operation even begins to avoid racing conditions.

Source code in agentlightning/store/base.py
async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int:
    """
    Get the next span sequence ID for a given rollout and attempt.
    This should be used to assign a unique sequence ID to each span within an attempt.

    Recommend getting the ID before the operation even begins to avoid racing conditions.
    """
    raise NotImplementedError()

get_resources_by_id(resources_id) async

Safely retrieves a specific version of named resources by its ID.

Source code in agentlightning/store/base.py
async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
    """
    Safely retrieves a specific version of named resources by its ID.
    """
    raise NotImplementedError()

get_rollout_by_id(rollout_id) async

Safely retrieves a specific rollout by its ID.

Source code in agentlightning/store/base.py
async def get_rollout_by_id(self, rollout_id: str) -> Optional[RolloutV2]:
    """
    Safely retrieves a specific rollout by its ID.
    """
    raise NotImplementedError()

query_attempts(rollout_id) async

Query and retrieve all attempts associated with a specific rollout ID. Returns an empty list if no attempts are found.

Source code in agentlightning/store/base.py
async def query_attempts(self, rollout_id: str) -> List[Attempt]:
    """
    Query and retrieve all attempts associated with a specific rollout ID.
    Returns an empty list if no attempts are found.
    """
    raise NotImplementedError()

query_rollouts(*, status=None, rollout_ids=None) async

Query and retrieve rollouts filtered by their status. If no status is provided, returns all rollouts.

Source code in agentlightning/store/base.py
async def query_rollouts(
    self, *, status: Optional[Sequence[RolloutStatus]] = None, rollout_ids: Optional[Sequence[str]] = None
) -> List[RolloutV2]:
    """
    Query and retrieve rollouts filtered by their status.
    If no status is provided, returns all rollouts.
    """
    raise NotImplementedError()

query_spans(rollout_id, attempt_id=None) async

Query and retrieve all spans associated with a specific rollout ID. Returns an empty list if no spans are found.

Source code in agentlightning/store/base.py
async def query_spans(self, rollout_id: str, attempt_id: str | Literal["latest"] | None = None) -> List[Span]:
    """
    Query and retrieve all spans associated with a specific rollout ID.
    Returns an empty list if no spans are found.
    """
    raise NotImplementedError()

start_attempt(rollout_id) async

Create a new attempt for a given rollout ID and return the attempt details.

Source code in agentlightning/store/base.py
async def start_attempt(self, rollout_id: str) -> AttemptedRollout:
    """
    Create a new attempt for a given rollout ID and return the attempt details.
    """
    raise NotImplementedError()

start_rollout(input, mode=None, resources_id=None, metadata=None) async

Add one incomplete rollout to the store, and get an attempt created for it. This will immediately sets the rollout to a preparing state, and should be used by whoever is going to execute the rollout.

Return a special rollout with attempt object. Do not update it directly.

But if the rollout fails or timeouts, it's still possible that the watchdog sends it back to the queue for retry.

To enqueue a rollout to the task queue, use enqueue_rollout instead.

Source code in agentlightning/store/base.py
async def start_rollout(
    self,
    input: TaskInput,
    mode: Literal["train", "val", "test"] | None = None,
    resources_id: str | None = None,
    metadata: Dict[str, Any] | None = None,
) -> AttemptedRollout:
    """
    Add one incomplete rollout to the store, and get an attempt created for it.
    This will immediately sets the rollout to a preparing state, and should be
    used by whoever is going to execute the rollout.

    Return a special rollout with attempt object. Do not update it directly.

    But if the rollout fails or timeouts, it's still possible that the watchdog
    sends it back to the queue for retry.

    To enqueue a rollout to the task queue, use `enqueue_rollout` instead.
    """
    raise NotImplementedError()

update_attempt(rollout_id, attempt_id, status=UNSET, worker_id=UNSET, last_heartbeat_time=UNSET, metadata=UNSET) async

Update a specific or latest attempt for a given rollout.

Update the latest attempt will NOT affect the corresponding rollout status.

Parameters:

Name Type Description Default
rollout_id str

Unique identifier for the rollout

required
attempt_id str | Literal['latest']

Unique identifier for the attempt

required
status AttemptStatus | Unset

Status to set for the attempt, update if provided

UNSET
worker_id str | Unset

Worker identifier, update if provided

UNSET
last_heartbeat_time float | Unset

Timestamp of the last heartbeat from the worker

UNSET
metadata Optional[Dict[str, Any]] | Unset

Dictionary of additional metadata to update, will replace the existing metadata

UNSET
Source code in agentlightning/store/base.py
async def update_attempt(
    self,
    rollout_id: str,
    attempt_id: str | Literal["latest"],
    status: AttemptStatus | Unset = UNSET,
    worker_id: str | Unset = UNSET,
    last_heartbeat_time: float | Unset = UNSET,
    metadata: Optional[Dict[str, Any]] | Unset = UNSET,
) -> Attempt:
    """
    Update a specific or latest attempt for a given rollout.

    Update the latest attempt will NOT affect the corresponding rollout status.


    Args:
        rollout_id: Unique identifier for the rollout
        attempt_id: Unique identifier for the attempt
        status: Status to set for the attempt, update if provided
        worker_id: Worker identifier, update if provided
        last_heartbeat_time: Timestamp of the last heartbeat from the worker
        metadata: Dictionary of additional metadata to update, will replace the existing metadata
    """
    raise NotImplementedError()

update_resources(resources_id, resources) async

Safely stores a new version or updates an existing version of named resources and sets it as the latest.

Source code in agentlightning/store/base.py
async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
    """
    Safely stores a new version or updates an existing version of named resources and sets it as the latest.
    """
    raise NotImplementedError()

update_rollout(rollout_id, input=UNSET, mode=UNSET, resources_id=UNSET, status=UNSET, config=UNSET, metadata=UNSET) async

Update the rollout status and related metadata.

Not-listed fields here either cannot be updated, or should be auto-updated (e.g., end_time).

When status is updated to a finished / problematic state, other states like task queues will be updated accordingly.

Parameters:

Name Type Description Default
rollout_id str

Unique identifier for the rollout to update

required
input TaskInput | Unset

New input data for the rollout. If set, will be updated. Can be updated to None

UNSET
mode Optional[Literal['train', 'val', 'test']] | Unset

New mode for the rollout. If set, will be updated. Can be updated to None

UNSET
resources_id Optional[str] | Unset

New resources ID for the rollout. If set, will be updated. Can be updated to None

UNSET
status RolloutStatus | Unset

New status for the rollout. If set, will be updated

UNSET
config RolloutConfig | Unset

New config for the rollout. If set, will be updated

UNSET
metadata Optional[Dict[str, Any]] | Unset

Dictionary of additional metadata to update. If set, will replace the existing metadata

UNSET
Source code in agentlightning/store/base.py
async def update_rollout(
    self,
    rollout_id: str,
    input: TaskInput | Unset = UNSET,
    mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
    resources_id: Optional[str] | Unset = UNSET,
    status: RolloutStatus | Unset = UNSET,
    config: RolloutConfig | Unset = UNSET,
    metadata: Optional[Dict[str, Any]] | Unset = UNSET,
) -> RolloutV2:
    """
    Update the rollout status and related metadata.

    Not-listed fields here either cannot be updated, or should be auto-updated (e.g., end_time).

    When status is updated to a finished / problematic state, other states like task
    queues will be updated accordingly.

    Args:
        rollout_id: Unique identifier for the rollout to update
        input: New input data for the rollout. If set, will be updated. Can be updated to None
        mode: New mode for the rollout. If set, will be updated. Can be updated to None
        resources_id: New resources ID for the rollout. If set, will be updated. Can be updated to None
        status: New status for the rollout. If set, will be updated
        config: New config for the rollout. If set, will be updated
        metadata: Dictionary of additional metadata to update. If set, will replace the existing metadata
    """
    raise NotImplementedError()

wait_for_rollouts(*, rollout_ids, timeout=None) async

Wait for specified rollouts to complete with a timeout. Returns the completed rollouts, potentially incomplete if timeout is reached.

TODO: Add support for waiting for 20 new rollouts, or wait until 80% of the pending ids are completed.

Source code in agentlightning/store/base.py
async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[RolloutV2]:
    """
    Wait for specified rollouts to complete with a timeout.
    Returns the completed rollouts, potentially incomplete if timeout is reached.

    TODO: Add support for waiting for 20 new rollouts, or wait until 80% of the pending ids are completed.
    """
    raise NotImplementedError()

ModelConfig

Bases: TypedDict

LiteLLM model registration entry.

This mirrors the items in LiteLLM's model_list section.

Attributes:

Name Type Description
model_name str

Logical model name exposed by the proxy.

litellm_params Dict[str, Any]

Parameters passed to LiteLLM for this model (e.g., backend model id, api_base, additional options).

Source code in agentlightning/llm_proxy.py
class ModelConfig(TypedDict):
    """LiteLLM model registration entry.

    This mirrors the items in LiteLLM's ``model_list`` section.

    Attributes:
        model_name: Logical model name exposed by the proxy.
        litellm_params: Parameters passed to LiteLLM for this model
            (e.g., backend model id, api_base, additional options).
    """  # Google style kept concise.

    model_name: str
    litellm_params: Dict[str, Any]

Rollout

Bases: BaseModel

The standard reporting object from client to server.

Source code in agentlightning/types/core.py
class Rollout(BaseModel):
    """The standard reporting object from client to server."""

    rollout_id: str

    # Echoing the input task
    task: Optional[Task] = None

    # Primary, high-level feedback
    final_reward: Optional[float] = None

    # Structured, sequential feedback for RL-style optimization
    triplets: Optional[List[Triplet]] = None

    # Optional, rich-context data for deep analysis
    trace: Optional[List[Dict[str, Any]]] = Field(
        default=None,
        description="A list of spans that conform to the OpenTelemetry JSON format. "
        "Users of the opentelemetry-sdk can generate this by calling "
        "json.loads(readable_span.to_json()).",
    )
    logs: Optional[List[str]] = None

    # A bucket for any other relevant information
    metadata: Dict[str, Any] = Field(default_factory=dict)

RolloutConfig

Bases: BaseModel

Configurations for rollout execution.

Source code in agentlightning/types/core.py
class RolloutConfig(BaseModel):
    """Configurations for rollout execution."""

    timeout_seconds: Optional[float] = None  # none indicates no timeout
    unresponsive_seconds: Optional[float] = None  # none indicates no unresponsive timeout
    max_attempts: int = Field(default=1, ge=1)  # including the first attempt
    retry_condition: List[AttemptStatus] = Field(
        default_factory=cast(Callable[[], List[AttemptStatus]], list)
    )  # list of statuses that should trigger a retry

Task

Bases: BaseModel

A task (rollout request) to be processed by the client agent.

Source code in agentlightning/types/core.py
class Task(BaseModel):
    """A task (rollout request) to be processed by the client agent."""

    rollout_id: str
    input: TaskInput

    mode: Optional[RolloutMode] = None
    resources_id: Optional[str] = None

    # Optional fields for tracking task lifecycle
    create_time: Optional[float] = None
    last_claim_time: Optional[float] = None
    num_claims: Optional[int] = None

    # Allow additional metadata fields
    metadata: Dict[str, Any] = Field(default_factory=dict)

TraceAdapter

Bases: Adapter[List[Span], T_to], Generic[T_to]

Base class for adapters that convert trace spans into other formats.

This class specializes Adapter for working with trace spans. It expects a list of Agent-lightning spans as input and produces a custom target format (e.g., reinforcement learning training data, SFT datasets, logs, metrics).

Source code in agentlightning/adapter/base.py
class TraceAdapter(Adapter[List[Span], T_to], Generic[T_to]):
    """Base class for adapters that convert trace spans into other formats.

    This class specializes `Adapter` for working with trace spans. It expects a list of
    Agent-lightning spans as input and produces a custom target format
    (e.g., reinforcement learning training data, SFT datasets, logs, metrics).
    """

TraceTripletAdapter

Bases: BaseTraceTripletAdapter

An adapter to convert OpenTelemetry spans to triplet data.

Attributes:

Name Type Description
repair_hierarchy

When repair_hierarchy is set to True, the trace will be repaired with the time information. See TraceTree.repair_hierarchy for more details.

llm_call_match

Regular expression pattern to match LLM call span names.

agent_match

Optional regular expression pattern to match agent span names. If None, all agents are matched.

exclude_llm_call_in_reward

Whether to exclude LLM calls that occur within reward spans.

reward_match

Policy for matching rewards to LLM calls.

Source code in agentlightning/adapter/triplet.py
class TraceTripletAdapter(BaseTraceTripletAdapter):
    """
    An adapter to convert OpenTelemetry spans to triplet data.

    Attributes:
        repair_hierarchy: When `repair_hierarchy` is set to True, the trace will be repaired with the time information.
            See `TraceTree.repair_hierarchy` for more details.
        llm_call_match: Regular expression pattern to match LLM call span names.
        agent_match: Optional regular expression pattern to match agent span names. If None, all agents are matched.
        exclude_llm_call_in_reward: Whether to exclude LLM calls that occur within reward spans.
        reward_match: Policy for matching rewards to LLM calls.
    """

    def __init__(
        self,
        repair_hierarchy: bool = True,
        llm_call_match: str = r"openai\.chat\.completion",
        agent_match: Optional[str] = None,
        exclude_llm_call_in_reward: bool = True,
        reward_match: RewardMatchPolicy = RewardMatchPolicy.FIRST_OCCURRENCE,
    ):
        self.repair_hierarchy = repair_hierarchy
        self.llm_call_match = llm_call_match
        self.agent_match = agent_match
        self.exclude_llm_call_in_reward = exclude_llm_call_in_reward
        self.reward_match = reward_match

    def visualize(
        self,
        source: Union[List[Span], List[ReadableSpan]],
        /,
        filename: str = "trace_tree",
        interested_span_match: str | None = None,
    ) -> TraceTree:
        """
        Visualize the trace tree.

        Args:
            source (List[Span]): The list of OpenTelemetry spans to visualize.
            filename (str): The base filename for the output visualization (default: "trace_tree").
            interested_span_match (str | None): Optional regular expression pattern to highlight or focus on specific spans in the visualization.

        Returns:
            TraceTree: The constructed trace tree object.
        """
        source_normalized = [
            Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
            for span in source
        ]
        trace_tree = TraceTree.from_spans(source_normalized)
        if self.repair_hierarchy:
            trace_tree.repair_hierarchy()
        trace_tree.visualize(filename, interested_span_match=interested_span_match)
        return trace_tree

    def adapt(self, source: Union[List[Span], List[ReadableSpan]], /) -> List[Triplet]:  # type: ignore
        """Convert OpenTelemetry spans to a list of Triplet objects."""
        source_normalized = [
            Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
            for span in source
        ]
        trace_tree = TraceTree.from_spans(source_normalized)
        if self.repair_hierarchy:
            trace_tree.repair_hierarchy()
        trajectory = trace_tree.to_trajectory(
            llm_call_match=self.llm_call_match,
            agent_match=self.agent_match,
            exclude_llm_call_in_reward=self.exclude_llm_call_in_reward,
            reward_match=self.reward_match,
        )
        return trajectory

adapt(source)

Convert OpenTelemetry spans to a list of Triplet objects.

Source code in agentlightning/adapter/triplet.py
def adapt(self, source: Union[List[Span], List[ReadableSpan]], /) -> List[Triplet]:  # type: ignore
    """Convert OpenTelemetry spans to a list of Triplet objects."""
    source_normalized = [
        Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
        for span in source
    ]
    trace_tree = TraceTree.from_spans(source_normalized)
    if self.repair_hierarchy:
        trace_tree.repair_hierarchy()
    trajectory = trace_tree.to_trajectory(
        llm_call_match=self.llm_call_match,
        agent_match=self.agent_match,
        exclude_llm_call_in_reward=self.exclude_llm_call_in_reward,
        reward_match=self.reward_match,
    )
    return trajectory

visualize(source, /, filename='trace_tree', interested_span_match=None)

Visualize the trace tree.

Parameters:

Name Type Description Default
source List[Span]

The list of OpenTelemetry spans to visualize.

required
filename str

The base filename for the output visualization (default: "trace_tree").

'trace_tree'
interested_span_match str | None

Optional regular expression pattern to highlight or focus on specific spans in the visualization.

None

Returns:

Name Type Description
TraceTree TraceTree

The constructed trace tree object.

Source code in agentlightning/adapter/triplet.py
def visualize(
    self,
    source: Union[List[Span], List[ReadableSpan]],
    /,
    filename: str = "trace_tree",
    interested_span_match: str | None = None,
) -> TraceTree:
    """
    Visualize the trace tree.

    Args:
        source (List[Span]): The list of OpenTelemetry spans to visualize.
        filename (str): The base filename for the output visualization (default: "trace_tree").
        interested_span_match (str | None): Optional regular expression pattern to highlight or focus on specific spans in the visualization.

    Returns:
        TraceTree: The constructed trace tree object.
    """
    source_normalized = [
        Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
        for span in source
    ]
    trace_tree = TraceTree.from_spans(source_normalized)
    if self.repair_hierarchy:
        trace_tree.repair_hierarchy()
    trace_tree.visualize(filename, interested_span_match=interested_span_match)
    return trace_tree

get_left_padded_ids_and_attention_mask(ids, max_length, pad_token_id)

Left-pad (or truncate) a sequence of token IDs to a fixed length, and build the corresponding attention mask.

Parameters:

Name Type Description Default
ids List[int]

the original list of token IDs.

required
max_length int

desired total length after padding/truncation.

required
pad_token_id int

ID to use for padding.

required

Returns:

Name Type Description
padded_ids any

list of length == max_length.

attention_mask any

list of same length: 1 for non-pad tokens, 0 for pads.

Source code in agentlightning/verl/daemon.py
def get_left_padded_ids_and_attention_mask(
    ids: List[int], max_length: int, pad_token_id: int
) -> Tuple[List[int], List[int]]:
    """
    Left-pad (or truncate) a sequence of token IDs to a fixed length,
    and build the corresponding attention mask.

    Args:
        ids:             the original list of token IDs.
        max_length:      desired total length after padding/truncation.
        pad_token_id:    ID to use for padding.

    Returns:
        padded_ids (any):      list of length == max_length.
        attention_mask (any):  list of same length: 1 for non-pad tokens, 0 for pads.
    """
    seq_len = len(ids)

    if seq_len >= max_length:
        # too long → truncate from the left, keep the last max_length tokens
        trimmed = ids[-max_length:]
        attention_mask = [1] * max_length
        return trimmed, attention_mask

    # too short → pad on the left
    pad_len = max_length - seq_len
    padded_ids = [pad_token_id] * pad_len + ids
    attention_mask = [0] * pad_len + [1] * seq_len
    return padded_ids, attention_mask

get_right_padded_ids_and_attention_mask(ids, max_length, pad_token_id)

Right-pad (or truncate) a sequence of token IDs to a fixed length, and build the corresponding attention mask.

Parameters:

Name Type Description Default
ids List[int]

the original list of token IDs.

required
max_length int

desired total length after padding/truncation.

required
pad_token_id int

ID to use for padding.

required

Returns:

Name Type Description
padded_ids any

list of length == max_length.

attention_mask any

list of same length: 1 for non-pad tokens, 0 for pads.

Source code in agentlightning/verl/daemon.py
def get_right_padded_ids_and_attention_mask(
    ids: List[int], max_length: int, pad_token_id: int
) -> Tuple[List[int], List[int]]:
    """
    Right-pad (or truncate) a sequence of token IDs to a fixed length,
    and build the corresponding attention mask.

    Args:
        ids:            the original list of token IDs.
        max_length:     desired total length after padding/truncation.
        pad_token_id:   ID to use for padding.

    Returns:
        padded_ids (any):     list of length == max_length.
        attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads.
    """
    seq_len = len(ids)

    if seq_len >= max_length:
        # too long → truncate to the first max_length tokens
        trimmed = ids[:max_length]
        attention_mask = [1] * max_length
        return trimmed, attention_mask

    # too short → pad on the right
    pad_len = max_length - seq_len
    padded_ids = ids + [pad_token_id] * pad_len
    attention_mask = [1] * seq_len + [0] * pad_len
    return padded_ids, attention_mask