A Rust implementation of a RoBERTa classification model for the SNLI dataset, with support for fine-tuning, predicting, and serving. This is built on top of tch-rs and rust-bert.
This was the result of the AI2 2020 Employee Hackathon. The motivation for this project was to demonstrate that Rust is already a viable alternative to Python for deep learning projects in the context of both research and production.
RustBERTa-SNLI is packaged as a Rust binary, and will work on any operating system that PyTorch supports.
The only prerequisite is that you have the Rust toolchain installed. So if you're already a Rustacean, skip ahead to the "Additional setup for CUDA" section (optional).
Now, luckily, installing Rust is nothing like installing a proper Python environment, i.e. it doesn't require a PhD in system administration or the courage to blindly run every sudo command you can find on Stack Overflow until something works or completely breaks your computer.
All you have to do is run this:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
Then just make sure ~/.cargo/bin
is in your $PATH
, and you're good to go.
You can test for success by running rustup --version
and cargo --version
.
rustup
can be used to update your toolchain when a new version of Rust is released (which happens monthly).cargo
is used to compile, run, and test your code, as well as to build documentation, publish your crate (the Rust term for a module/library) to crates.io, and install binaries from other crates on crates.io.
If you have CUDA-enabled GPUs available on your machine, you'll probably want to compile this library with CUDA support.
To do that, you just need to download the right version of LibTorch from the PyTorch website: https://pytorch.org/get-started/locally/.
Then unzip the downloaded file to someplace safe like ~/torch/libtorch
and set the environment variables:
LIBTORCH="$HOME/torch/libtorch" # or wherever you unzipped it
LD_LIBRARY_PATH="$HOME/torch/libtorch/lib:$LD_LIBRARY_PATH"
cargo
. But if you accidentally start the build
for this step, just delete the target/
directory and start over.
To build the release binary, just run make
.
To see all of the available commands, run
./roberta-snli --help
For example, to fine-tune a pretrained RoBERTa model, run
./roberta-snli train --out weights.ot
To interactively get predictions with a fine-tuned model, run
./roberta-snli predict --weigths weights.ot
To evaluate a fine-tuned model on the test set, run
./roberta-snli evaluate
And to serve a fine-tuned model as a production-grade webservice with batched prediction, run
./roberta-snli serve
This will serve on port 3030 by default. You can then test it out by running:
curl \
-d '{"premise":"A soccer game with multiple people playing.","hypothesis":"Some people are playing a sport."}' \
-H "Content-Type: application/json" \
http://localhost:3030/predict
You can also test the batching functionality by sending a bunch of requests at once with:
./scripts/test_server.sh