Generating a FunctionProto

The example below shows how we can define Selu as a function in onnxscript.

First, import the ONNX opset used to define the function.

from onnxscript import opset15 as op
from onnxscript import script

Next, define Selu as an ONNXScript function.

@script()
def Selu(X, alpha: float, gamma: float):
    alphaX = op.CastLike(alpha, X)
    gammaX = op.CastLike(gamma, X)
    neg = gammaX * (alphaX * op.Exp(X) - alphaX)
    pos = gammaX * X
    zero = op.CastLike(0, X)
    return op.Where(zero >= X, neg, pos)

We can convert the ONNXScript function to an ONNX function (FunctionProto) as below:

onnx_fun = Selu.to_function_proto()

Let’s see what the translated function looks like:

import onnx  # noqa: E402

print(onnx.printer.to_text(onnx_fun))
<
  domain: "this",
  opset_import: ["" : 15]
>
Selu <alpha,gamma>(X) => (return_val)
{
   [n0] alpha = Constant <value_float: float = @alpha> ()
   [n1] alphaX = CastLike (alpha, X)
   [n2] gamma = Constant <value_float: float = @gamma> ()
   [n3] gammaX = CastLike (gamma, X)
   [n4] tmp = Exp (X)
   [n5] tmp_0 = Mul (alphaX, tmp)
   [n6] tmp_1 = Sub (tmp_0, alphaX)
   [n7] neg = Mul (gammaX, tmp_1)
   [n8] pos = Mul (gammaX, X)
   [n9] int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
   [n10] zero = CastLike (int64_0, X)
   [n11] tmp_2 = GreaterOrEqual (zero, X)
   [n12] return_val = Where (tmp_2, neg, pos)
}