Welcome to Knossos!¶
Warning
Knossos is very much a work in progress. Pretty much anything may not work, so we encourage you to say hello at https://github.com/microsoft/knossos-ksc/discussions before even starting to play :)
Knossos compiles (a subset of) PyTorch (and Julia, and F#) code into C++ (and MLIR and ONNX and other stuff). By which we mean actual C++, that you can deploy completely runtime-free if you like, or linking against ATen, MLAS, ONNX Runtime, whatever you have.
But that’s not all – it also contains a source-to-source automatic differentiation, so you can get gradients for free.
The canonical use case is to write custom PyTorch extensions.
Suppose you’ve invented a great new activation function, which you call relu3:
def relu3(x: float) -> float:
"""
Like ReLu, but smoother
Like GeLu, but cheaper
"""
if x < 0.0:
return 0.0
elif x < 1.0:
return 1 / 3 * x ** 3
else:
return x - 2 / 3
Defining a new kernel, taking a float and returning a float |
t = np.arange(-3, 3, step=0.1)
plt.plot(t, [relu3(t) for t in t], "b")
|
It must be better, right? Smoother than relu, cheaper than gelu. So we want to test it out. Before diving into an MNIST example, let’s spend a little time looking at the function.
The natural way to define this elementwise function is to just write the float-to-float
version as above, but of course we probably want it to work over tensors too.
As in JAX
(or functorch),
we provide vmap, so you can simply write:
vrelu3 = knossos.vmap(relu3) # Turn float->float into Tensor->Tensor
and then the plotting code above could just use vrelu3 instead of the list comprehension:
t = torch.arange(-3, 3, step=0.1)
plt.plot(t, vrelu3(t), "b")
Compilation¶
So how is it different to PyTorch or JAX? Well, as we said above, it’s compiled.
The default option is to compile to C++ source, and one reason for that is simply to
be able to inspect that source, to reassure ourselves that it’s good.
Here’s the C++ for relu3 (slightly prettified, so auto c2; won’t
compile, but look in build/torch_extensions to find the real code)
So, this is not that different from TorchScript,
but note that ks::Float is a real c++ float,
rather than a wrapper. And the aten::* calls are in fact to simple
inlined C++ functions which you can inspect in the [prelude].
The real difference from other systems comes when we take derivatives.
Derivatives¶
Before we use the function for deep learning, let’s just examine it a bit more.
We said it’s smooth, meaning, in this case, that its derivative is continuous.
Let’s check that by plotting the derivative. As vrelu3 takes a
vector to a vector, its Jacobian is a square matrix.
And because it’s operating elementwise,
i.e. independently on each element of the vector, the Jacobian is diagonal,
and a vector-Jacobian product (vjp) with a vector of all ones
will compute the derivative at each element.
dfdt = vrelu3.vjp(t, torch.ones_like(t)) # Vector-Jacobian product
plt.plot(t, dfdt, "b")
So what? I can do that in PyTorch¶
- Not really, for a few reasons:
Knossos compiles your code directly to C++/CUDA. You can literally look at the output.
This means that Knossos can easily deal with control flow like if statements, internal function calls, etc.
It also means it’s efficient for small float->float functions like this. The C++ code really deals in floats, not 1x1 tensors.
And the derivative code is also C++, taking plain ol’ floats.
So if you try the above example with vmap from functorch or JAX, it just won’t work. Now, if you’re an experienced PyTorch programmer, you know that the above is inefficient and you naturally code it “vectorized”. So you write
def vrelu3_pytorch(x: torch.Tensor):
mask1_inf = x > 1.0
mask0_1 = (x > 0.0) & ~mask1_inf
val_0_1 = 1 / 3 * x ** 3
val_1_inf = x - 2 / 3
return mask0_1 * val_0_1 + mask1_inf * val_1_inf
We argue that while performant, this is not the most natural way to write this code. Let’s look at the two options side by side:
@knossos.vmap
def vrelu3_knossos(x: float) -> float:
if x < 0.0:
return 0.0
elif x < 1.0:
return 1 / 3 * x ** 3
else:
return x - 2 / 3
Knossos: define the kernel, compile with vmap |
def vrelu3_pytorch(x: torch.Tensor):
mask1_inf = x > 1.0
mask0_1 = (x > 0.0) & ~mask1_inf
val_0_1 = 1 / 3 * x ** 3
val_1_inf = x - 2 / 3
return mask0_1 * val_0_1 + mask1_inf * val_1_inf
PyTorch: “Thinking in tensors”. It’s fun for a while, but it gets old |
The vectorized method may not even be the most efficient. For example, x^3 is computed for all inputs. This allows for parallelism on massively parallel hardware, but is wasteful on small power-constrained devices. Secondly, as written, the computation produces 10 temporary tensors, with a working set of up to 3x the optimal computation.
Integrating with PyTorch¶
So, we have a kernel, taking Tensors to Tensors, let’s try it in a machine learning model.
We’ll make a simple DNN with vrelu3 activations, as in
this
tutorial:
# Initialize the model using nn.Sequential
model = nn.Sequential(OrderedDict([
('fc1', nn.Linear(784, 256)),
('activation1', vrelu3.nnModule),
('fc2', nn.Linear(256, 128)),
('bn2', nn.BatchNorm1d(num_features=128)),
('activation2', vrelu3.nnModule),
('dropout', nn.Dropout(0.3)),
('fc3', nn.Linear(128, 64)),
('bn3', nn.BatchNorm1d(num_features=64)),
('activation3', vrelu3.nnModule),
('logits', nn.Linear(64, 10)),
('logsoftmax', nn.LogSoftmax(dim=1))]))
# Run training
train_model(model)
Yay, it just works. In this case it’s about as fast as the vectorized PyTorch (and much much faster than a vmap version), but of course it’s easier to write and read, and easier to modify, and easier to deploy.
Limitations¶
So, it’s a general-purpose python compiler! How cool!
No, it’s not. It is intended to allow PyTorch programmers to write more complex models without falling off performance cliffs. As mentioned above, a lot of stuff won’t work, but here’s what we want to get working well:
Float-to-float kernels as above
Small tensor-to-tensor kernels (see sqrl)
Messy computer vision models like those in (see ADBench)
Simple concatenations of operators like LSTM
And we want them all to be as fast as reasonable human-written C++

