Skip to main content
Version: 0.11.1

ONNX Inference on Spark

In this example, we will train a LightGBM model, convert the model to ONNX format and use the converted model to infer some testing data on Spark.

Python dependencies:

  • onnxmltools==1.7.0
  • lightgbm==3.2.1

Load training data

from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

from import *

from import materializing_display as display
df = ("csv")
.option("header", True)
.option("inferSchema", True)


Use LightGBM to train a model

from import VectorAssembler
from import LightGBMClassifier

feature_cols = df.columns[1:]
featurizer = VectorAssembler(inputCols=feature_cols, outputCol="features")

train_data = featurizer.transform(df)["Bankrupt?", "features"]

model = (
LightGBMClassifier(featuresCol="features", labelCol="Bankrupt?")

model =

Export the trained model to a LightGBM booster, convert it to ONNX format.

from import running_on_binder

if running_on_binder():
!pip install lightgbm==3.2.1
from IPython import get_ipython
import lightgbm as lgb
from lightgbm import Booster, LGBMClassifier

def convertModel(lgbm_model: LGBMClassifier or Booster, input_size: int) -> bytes:
from onnxmltools.convert import convert_lightgbm
from onnxconverter_common.data_types import FloatTensorType

initial_types = [("input", FloatTensorType([-1, input_size]))]
onnx_model = convert_lightgbm(
lgbm_model, initial_types=initial_types, target_opset=9
return onnx_model.SerializeToString()

booster_model_str = model.getLightGBMBooster().modelStr().get()
booster = lgb.Booster(model_str=booster_model_str)
model_payload_ml = convertModel(booster, len(feature_cols))

Load the ONNX payload into an ONNXModel, and inspect the model inputs and outputs.

from import ONNXModel

onnx_ml = ONNXModel().setModelPayload(model_payload_ml)

print("Model inputs:" + str(onnx_ml.getModelInputs()))
print("Model outputs:" + str(onnx_ml.getModelOutputs()))

Map the model input to the input dataframe's column name (FeedDict), and map the output dataframe's column names to the model outputs (FetchDict).

onnx_ml = (
.setFeedDict({"input": "features"})
.setFetchDict({"probability": "probabilities", "prediction": "label"})

Create some testing data and transform the data through the ONNX model.

from import VectorAssembler
import pandas as pd
import numpy as np

n = 1000 * 1000
m = 95
test = np.random.rand(n, m)
testPdf = pd.DataFrame(test)
cols = list(map(str, testPdf.columns))
testDf = spark.createDataFrame(testPdf)
testDf = testDf.union(testDf).repartition(200)
testDf = (