Skip to content

Reinforcement Learning API

agentlightning.verl

NamedResources = Dict[str, ResourceUnion] module-attribute

A dictionary-like class to hold named resources.

Example

resources: NamedResources = { 'main_llm': LLM( endpoint="http://localhost:8080", model="llama3", sampling_parameters={'temperature': 0.7, 'max_tokens': 100} ), 'system_prompt': PromptTemplate( template="You are a helpful assistant.", engine='f-string' ) }

AgentLightningServer

The main SDK class for developers to control the Agent Lightning Server.

This class manages the server lifecycle, task queueing, resources updates, and retrieval of results, providing a simple interface for the optimization logic.

Source code in agentlightning/server.py
class AgentLightningServer:
    """
    The main SDK class for developers to control the Agent Lightning Server.

    This class manages the server lifecycle, task queueing, resources updates,
    and retrieval of results, providing a simple interface for the optimization logic.
    """

    def __init__(self, host: str = "127.0.0.1", port: int = 8000, task_timeout_seconds: float = 300.0):
        """
        Initializes the server controller.

        Args:
            host: The host to bind the server to.
            port: The port to bind the server to.
            task_timeout_seconds: Time in seconds after which a claimed task is considered stale and requeued.
        """
        self.host = host
        self.port = port
        self.endpoint = f"http://{host}:{port}"
        self._task_timeout_seconds = task_timeout_seconds

        # Defer initialization and use event for cross-thread communication
        self._store: Optional[ServerDataStore] = None
        self.loop: Optional[asyncio.AbstractEventLoop] = None
        self.startup_event = threading.Event()

        # Create FastAPI app instance with a lifespan manager
        self._app = FastAPI(lifespan=self._lifespan)
        self._setup_routes()

        self._uvicorn_config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info")
        self._uvicorn_server = uvicorn.Server(self._uvicorn_config)

    # --- ADDED: Lifespan context manager ---
    @asynccontextmanager
    async def _lifespan(self, app: FastAPI):
        """
        Manages server startup and shutdown. This runs inside the server's event loop.
        """
        logger.info("Server is starting up...")
        self.loop = asyncio.get_running_loop()
        self._store = ServerDataStore()  # Initialize data store here
        self.startup_event.set()  # Signal that the server is ready

        yield

        logger.info("Server is shutting down.")
        self._store = None
        self.startup_event.clear()  # Clear the startup event
        self.loop = None

    async def _check_and_requeue_stale_tasks(self):
        """
        Check for stale tasks and requeue them. Called reactively during get_next_task.
        """
        current_time = time.time()
        # Ensure store is initialized before checking
        if not self._store:
            return
        processing_tasks = self._store.get_processing_tasks()

        for rollout_id, task in processing_tasks.items():
            if task.last_claim_time and current_time - task.last_claim_time > self._task_timeout_seconds:
                await self._store.requeue_task(task)
                logger.warning(
                    f"Task {task.rollout_id} timed out after {self._task_timeout_seconds}s, requeued (attempt {task.num_claims})"
                )

    def _setup_routes(self):
        """Setup FastAPI routes."""

        @self._app.get("/task", response_model=TaskIfAny)
        async def next_task() -> TaskIfAny:
            """Endpoint for clients to poll for the next available task."""
            await self._check_and_requeue_stale_tasks()

            if not self._store:
                return TaskIfAny(is_available=False)

            task = await self._store.get_next_task()
            if task:
                logger.debug(f"Serving task {task.rollout_id} to a client.")
                return TaskIfAny(is_available=True, task=task)
            else:
                logger.debug("No task available for client.")
                return TaskIfAny(is_available=False)

        @self._app.get("/resources/latest", response_model=ResourcesUpdate)
        async def fetch_latest_resources() -> ResourcesUpdate:
            """Endpoint for clients to poll for the latest available resources."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            resources_update = await self._store.get_latest_resources()
            if not resources_update:
                raise HTTPException(status_code=404, detail="No resources have been set on the server.")
            logger.debug(f"Serving latest resources '{resources_update.resources_id}' to a client.")
            return resources_update

        @self._app.get("/resources/{resource_id}", response_model=ResourcesUpdate)
        async def fetch_resources_by_id(
            resource_id: str = Path(..., description="The unique identifier for the resource version.")
        ) -> ResourcesUpdate:
            """Endpoint for clients to fetch a specific version of resources."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            resources_update = await self._store.get_resources_by_id(resource_id)
            if not resources_update:
                raise HTTPException(status_code=404, detail=f"Resource ID '{resource_id}' not found.")
            logger.debug(f"Serving resources for ID '{resource_id}' to a client.")
            return resources_update

        @self._app.post("/rollout", response_model=GenericResponse)
        async def post_rollout(payload: Rollout) -> GenericResponse:
            """Endpoint for clients to report a completed rollout."""
            if not self._store:
                raise HTTPException(status_code=503, detail="Server not fully initialized.")
            await self._store.store_rollout(payload)
            return GenericResponse(
                status="ok",
                message=f"Rollout {payload.rollout_id} received and stored.",
            )

    async def start(self):
        """Starts the FastAPI server in the background."""
        logger.info(f"Starting server at {self.endpoint}")
        asyncio.create_task(self._uvicorn_server.serve())
        await asyncio.sleep(1)  # Allow time for server to start up.

    async def stop(self):
        """Gracefully stops the running FastAPI server."""
        if self._uvicorn_server.started:
            logger.info("Stopping server...")
            self._uvicorn_server.should_exit = True
            await asyncio.sleep(1)  # Allow time for graceful shutdown.
            logger.info("Server stopped.")

    async def run_forever(self):
        """
        Runs the server indefinitely until stopped.
        This is useful when async start and stop methods do not work.
        """
        await self._uvicorn_server.serve()

    async def queue_task(
        self,
        sample: Any,
        mode: Literal["train", "val", "test"] | None = None,
        resources_id: str | None = None,
        metadata: Dict[str, Any] | None = None,
    ) -> str:
        """
        Adds a task to the queue for a client to process.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata)

    async def update_resources(self, resources: NamedResources) -> str:
        """
        Updates the resources, creating a new version and setting it as the latest.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        resources_id = f"res-{uuid.uuid4()}"
        update = ResourcesUpdate(resources_id=resources_id, resources=resources)
        await self._store.update_resources(update)
        return resources_id

    async def get_completed_rollout(self, rollout_id: str) -> Optional[Rollout]:
        """
        Retrieves a specific completed rollout by its ID.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.retrieve_rollout(rollout_id)

    async def poll_completed_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
        """
        Polls for a completed rollout by its ID, waiting up to `timeout` seconds.
        """
        start_time = time.time()
        while True:
            rollout = await self.get_completed_rollout(rollout_id)
            if rollout:
                return rollout
            if timeout and (time.time() - start_time) >= timeout:
                return None
            await asyncio.sleep(1)

    async def retrieve_completed_rollouts(self) -> List[Rollout]:
        """
        Retrieves all available completed trajectories and clears the internal store.
        """
        if not self._store:
            raise RuntimeError("Store not initialized. The server may not be running.")
        return await self._store.retrieve_completed_rollouts()

__init__(host='127.0.0.1', port=8000, task_timeout_seconds=300.0)

Initializes the server controller.

Parameters:

Name Type Description Default
host str

The host to bind the server to.

'127.0.0.1'
port int

The port to bind the server to.

8000
task_timeout_seconds float

Time in seconds after which a claimed task is considered stale and requeued.

300.0
Source code in agentlightning/server.py
def __init__(self, host: str = "127.0.0.1", port: int = 8000, task_timeout_seconds: float = 300.0):
    """
    Initializes the server controller.

    Args:
        host: The host to bind the server to.
        port: The port to bind the server to.
        task_timeout_seconds: Time in seconds after which a claimed task is considered stale and requeued.
    """
    self.host = host
    self.port = port
    self.endpoint = f"http://{host}:{port}"
    self._task_timeout_seconds = task_timeout_seconds

    # Defer initialization and use event for cross-thread communication
    self._store: Optional[ServerDataStore] = None
    self.loop: Optional[asyncio.AbstractEventLoop] = None
    self.startup_event = threading.Event()

    # Create FastAPI app instance with a lifespan manager
    self._app = FastAPI(lifespan=self._lifespan)
    self._setup_routes()

    self._uvicorn_config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info")
    self._uvicorn_server = uvicorn.Server(self._uvicorn_config)

get_completed_rollout(rollout_id) async

Retrieves a specific completed rollout by its ID.

Source code in agentlightning/server.py
async def get_completed_rollout(self, rollout_id: str) -> Optional[Rollout]:
    """
    Retrieves a specific completed rollout by its ID.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.retrieve_rollout(rollout_id)

poll_completed_rollout(rollout_id, timeout=None) async

Polls for a completed rollout by its ID, waiting up to timeout seconds.

Source code in agentlightning/server.py
async def poll_completed_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
    """
    Polls for a completed rollout by its ID, waiting up to `timeout` seconds.
    """
    start_time = time.time()
    while True:
        rollout = await self.get_completed_rollout(rollout_id)
        if rollout:
            return rollout
        if timeout and (time.time() - start_time) >= timeout:
            return None
        await asyncio.sleep(1)

queue_task(sample, mode=None, resources_id=None, metadata=None) async

Adds a task to the queue for a client to process.

Source code in agentlightning/server.py
async def queue_task(
    self,
    sample: Any,
    mode: Literal["train", "val", "test"] | None = None,
    resources_id: str | None = None,
    metadata: Dict[str, Any] | None = None,
) -> str:
    """
    Adds a task to the queue for a client to process.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata)

retrieve_completed_rollouts() async

Retrieves all available completed trajectories and clears the internal store.

Source code in agentlightning/server.py
async def retrieve_completed_rollouts(self) -> List[Rollout]:
    """
    Retrieves all available completed trajectories and clears the internal store.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    return await self._store.retrieve_completed_rollouts()

run_forever() async

Runs the server indefinitely until stopped. This is useful when async start and stop methods do not work.

Source code in agentlightning/server.py
async def run_forever(self):
    """
    Runs the server indefinitely until stopped.
    This is useful when async start and stop methods do not work.
    """
    await self._uvicorn_server.serve()

start() async

Starts the FastAPI server in the background.

Source code in agentlightning/server.py
async def start(self):
    """Starts the FastAPI server in the background."""
    logger.info(f"Starting server at {self.endpoint}")
    asyncio.create_task(self._uvicorn_server.serve())
    await asyncio.sleep(1)  # Allow time for server to start up.

stop() async

Gracefully stops the running FastAPI server.

Source code in agentlightning/server.py
async def stop(self):
    """Gracefully stops the running FastAPI server."""
    if self._uvicorn_server.started:
        logger.info("Stopping server...")
        self._uvicorn_server.should_exit = True
        await asyncio.sleep(1)  # Allow time for graceful shutdown.
        logger.info("Server stopped.")

update_resources(resources) async

Updates the resources, creating a new version and setting it as the latest.

Source code in agentlightning/server.py
async def update_resources(self, resources: NamedResources) -> str:
    """
    Updates the resources, creating a new version and setting it as the latest.
    """
    if not self._store:
        raise RuntimeError("Store not initialized. The server may not be running.")
    resources_id = f"res-{uuid.uuid4()}"
    update = ResourcesUpdate(resources_id=resources_id, resources=resources)
    await self._store.update_resources(update)
    return resources_id

AgentLightningTrainer

Bases: RayPPOTrainer

Specialized PPO trainer for agent-based reinforcement learning.

This trainer is designed specifically for scenarios where the model interacts with external environments, tools, or APIs through an AgentLightningServer. It simplifies the training loop by removing the complex conditional logic present in the original RayPPOTrainer and focusing on the agent mode workflow.

Key differences from RayPPOTrainer: 1. Uses AgentModeDaemon for server communication 2. Simplified data flow without pop/union operations 3. Direct batch processing through agent daemon 4. Streamlined validation using agent_mode validation

Source code in agentlightning/verl/trainer.py
class AgentLightningTrainer(RayPPOTrainer):
    """
    Specialized PPO trainer for agent-based reinforcement learning.

    This trainer is designed specifically for scenarios where the model interacts with
    external environments, tools, or APIs through an AgentLightningServer. It simplifies
    the training loop by removing the complex conditional logic present in the original
    RayPPOTrainer and focusing on the agent mode workflow.

    Key differences from RayPPOTrainer:
    1. Uses AgentModeDaemon for server communication
    2. Simplified data flow without pop/union operations
    3. Direct batch processing through agent daemon
    4. Streamlined validation using agent_mode validation
    """

    def _validate(self):
        assert len(self.val_dataloader) == 1, "Please set val_batch_size to None for better throughput."

        test_data = next(iter(self.val_dataloader))
        test_batch = DataProto.from_single_dict(test_data)

        self.async_rollout_manager.wake_up()
        self.agent_mode_daemon.set_up_data_and_server(
            test_batch.non_tensor_batch,
            self.async_rollout_manager.server_addresses,
            is_train=False,
        )
        self.agent_mode_daemon.run_until_all_finished()
        test_metrics = self.agent_mode_daemon.get_test_metrics()
        self.agent_mode_daemon.clear_data_and_server()
        self.async_rollout_manager.sleep()
        return test_metrics

    def _train_step(self, batch_dict: dict) -> dict:
        # Isolate in a separate method to automatically recycle the variables before validation.
        batch: DataProto = DataProto.from_single_dict(batch_dict)
        metrics = {}
        timing_raw = {}

        with _timer("step", timing_raw):

            # When agent mode is enabled, we read the batch as it is.
            gen_batch = batch

            # generate a batch
            with _timer("gen", timing_raw):
                self.async_rollout_manager.wake_up()
                self.agent_mode_daemon.set_up_data_and_server(
                    gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
                )
                self.agent_mode_daemon.run_until_all_finished()
                batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
                    max_prompt_length=self.config.data.max_prompt_length,
                    max_response_length=self.config.data.max_response_length,
                    device=gen_batch.batch["fake_ids"].device,
                )
                metrics.update(agent_metrics)
                self.agent_mode_daemon.clear_data_and_server()
                self.async_rollout_manager.sleep()

            if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                with _timer("gen_max", timing_raw):
                    gen_baseline_batch = deepcopy(gen_batch)
                    gen_baseline_batch.meta_info["do_sample"] = False
                    gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                    batch = batch.union(gen_baseline_output)
                    reward_baseline_tensor = self.reward_fn(batch)
                    reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

                    batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

                    batch.batch["reward_baselines"] = reward_baseline_tensor

                    del gen_baseline_batch, gen_baseline_output

            # uid is used for algorithm like GRPO, should be aligned to data id
            batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]

            batch.batch["response_mask"] = compute_response_mask(batch)

            # compute global_valid tokens
            batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

            with _timer("reward", timing_raw):
                # compute reward model score
                if self.use_rm:
                    reward_tensor = self.rm_wg.compute_rm_score(batch)
                    batch = batch.union(reward_tensor)

                reward_extra_infos_dict = {}

            # for agent mode, pad the lengths to calculate old log prob, ref, and values
            batch, pad_size = pad_dataproto_to_divisor(batch, self.actor_rollout_wg.world_size)

            # recompute old_log_probs
            with _timer("old_log_prob", timing_raw):
                old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                entropys = old_log_prob.batch["entropys"]
                response_masks = batch.batch["response_mask"]
                loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
                old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
                metrics.update(old_log_prob_metrics)
                old_log_prob.batch.pop("entropys")
                batch = batch.union(old_log_prob)

            if self.use_reference_policy:
                # compute reference log_prob
                with _timer("ref", timing_raw):
                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                    batch = batch.union(ref_log_prob)

            # compute values
            if self.use_critic:
                with _timer("values", timing_raw):
                    values = self.critic_wg.compute_values(batch)
                    batch = batch.union(values)

            # for agent mode, unpad to calculate adv
            # it is important, as adv should be based on the raw traces
            batch = unpad_dataproto(batch, pad_size=pad_size)

            with _timer("adv", timing_raw):
                # if agent_mode is enabled, there is already token_level_scores
                # token_level_scores is not needed to compute here

                # compute rewards. apply_kl_penalty if available
                if self.config.algorithm.use_kl_in_reward:
                    batch, kl_metrics = apply_kl_penalty(
                        batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
                    )
                    metrics.update(kl_metrics)
                else:
                    batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                # compute advantages, executed on the driver process

                norm_adv_by_std_in_grpo = self.config.algorithm.get(
                    "norm_adv_by_std_in_grpo", True
                )  # GRPO adv normalization factor

                batch = compute_advantage(
                    batch,
                    adv_estimator=self.config.algorithm.adv_estimator,
                    gamma=self.config.algorithm.gamma,
                    lam=self.config.algorithm.lam,
                    num_repeat=self.config.actor_rollout_ref.rollout.n,
                    norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                    config=self.config.algorithm,
                )

            # after advantages are assinged, we begin to drop (1) long prompt (2) floor to ppo minisize
            keep_indices = (~batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0]
            metrics["agent_mode/n_dropped_sample_because_of_prompt"] = (
                batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0]
            )
            batch = batch[keep_indices]
            # next, round to minibatch size
            mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
            n_transition = len(batch)
            random_indices = list(range(n_transition))
            random.shuffle(random_indices)
            batch.reorder(torch.tensor(random_indices).type(torch.int32))
            n_remained_transition = n_transition // mini_batch_size * mini_batch_size
            batch = batch[list(range(n_remained_transition))]
            metrics["agent_mode/n_dropped_sample_because_of_mini_batch"] = n_transition - n_remained_transition

            # Agent mode note: Change the order of balance batch;
            #     1. first calculate advantage
            #     2. then drop the samples (too long prompt & floor to ppo minisize)
            #     3. balance
            # balance the number of valid tokens on each dp rank.
            # Note that this breaks the order of data inside the batch.
            # Please take care when you implement group based adv computation such as GRPO and rloo
            if self.config.trainer.balance_batch:
                self._balance_batch(batch, metrics=metrics)

            # update critic
            if self.use_critic:
                with _timer("update_critic", timing_raw):
                    critic_output = self.critic_wg.update_critic(batch)
                critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                metrics.update(critic_output_metrics)

            # implement critic warmup
            if self.config.trainer.critic_warmup <= self.global_steps:
                # update actor
                with _timer("update_actor", timing_raw):
                    batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
                    actor_output = self.actor_rollout_wg.update_actor(batch)
                actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                metrics.update(actor_output_metrics)

            # Log rollout generations if enabled
            rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
            if rollout_data_dir:
                with _timer("dump_rollout_generations", timing_raw):
                    print(batch.batch.keys())
                    inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
                    outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
                    scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
                    self._dump_generations(
                        inputs=inputs,
                        outputs=outputs,
                        scores=scores,
                        reward_extra_infos_dict=reward_extra_infos_dict,
                        dump_path=rollout_data_dir,
                    )

        # compute training metrics
        metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
        metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
        # TODO: implement actual tflpo and theoretical tflpo
        n_gpus = self.resource_pool_manager.get_n_gpus()
        metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))

        return metrics

    def fit(self):
        logger = Tracking(
            project_name=self.config.trainer.project_name,
            experiment_name=self.config.trainer.experiment_name,
            default_backend=self.config.trainer.logger,
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        self.global_steps = 0

        # load checkpoint before doing anything
        self._load_checkpoint()

        assert self.async_rollout_mode, "If agent mode is enabled, async server must be enabled"
        self.agent_mode_daemon = AgentModeDaemon(
            self.config.agentlightning.port,
            self.config.actor_rollout_ref.rollout.n,
            train_information={
                # Note (Zhiyuan): To avoid further patch into vllm async server, using the same sentence to get the naming here.
                # However, it is possible that verl updates the naming and causes incompatibility.
                # Reference: https://github.com/volcengine/verl/blob/5b5e09d9cc20625e436d01f69d9cc739ff681c54/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L217
                "model": "/".join(self.config.actor_rollout_ref.model.path.split("/")[-2:]),
                "temperature": self.config.actor_rollout_ref.rollout.temperature,
            },
            tokenizer=self.tokenizer,
            mini_batch_size=self.config.actor_rollout_ref.actor.ppo_mini_batch_size,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        self.agent_mode_daemon.start()

        # perform validation before training
        # currently, we only support validation using the reward_function.
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get("val_only", False):
                return

        # add tqdm
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # we start from step 1
        self.global_steps += 1
        last_val_metrics = None

        for epoch in range(self.config.trainer.total_epochs):
            for batch_dict in self.train_dataloader:
                metrics = {}
                timing_raw = {}
                is_last_step = self.global_steps >= self.total_training_steps

                # train step
                metrics = self._train_step(batch_dict)

                # validate
                if (
                    self.val_reward_fn is not None
                    and self.config.trainer.test_freq > 0
                    and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
                ):
                    with _timer("validate", timing_raw):
                        val_metrics: dict = self._validate()
                        if is_last_step:
                            last_val_metrics = val_metrics
                    metrics.update(val_metrics)

                if self.config.trainer.save_freq > 0 and (
                    is_last_step or self.global_steps % self.config.trainer.save_freq == 0
                ):
                    with _timer("save_checkpoint", timing_raw):
                        self._save_checkpoint()

                # step metrics
                metrics.update(
                    {
                        "training/global_step": self.global_steps,
                        "training/epoch": epoch,
                    }
                )

                # TODO: make a canonical logger that supports various backend
                logger.log(data=metrics, step=self.global_steps)

                if is_last_step:
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()

                    # This exit logic is to ensure a robust CI.
                    pprint(f"Flush the logger...")
                    del logger  # Make sure the loggers are flushed and closed properly
                    pprint(f"Training finished at step {self.global_steps}.")
                    return

                progress_bar.update(1)
                self.global_steps += 1

AgentModeDaemon

AgentModeDaemon using the AgentLightningServer SDK.

This class manages the server lifecycle, task queueing, and results retrieval, while also running a proxy server for LLM requests. It maintains the original interface for compatibility with the RayPPOTrainer.

Source code in agentlightning/verl/daemon.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
class AgentModeDaemon:
    """
    AgentModeDaemon using the AgentLightningServer SDK.

    This class manages the server lifecycle, task queueing, and results
    retrieval, while also running a proxy server for LLM requests. It maintains
    the original interface for compatibility with the RayPPOTrainer.
    """

    def __init__(
        self,
        port,
        train_rollout_n,
        train_information,
        tokenizer,
        mini_batch_size,
        pad_token_id,
        reward_fillna_value=0.0,
        llm_timeout_seconds=600.0,
    ):
        # Server and Task Configuration
        self.server_port = port
        self.llm_timeout_seconds = llm_timeout_seconds
        self.server = AgentLightningServer(
            host="0.0.0.0", port=self.server_port, task_timeout_seconds=self.llm_timeout_seconds
        )
        self.proxy_port = _find_available_port()  # Run proxy on a different port

        # Training and Data Configuration
        self.train_rollout_n = train_rollout_n
        self.train_information = train_information
        self.mini_batch_size = mini_batch_size
        self.pad_token_id = pad_token_id
        self.tokenizer = tokenizer
        self.reward_fillna_value = reward_fillna_value

        # Internal State
        self.backend_llm_server_addresses: List[str] = []
        self._total_tasks_queued = 0
        self._completed_rollouts: Dict[str, Rollout] = {}
        self._task_id_to_original_sample: Dict[str, Dict] = {}
        self._server_thread: Optional[threading.Thread] = None
        self._proxy_thread: Optional[threading.Thread] = None
        self.is_train = True

    def _start_proxy_server(self):
        """
        Initializes and runs a Flask-based proxy server in a separate thread.
        This proxy load-balances requests to the actual backend LLM servers.
        """
        app = Flask(__name__)

        num_requests = 0
        last_request_time = 0

        @app.route("/v1/<path:path>", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
        def proxy(path):
            if not self.backend_llm_server_addresses:
                abort(503, description="No backend LLM servers available.")

            # Randomly choose a backend server for load balancing
            target_server = random.choice(self.backend_llm_server_addresses)
            target_url = f"http://{target_server}/v1/{path}"

            # Copy client request headers, removing the Host header
            headers = {key: value for key, value in request.headers if key.lower() != "host"}

            # Log the request for debugging
            nonlocal num_requests, last_request_time
            current_time = time.time()
            num_requests += 1
            if current_time - last_request_time > 60 or num_requests == 1 or num_requests % 100 == 0:
                print(f"Proxying {request.method} request to {target_server}. Request data: {request.get_data()}")
            last_request_time = current_time

            try:
                # Forward the request to the target backend
                resp = requests.request(
                    method=request.method,
                    url=target_url,
                    headers=headers,
                    params=request.args,
                    data=request.get_data(),
                    cookies=request.cookies,
                    allow_redirects=False,
                    timeout=self.llm_timeout_seconds,
                )
                # Filter out hop-by-hop headers before returning the response
                excluded_headers = [
                    "content-encoding",
                    "content-length",
                    "transfer-encoding",
                    "connection",
                    "keep-alive",
                    "proxy-authenticate",
                    "proxy-authorization",
                    "te",
                    "trailers",
                    "upgrade",
                ]
                response_headers = [
                    (name, value) for name, value in resp.raw.headers.items() if name.lower() not in excluded_headers
                ]
                if resp.status_code == 200:
                    # NOTE: from Zhiyuan's code.
                    # https://github.com/hzy46/verl_agent_mode/blob/2db65ea9858f645a914120357412a7540f8bd82d/verl/trainer/ppo/ray_trainer.py#L692-L711
                    # request_json = json.loads(request.get_data().decode("utf-8"))
                    response_json = json.loads(resp.content.decode("utf-8"))
                    # response_message = ChatCompletion(**response_json).choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
                    # tool_schemas = request_json.get("tools", None)
                    # prompt_ids = self.tokenizer.apply_chat_template(request_json["messages"], tools=tool_schemas, add_generation_prompt=True, tokenize=True)
                    # full_ids = self.tokenizer.apply_chat_template(request_json["messages"] + [response_message], tools=tool_schemas, add_generation_prompt=False, tokenize=True)
                    # TBD: response_ids sometimes ends with "<eos_id>\n", shall we keep the extra "\n"?
                    # sometimes it has some differences with the hacky method in the end, but this should align with ToolCompletionCallback
                    # response_ids = full_ids[len(prompt_ids):]

                    # NOTE (yuge): They are different. Don't know why.
                    # assert response_json['prompt_token_ids'] == prompt_ids
                    # patched_response_ids = response_json['response_token_ids'][0]
                    # assert patched_response_ids == response_ids[:len(patched_response_ids)], f"{patched_response_ids} != {response_ids[:len(patched_response_ids)]}"
                    # response_json['prompt_token_ids'] = prompt_ids
                    # response_json['response_token_ids'] = [response_ids]
                    replaced_return_content = json.dumps(response_json).encode("utf-8")
                    return Response(replaced_return_content, status=resp.status_code, headers=response_headers)
                return Response(resp.content, resp.status_code, response_headers)
            except requests.exceptions.RequestException as e:
                abort(500, description=f"Error proxying request: {e}")

        def run_app():
            app.run(host="0.0.0.0", port=self.proxy_port, threaded=True, debug=False)

        self._proxy_thread = threading.Thread(target=run_app, daemon=True)
        self._proxy_thread.start()
        print(f"Proxy server running on port {self.proxy_port}")

    def start(self):
        """Starts the main AgentLightningServer and the proxy server."""

        def run_server():
            """Run the AgentLightningServer in a separate thread."""
            asyncio.run(self.server.run_forever())

        self._server_thread = threading.Thread(target=run_server, daemon=True)
        self._server_thread.start()

        # Wait for the server's internal startup event to be set.
        print("Waiting for AgentLightningServer to start...")
        is_ready = self.server.startup_event.wait(timeout=20.0)  # Wait up to 20s
        if not is_ready:
            raise RuntimeError("AgentLightningServer failed to start within the timeout period.")

        print(f"AgentLightningServer control plane running on port {self.server_port}")

        self._start_proxy_server()

    async def _async_set_up(self, data, server_addresses, is_train=True):
        """Async helper to set up data and resources on the server."""
        self.clear_data_and_server()
        self.backend_llm_server_addresses = server_addresses
        self.is_train = is_train

        # 1. Update resources on the server for clients to use
        llm_resource = LLM(
            endpoint=f"http://127.0.0.1:{self.proxy_port}/v1",
            model=self.train_information.get("model", "default-model"),
            sampling_parameters={"temperature": self.train_information.get("temperature", 0.7)},
        )
        resources: NamedResources = {"main_llm": llm_resource}
        resources_id = await self.server.update_resources(resources)

        # 2. Queue tasks for agents to process
        keys = list(data.keys())
        num_samples = len(data[keys[0]])
        rollouts_per_sample = self.train_rollout_n if is_train else 1

        for i in range(num_samples):
            data_id = str(uuid.uuid4())
            original_sample = {key: data[key][i] for key in keys}
            original_sample["data_id"] = data_id

            # For training, each sample is rolled out multiple times
            for j in range(rollouts_per_sample):
                task_metadata = {"data_id": data_id, "is_train": is_train}

                # Data ID is different from Rollout ID, as one data can have multiple rollouts.
                rollout_id = await self.server.queue_task(
                    sample=original_sample,
                    mode="train" if is_train else "val",
                    resources_id=resources_id,
                    metadata=task_metadata,
                )
                # Store original sample data to reconstruct batch information later
                self._task_id_to_original_sample[rollout_id] = original_sample
                self._total_tasks_queued += 1

    def set_up_data_and_server(self, data, server_addresses, is_train=True):
        """Synchronous wrapper for setting up data and server resources."""
        if not self.server.loop or not self.server.startup_event.is_set():
            raise RuntimeError("Server is not running or ready.")

        coro = self._async_set_up(data, server_addresses, is_train)
        future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)
        try:
            future.result(timeout=60)  # Wait for completion with a timeout
        except Exception as e:
            print(f"Failed to set up data on server: {e}")
            raise

    def _validate_data(self, rollout: Rollout):
        if rollout.final_reward is None:
            print(
                f"Warning: Reward is None for rollout {rollout.rollout_id}, will be auto-set to {self.reward_fillna_value}."
            )
        if rollout.triplets is None:
            print(f"Warning: Triplet is None for rollout {rollout.rollout_id}.")
        elif len(rollout.triplets) == 0:
            print(f"Warning: Length of triplets is 0 for rollout {rollout.rollout_id}.")
        elif any(not r.response.get("token_ids", []) for r in rollout.triplets):
            print(f"Warning: Rollout {rollout.rollout_id} contains empty response: {rollout.triplets}")
        elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets):
            print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}")

    async def _async_run_until_finished(self, verbose=True):
        """Async helper to wait for all tasks to complete."""
        while len(self._completed_rollouts) < self._total_tasks_queued:
            completed_batch = await self.server.retrieve_completed_rollouts()
            for rollout in completed_batch:
                self._validate_data(rollout)
                self._completed_rollouts[rollout.rollout_id] = rollout
            if verbose:
                print(f"Completed {len(self._completed_rollouts)}/{self._total_tasks_queued} tasks...")
            await asyncio.sleep(5)
        print("All tasks finished.")

    def run_until_all_finished(self, verbose=True):
        """Synchronously waits for all queued tasks to be completed and reported."""
        if self._total_tasks_queued == 0:
            print("Warning: No tasks were queued.")
            return

        if not self.server.loop or not self.server.startup_event.is_set():
            raise RuntimeError("Server is not running or ready.")

        coro = self._async_run_until_finished(verbose)
        future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)
        try:
            future.result()  # Wait indefinitely for all tasks to complete
        except Exception as e:
            print(f"Error while waiting for tasks to finish: {e}")
            raise

    def get_test_metrics(self):
        """Calculates and returns metrics for a validation run."""
        assert not self.is_train, "This method should only be called during validation."
        assert len(self._completed_rollouts) == self._total_tasks_queued

        sample_stat_list = []
        for rollout_id, rollout in self._completed_rollouts.items():
            if not rollout.triplets:
                continue
            response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
            final_reward = self._fillna_reward(rollout)
            sample_stat_list.append(
                {
                    "sum_response_length": np.sum(response_length_list),
                    "mean_response_length": np.mean(response_length_list) if response_length_list else 0,
                    "turn_count": len(rollout.triplets),
                    "reward": final_reward,
                }
            )

        return {
            "val/reward": np.mean([stat["reward"] for stat in sample_stat_list]),
            "val/mean_response_length": np.mean([stat["mean_response_length"] for stat in sample_stat_list]),
            "val/sum_response_length": np.mean([stat["sum_response_length"] for stat in sample_stat_list]),
            "val/turn_count": np.mean([stat["turn_count"] for stat in sample_stat_list]),
        }

    def get_train_data_batch(self, max_prompt_length, max_response_length, device):
        """
        Processes completed rollouts to generate a training data batch.

        This function reconstructs the logic from the original AgentModeDaemon,
        using data retrieved from the new server architecture. It handles padding,
        truncation, and tensor creation for the PPO training loop.
        """
        assert self.is_train, "This method should only be called during training."
        assert len(self._completed_rollouts) == self._total_tasks_queued

        # 1. Reconstruct the `finished_id_to_sample_info` structure from completed rollouts
        finished_id_to_sample_info = {}
        for rollout_id, rollout in self._completed_rollouts.items():
            original_sample = self._task_id_to_original_sample[rollout_id]

            if not rollout.triplets:
                continue

            # The client should report triplets that contain prompt_ids and response_ids.
            # Example triplet.prompt: {"token_ids": [...]}
            # Example triplet.response: {"token_ids": [...]}
            trace_list = [
                {"prompt_ids": t.prompt.get("token_ids", []), "response_ids": t.response.get("token_ids", [])}
                for t in rollout.triplets
            ]

            final_reward = self._fillna_reward(rollout)
            info = {
                "reward": final_reward,
                "trace_list": trace_list,
                "data_id": original_sample["data_id"],
            }
            finished_id_to_sample_info[rollout_id] = info
        #
        # --- Data processing and tensor creation logic ---
        # Get all the reported data.
        # prompt_ids are left-padded.
        # response_ids are right-padded.
        # They are concatenated in the middle.
        # Discard handling:
        #   - Those exceeding max_prompt_length will be marked for discard, but not
        #     discarded here. They are only truncated and marked, to be discarded later.
        #     This is for the correctness of the advantage calculation.
        #   - The discard for the PPO mini-batch should also be handled this way.
        input_ids_list, input_attention_mask_list = [], []
        response_ids_list, response_attention_mask_list = [], []
        reward_list, data_id_list, rollout_id_list, turn_index_list, is_drop_list = [], [], [], [], []
        n_trunc_sample_because_of_response = 0

        for rollout_id, sample_info in finished_id_to_sample_info.items():
            for turn_index, trace in enumerate(sample_info["trace_list"]):

                reward_list.append(sample_info["reward"])
                prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]

                # Mark samples with prompts exceeding max_prompt_length to be dropped later
                if len(prompt_ids) > max_prompt_length:
                    prompt_ids = prompt_ids[:max_prompt_length]
                    is_drop_list.append(True)
                else:
                    is_drop_list.append(False)

                # Truncate responses that exceed max_response_length
                if len(response_ids) > max_response_length:
                    response_ids = response_ids[:max_response_length]
                    n_trunc_sample_because_of_response += 1

                # Pad prompts to the left and responses to the right
                one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
                    prompt_ids, max_prompt_length, self.pad_token_id
                )
                one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
                    response_ids, max_response_length, self.pad_token_id
                )

                input_ids_list.append(one_input_ids)
                input_attention_mask_list.append(one_input_attention_mask)
                response_ids_list.append(one_response_ids)
                response_attention_mask_list.append(one_response_attention_mask)
                data_id_list.append(sample_info["data_id"])
                rollout_id_list.append(rollout_id)
                turn_index_list.append(turn_index)

        n_transition = len(input_ids_list)
        batch_input_ids = torch.LongTensor(input_ids_list).to(device)
        input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
        batch_response_ids = torch.LongTensor(response_ids_list).to(device)
        response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)

        # Concatenate prompts and responses to form the full sequence
        batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
        attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1)
        position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
        is_drop_mask = torch.BoolTensor(is_drop_list).to(device)
        scores = torch.tensor(reward_list, dtype=torch.bfloat16).to(device)

        # Create token-level scores by placing the final reward at the last token position
        token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
        # At the eos_mask_idx position of each sample, fill in the corresponding scores.
        # torch.arange(n_transition) generates [0,1,2,...,bsz-1] as indices for the batch dimension.
        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)
        token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores
        # Only take the last response_length part of the sequence to get the token-level scores for the model's response part.
        token_level_scores = token_level_scores[:, -max_response_length:]

        # Form the final batch using TensorDict
        batch = TensorDict(
            {
                "prompts": batch_input_ids,
                "responses": batch_response_ids,
                "input_ids": batch_seq,  # here input_ids become the whole sentences
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "is_drop_mask": is_drop_mask,
                "token_level_scores": token_level_scores.contiguous(),
            },
            batch_size=n_transition,
        )
        data_proto = DataProto(batch=batch)

        data_metrics = {
            "agent_mode/n_trunc_sample_because_of_response": n_trunc_sample_because_of_response,
            "agent_mode/n_sample_to_train": n_transition,
        }

        # Add non-tensor data for advantage calculation and logging
        data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list)
        data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list)
        data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list)

        return data_proto, data_metrics

    def clear_data_and_server(self):
        """Resets the internal state of the daemon for the next run."""
        self.backend_llm_server_addresses = []
        self._completed_rollouts.clear()
        self._task_id_to_original_sample.clear()
        self._total_tasks_queued = 0
        # For a true reset, the server's internal queues would also need clearing.
        # This implementation assumes that `set_up_data_and_server` is called
        # for each new run, effectively starting a fresh batch.

    def _fillna_reward(self, rollout):
        if rollout.final_reward is None:
            if self.reward_fillna_value is not None:
                final_reward = self.reward_fillna_value
            else:
                raise ValueError(f"Reward is None for rollout {rollout.rollout_id}, please check the reward function.")
        else:
            final_reward = rollout.final_reward
        return final_reward

clear_data_and_server()

Resets the internal state of the daemon for the next run.

Source code in agentlightning/verl/daemon.py
def clear_data_and_server(self):
    """Resets the internal state of the daemon for the next run."""
    self.backend_llm_server_addresses = []
    self._completed_rollouts.clear()
    self._task_id_to_original_sample.clear()
    self._total_tasks_queued = 0

get_test_metrics()

Calculates and returns metrics for a validation run.

Source code in agentlightning/verl/daemon.py
def get_test_metrics(self):
    """Calculates and returns metrics for a validation run."""
    assert not self.is_train, "This method should only be called during validation."
    assert len(self._completed_rollouts) == self._total_tasks_queued

    sample_stat_list = []
    for rollout_id, rollout in self._completed_rollouts.items():
        if not rollout.triplets:
            continue
        response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
        final_reward = self._fillna_reward(rollout)
        sample_stat_list.append(
            {
                "sum_response_length": np.sum(response_length_list),
                "mean_response_length": np.mean(response_length_list) if response_length_list else 0,
                "turn_count": len(rollout.triplets),
                "reward": final_reward,
            }
        )

    return {
        "val/reward": np.mean([stat["reward"] for stat in sample_stat_list]),
        "val/mean_response_length": np.mean([stat["mean_response_length"] for stat in sample_stat_list]),
        "val/sum_response_length": np.mean([stat["sum_response_length"] for stat in sample_stat_list]),
        "val/turn_count": np.mean([stat["turn_count"] for stat in sample_stat_list]),
    }

get_train_data_batch(max_prompt_length, max_response_length, device)

Processes completed rollouts to generate a training data batch.

This function reconstructs the logic from the original AgentModeDaemon, using data retrieved from the new server architecture. It handles padding, truncation, and tensor creation for the PPO training loop.

Source code in agentlightning/verl/daemon.py
def get_train_data_batch(self, max_prompt_length, max_response_length, device):
    """
    Processes completed rollouts to generate a training data batch.

    This function reconstructs the logic from the original AgentModeDaemon,
    using data retrieved from the new server architecture. It handles padding,
    truncation, and tensor creation for the PPO training loop.
    """
    assert self.is_train, "This method should only be called during training."
    assert len(self._completed_rollouts) == self._total_tasks_queued

    # 1. Reconstruct the `finished_id_to_sample_info` structure from completed rollouts
    finished_id_to_sample_info = {}
    for rollout_id, rollout in self._completed_rollouts.items():
        original_sample = self._task_id_to_original_sample[rollout_id]

        if not rollout.triplets:
            continue

        # The client should report triplets that contain prompt_ids and response_ids.
        # Example triplet.prompt: {"token_ids": [...]}
        # Example triplet.response: {"token_ids": [...]}
        trace_list = [
            {"prompt_ids": t.prompt.get("token_ids", []), "response_ids": t.response.get("token_ids", [])}
            for t in rollout.triplets
        ]

        final_reward = self._fillna_reward(rollout)
        info = {
            "reward": final_reward,
            "trace_list": trace_list,
            "data_id": original_sample["data_id"],
        }
        finished_id_to_sample_info[rollout_id] = info
    #
    # --- Data processing and tensor creation logic ---
    # Get all the reported data.
    # prompt_ids are left-padded.
    # response_ids are right-padded.
    # They are concatenated in the middle.
    # Discard handling:
    #   - Those exceeding max_prompt_length will be marked for discard, but not
    #     discarded here. They are only truncated and marked, to be discarded later.
    #     This is for the correctness of the advantage calculation.
    #   - The discard for the PPO mini-batch should also be handled this way.
    input_ids_list, input_attention_mask_list = [], []
    response_ids_list, response_attention_mask_list = [], []
    reward_list, data_id_list, rollout_id_list, turn_index_list, is_drop_list = [], [], [], [], []
    n_trunc_sample_because_of_response = 0

    for rollout_id, sample_info in finished_id_to_sample_info.items():
        for turn_index, trace in enumerate(sample_info["trace_list"]):

            reward_list.append(sample_info["reward"])
            prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]

            # Mark samples with prompts exceeding max_prompt_length to be dropped later
            if len(prompt_ids) > max_prompt_length:
                prompt_ids = prompt_ids[:max_prompt_length]
                is_drop_list.append(True)
            else:
                is_drop_list.append(False)

            # Truncate responses that exceed max_response_length
            if len(response_ids) > max_response_length:
                response_ids = response_ids[:max_response_length]
                n_trunc_sample_because_of_response += 1

            # Pad prompts to the left and responses to the right
            one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
                prompt_ids, max_prompt_length, self.pad_token_id
            )
            one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
                response_ids, max_response_length, self.pad_token_id
            )

            input_ids_list.append(one_input_ids)
            input_attention_mask_list.append(one_input_attention_mask)
            response_ids_list.append(one_response_ids)
            response_attention_mask_list.append(one_response_attention_mask)
            data_id_list.append(sample_info["data_id"])
            rollout_id_list.append(rollout_id)
            turn_index_list.append(turn_index)

    n_transition = len(input_ids_list)
    batch_input_ids = torch.LongTensor(input_ids_list).to(device)
    input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
    batch_response_ids = torch.LongTensor(response_ids_list).to(device)
    response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)

    # Concatenate prompts and responses to form the full sequence
    batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
    attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1)
    position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
    is_drop_mask = torch.BoolTensor(is_drop_list).to(device)
    scores = torch.tensor(reward_list, dtype=torch.bfloat16).to(device)

    # Create token-level scores by placing the final reward at the last token position
    token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
    # At the eos_mask_idx position of each sample, fill in the corresponding scores.
    # torch.arange(n_transition) generates [0,1,2,...,bsz-1] as indices for the batch dimension.
    eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)
    token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores
    # Only take the last response_length part of the sequence to get the token-level scores for the model's response part.
    token_level_scores = token_level_scores[:, -max_response_length:]

    # Form the final batch using TensorDict
    batch = TensorDict(
        {
            "prompts": batch_input_ids,
            "responses": batch_response_ids,
            "input_ids": batch_seq,  # here input_ids become the whole sentences
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "is_drop_mask": is_drop_mask,
            "token_level_scores": token_level_scores.contiguous(),
        },
        batch_size=n_transition,
    )
    data_proto = DataProto(batch=batch)

    data_metrics = {
        "agent_mode/n_trunc_sample_because_of_response": n_trunc_sample_because_of_response,
        "agent_mode/n_sample_to_train": n_transition,
    }

    # Add non-tensor data for advantage calculation and logging
    data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list)
    data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list)
    data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list)

    return data_proto, data_metrics

run_until_all_finished(verbose=True)

Synchronously waits for all queued tasks to be completed and reported.

Source code in agentlightning/verl/daemon.py
def run_until_all_finished(self, verbose=True):
    """Synchronously waits for all queued tasks to be completed and reported."""
    if self._total_tasks_queued == 0:
        print("Warning: No tasks were queued.")
        return

    if not self.server.loop or not self.server.startup_event.is_set():
        raise RuntimeError("Server is not running or ready.")

    coro = self._async_run_until_finished(verbose)
    future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)
    try:
        future.result()  # Wait indefinitely for all tasks to complete
    except Exception as e:
        print(f"Error while waiting for tasks to finish: {e}")
        raise

set_up_data_and_server(data, server_addresses, is_train=True)

Synchronous wrapper for setting up data and server resources.

Source code in agentlightning/verl/daemon.py
def set_up_data_and_server(self, data, server_addresses, is_train=True):
    """Synchronous wrapper for setting up data and server resources."""
    if not self.server.loop or not self.server.startup_event.is_set():
        raise RuntimeError("Server is not running or ready.")

    coro = self._async_set_up(data, server_addresses, is_train)
    future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)
    try:
        future.result(timeout=60)  # Wait for completion with a timeout
    except Exception as e:
        print(f"Failed to set up data on server: {e}")
        raise

start()

Starts the main AgentLightningServer and the proxy server.

Source code in agentlightning/verl/daemon.py
def start(self):
    """Starts the main AgentLightningServer and the proxy server."""

    def run_server():
        """Run the AgentLightningServer in a separate thread."""
        asyncio.run(self.server.run_forever())

    self._server_thread = threading.Thread(target=run_server, daemon=True)
    self._server_thread.start()

    # Wait for the server's internal startup event to be set.
    print("Waiting for AgentLightningServer to start...")
    is_ready = self.server.startup_event.wait(timeout=20.0)  # Wait up to 20s
    if not is_ready:
        raise RuntimeError("AgentLightningServer failed to start within the timeout period.")

    print(f"AgentLightningServer control plane running on port {self.server_port}")

    self._start_proxy_server()

LLM

Bases: Resource

Provide an LLM endpoint and model name as a resource.

Attributes:

Name Type Description
endpoint str

The URL of the LLM API endpoint.

model str

The identifier for the model to be used (e.g., 'gpt-4o').

sampling_parameters SamplingParameters

A dictionary of hyperparameters for model inference, such as temperature, top_p, etc.

Source code in agentlightning/types.py
class LLM(Resource):
    """
    Provide an LLM endpoint and model name as a resource.

    Attributes:
        endpoint (str): The URL of the LLM API endpoint.
        model (str): The identifier for the model to be used (e.g., 'gpt-4o').
        sampling_parameters (SamplingParameters): A dictionary of hyperparameters
            for model inference, such as temperature, top_p, etc.
    """

    resource_type: Literal["llm"] = "llm"
    endpoint: str
    model: str
    sampling_parameters: Dict[str, Any] = Field(default_factory=dict)

Rollout

Bases: BaseModel

The standard reporting object from client to server.

Source code in agentlightning/types.py
class Rollout(BaseModel):
    """The standard reporting object from client to server."""

    rollout_id: str

    # Primary, high-level feedback
    final_reward: Optional[float] = None

    # Structured, sequential feedback for RL-style optimization
    triplets: Optional[List[Triplet]] = None

    # Optional, rich-context data for deep analysis
    trace: Optional[List[Dict[str, Any]]] = Field(
        default=None,
        description="A list of spans that conform to the OpenTelemetry JSON format. "
        "Users of the opentelemetry-sdk can generate this by calling "
        "json.loads(readable_span.to_json()).",
    )
    logs: Optional[List[str]] = None

    # A bucket for any other relevant information
    metadata: Dict[str, Any] = Field(default_factory=dict)

get_left_padded_ids_and_attention_mask(ids, max_length, pad_token_id)

Left-pad (or truncate) a sequence of token IDs to a fixed length, and build the corresponding attention mask.

Parameters:

Name Type Description Default
ids List[int]

the original list of token IDs.

required
max_length int

desired total length after padding/truncation.

required
pad_token_id int

ID to use for padding.

required

Returns:

Name Type Description
padded_ids any

list of length == max_length.

attention_mask any

list of same length: 1 for non-pad tokens, 0 for pads.

Source code in agentlightning/verl/daemon.py
def get_left_padded_ids_and_attention_mask(ids: List[int], max_length: int, pad_token_id: int):
    """
    Left-pad (or truncate) a sequence of token IDs to a fixed length,
    and build the corresponding attention mask.

    Args:
        ids:             the original list of token IDs.
        max_length:      desired total length after padding/truncation.
        pad_token_id:    ID to use for padding.

    Returns:
        padded_ids (any):      list of length == max_length.
        attention_mask (any):  list of same length: 1 for non-pad tokens, 0 for pads.
    """
    seq_len = len(ids)

    if seq_len >= max_length:
        # too long → truncate from the left, keep the last max_length tokens
        trimmed = ids[-max_length:]
        attention_mask = [1] * max_length
        return trimmed, attention_mask

    # too short → pad on the left
    pad_len = max_length - seq_len
    padded_ids = [pad_token_id] * pad_len + ids
    attention_mask = [0] * pad_len + [1] * seq_len
    return padded_ids, attention_mask

get_right_padded_ids_and_attention_mask(ids, max_length, pad_token_id)

Right-pad (or truncate) a sequence of token IDs to a fixed length, and build the corresponding attention mask.

Parameters:

Name Type Description Default
ids List[int]

the original list of token IDs.

required
max_length int

desired total length after padding/truncation.

required
pad_token_id int

ID to use for padding.

required

Returns:

Name Type Description
padded_ids any

list of length == max_length.

attention_mask any

list of same length: 1 for non-pad tokens, 0 for pads.

Source code in agentlightning/verl/daemon.py
def get_right_padded_ids_and_attention_mask(ids: List[int], max_length: int, pad_token_id: int):
    """
    Right-pad (or truncate) a sequence of token IDs to a fixed length,
    and build the corresponding attention mask.

    Args:
        ids:            the original list of token IDs.
        max_length:     desired total length after padding/truncation.
        pad_token_id:   ID to use for padding.

    Returns:
        padded_ids (any):     list of length == max_length.
        attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads.
    """
    seq_len = len(ids)

    if seq_len >= max_length:
        # too long → truncate to the first max_length tokens
        trimmed = ids[:max_length]
        attention_mask = [1] * max_length
        return trimmed, attention_mask

    # too short → pad on the right
    pad_len = max_length - seq_len
    padded_ids = ids + [pad_token_id] * pad_len
    attention_mask = [1] * seq_len + [0] * pad_len
    return padded_ids, attention_mask