What is Batcher?

The Batcher class defines how mutiple requests from users are merged into a batch request, and how the batch response is splitted into multiple responses for users. It exposes two interface functions for developers to implement their own batcher that fits a specific scenario.

  • batch. Merge multiple requests into a batch request, and attach a context object for unbatching.

  • unbatch. Split a batch response into multiple responses, with the help of the context object.

Built-in Batchers

The following built-in batchers are provided in batch_inference.batcher module. Both numpy.ndarry and torch.Tensor are supported as input date types.

  • ConcatBatcher. Simply concatenate multiple requests into a single batch request.

  • SeqBatcher. Pad sequences of different lengths with padding tokens before concatenation.

  • BucketSeqBatcher. Group sequences of similar lengths, pad them with padding tokens and then concatenate.

Implement Customized Batcher

The following example shows how to implement a customized batcher. The batcher merges multiple requests into a single batch request, and splits the batch response into multiple responses.

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from typing import Any, List, Tuple

import numpy as np

from batch_inference import ModelHost
from batch_inference.batcher import Batcher


class MyModel:
    def __init__(self):
        self.op = np.matmul

    # x.shape: [batch_size, m, k], y.shape: [batch_size, k, n]
    def predict_batch(self, x, y):
        res = self.op(x, y)
        return res


class MyBatcher(Batcher):
    def __init__(self):
        super().__init__()

    def batch(self, requests: List[Tuple[np.ndarray]]):
        """Batch n requests into 1 batched request

        Args:
            requests: [(x, y)], each request is a (x, y) from predict method

        Returns:
            batched requests: (x_batched, y_batched) for predict_batch method
            context for unbatch: List[int], the batch sizes of each original (x, y)
        """

        x_batched = np.concatenate([item[0] for item in requests], axis=0)
        y_batched = np.concatenate([item[1] for item in requests], axis=0)
        batch_sizes = [item[0].shape[0] for item in requests]
        return (x_batched, y_batched), batch_sizes

    def unbatch(
        self,
        batched_response: np.ndarray,
        unbatch_ctx: List[int],
    ):
        """Unbatch 1 batched response into n responses

        Args:
            batched_responses: batched_res from predict_batch method,
                               batched_res=batched_x * batched_y
            unbatch_ctx: batch_sizes of n original requests

        Returns:
            responses: [res1, res2, ...], each res will be returned by predict method,
                       res=x * y
        """

        batch_sizes = unbatch_ctx
        responses = []
        start = 0
        for n in batch_sizes:
            responses.append(batched_response[start : start + n])
            start += n
        return responses


model_host = ModelHost(
    MyModel,
    batcher=MyBatcher(),
    max_batch_size=32,
)()

The MultiBatcher Class

In most cases, we merge multiple requests from users into a single batch request. ConcatBatcher and SeqBatcher are examples of this.

Sometimes, we want to group multiple requests first and then merged requests in each group into a batch request. In this case, the Batcher should inherits the batcher.MultiBatcher class instead of batche.Batcher. BucketSeqBatcher is an example of this.