Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Flax #3123

Merged
merged 25 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ env:
LINES: 120
COLUMNS: 120
BENTOML_DO_NOT_TRACK: True
PYTEST_PLUGINS: bentoml.testing.pytest.plugin
aarnphm marked this conversation as resolved.
Show resolved Hide resolved

# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
defaults:
Expand Down
53 changes: 53 additions & 0 deletions .github/workflows/frameworks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ env:
LINES: 120
COLUMNS: 120
BENTOML_DO_NOT_TRACK: True
PYTEST_PLUGINS: bentoml.testing.pytest.plugin

# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
defaults:
Expand All @@ -33,6 +34,7 @@ jobs:
pytorch: ${{ steps.filter.outputs.pytorch }}
pytorch_lightning: ${{ steps.filter.outputs.pytorch_lightning }}
sklearn: ${{ steps.filter.outputs.sklearn }}
flax: ${{ steps.filter.outputs.flax }}
tensorflow: ${{ steps.filter.outputs.tensorflow }}
torchscript: ${{ steps.filter.outputs.torchscript }}
transformers: ${{ steps.filter.outputs.transformers }}
Expand Down Expand Up @@ -94,6 +96,12 @@ jobs:
- src/bentoml/_internal/frameworks/pytorch.py
- src/bentoml/_internal/frameworks/common/pytorch.py
- tests/integration/frameworks/test_pytorch_unit.py
flax:
- *related
- src/bentoml/flax.py
- src/bentoml/_internal/frameworks/flax.py
- src/bentoml/_internal/frameworks/common/jax.py
- tests/integration/frameworks/models/flax.py
torchscript:
- *related
- src/bentoml/torchscript.py
Expand Down Expand Up @@ -224,6 +232,51 @@ jobs:
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}

flax_integration_tests:
needs: diff
if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.flax == 'true') || github.event_name == 'push' }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0 # fetch all tags and branches
- name: Setup python
uses: actions/setup-python@v4
with:
python-version: 3.8

- name: Get pip cache dir
id: cache-dir
run: |
echo ::set-output name=dir::$(pip cache dir)

- name: Cache pip dependencies
uses: actions/cache@v3
id: cache-pip
with:
path: ${{ steps.cache-dir.outputs.dir }}
key: ${{ runner.os }}-tests-${{ hashFiles('requirements/tests-requirements.txt') }}

- name: Install dependencies
run: |
pip install .
pip install flax jax jaxlib chex tensorflow
pip install -r requirements/tests-requirements.txt

- name: Run tests and generate coverage report
run: |
OPTS=(--cov-config pyproject.toml --cov src/bentoml --cov-append --framework flax)
coverage run -m pytest tests/integration/frameworks/test_frameworks.py "${OPTS[@]}"

- name: Generate coverage
run: coverage xml

- name: Upload test coverage to Codecov
uses: codecov/codecov-action@v3
with:
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}

fastai_integration_tests:
needs: diff
if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.fastai == 'true') || github.event_name == 'push' }}
Expand Down
21 changes: 21 additions & 0 deletions docs/source/reference/frameworks/flax.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
====
Flax
====

.. admonition:: About this page

This is an API reference for FLax in BentoML. Please refer to
:doc:`/frameworks/flax` for more information about how to use Flax in BentoML.


.. note::

You can find more examples for **Flax** in our `bentoml/examples https://github.com/bentoml/BentoML/tree/main/examples`_ directory.

.. currentmodule:: bentoml.flax

.. autofunction:: bentoml.flax.save_model

.. autofunction:: bentoml.flax.load_model

.. autofunction:: bentoml.flax.get
1 change: 1 addition & 0 deletions docs/source/reference/frameworks/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Framework APIs
onnx
sklearn
transformers
flax
tensorflow
xgboost
picklable_model
Expand Down
28 changes: 20 additions & 8 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# BentoML Examples 🎨 [![Twitter Follow](https://img.shields.io/twitter/follow/bentomlai?style=social)](https://twitter.com/bentomlai) [![Slack](https://img.shields.io/badge/Slack-Join-4A154B?style=social)](https://l.linklyhq.com/l/ktO8)

BentoML is an open platform for machine learning in production. It simplifies model packaging and model management, optimizes model serving workloads to run at production scale, and accelerates the creation, deployment, and monitoring of prediction services.
BentoML is an open platform for machine learning in production. It simplifies
model packaging and model management, optimizes model serving workloads to run
at production scale, and accelerates the creation, deployment, and monitoring of
prediction services.

The repository contains a collection of example projects demonstrating [BentoML](https://github.com/bentoml/BentoML)
usage and best practices.
The repository contains a collection of example projects demonstrating
[BentoML](https://github.com/bentoml/BentoML) usage and best practices.

👉 [Pop into our Slack community!](https://join.slack.bentoml.org) We're happy to help with any issue you face or even just to meet you and hear what you're working on :)
👉 [Pop into our Slack community!](https://join.slack.bentoml.org) We're happy
to help with any issue you face or even just to meet you and hear what you're
working on :)

## Index

Expand Down Expand Up @@ -36,16 +41,23 @@ usage and best practices.
| [tensorflow2_keras](https://github.com/bentoml/BentoML/tree/main/examples/tensorflow2_keras) | TensorFlow, Keras | MNIST | Notebook |
| [tensorflow2_native](https://github.com/bentoml/BentoML/tree/main/examples/tensorflow2_native) | TensforFlow | MNIST | Notebook |
| [xgboost](https://github.com/bentoml/BentoML/tree/main/examples/xgboost) | XGBoost | DMatrix | |
| [flax/MNIST](https://github.com/bentoml/BentoML/tree/main/examples/flax/MNIST) | Flax | MNIST | gRPC, Testing |

## How to contribute

If you have issues running these projects or have suggestions for improvement, use [Github Issues 🐱](https://github.com/bentoml/BentoML/issues/new)
If you have issues running these projects or have suggestions for improvement,
use [Github Issues 🐱](https://github.com/bentoml/BentoML/issues/new)

If you are interested in contributing new projects to this repo, let's talk 🥰 - Join us on [Slack](https://join.slack.com/t/bentoml/shared_invite/enQtNjcyMTY3MjE4NTgzLTU3ZDc1MWM5MzQxMWQxMzJiNTc1MTJmMzYzMTYwMjQ0OGEwNDFmZDkzYWQxNzgxYWNhNjAxZjk4MzI4OGY1Yjg) and share your idea in #bentoml-contributors channel
If you are interested in contributing new projects to this repo, let's talk 🥰 -
Join us on
[Slack](https://join.slack.com/t/bentoml/shared_invite/enQtNjcyMTY3MjE4NTgzLTU3ZDc1MWM5MzQxMWQxMzJiNTc1MTJmMzYzMTYwMjQ0OGEwNDFmZDkzYWQxNzgxYWNhNjAxZjk4MzI4OGY1Yjg)
and share your idea in #bentoml-contributors channel

Before you create a Pull Request, make sure:

- Follow the basic structures and naming conventions of other existing example projects
- Follow the basic structures and naming conventions of other existing example
projects
- Ensure your project runs with the latest version of BentoML

For legacy version prior to v1.0.0, see the [0.13-LTS branch](https://github.com/bentoml/gallery/tree/0.13-LTS).
For legacy version prior to v1.0.0, see the
[0.13-LTS branch](https://github.com/bentoml/gallery/tree/0.13-LTS).
4 changes: 4 additions & 0 deletions examples/flax/MNIST/.bentoignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__pycache__/
*.py[cod]
*$py.class
.ipynb_checkpoints
2 changes: 2 additions & 0 deletions examples/flax/MNIST/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
events.out*
*.msgpack
18 changes: 18 additions & 0 deletions examples/flax/MNIST/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@bazel_skylib//rules:write_file.bzl", "write_file")

write_file(
name = "_train_sh",
out = "_train.sh",
content = [
"#!/usr/bin/env bash\n",
"cd $BUILD_WORKING_DIRECTORY\n",
"python -m pip install -r requirements.txt\n",
"python train.py $@",
],
)

sh_binary(
name = "train",
srcs = ["_train.sh"],
data = ["train.py"],
)
34 changes: 34 additions & 0 deletions examples/flax/MNIST/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# MNIST classifier

This project demonstrates a simple CNN for MNIST classifier served with BentoML.

### Instruction

Run training scripts:

```bash
# run with python3
pip install -r requirements.txt
python3 train.py --num-epochs 2

# run with bazel
bazel run :train -- --num-epochs 2
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
```

Serve with either gRPC or HTTP:

```bash
bentoml serve-grpc --production --enable-reflection
```

Run the test suite:

```bash
pytest tests
```

To run containerize do:

```bash
bentoml containerize mnist_flax --opt platform=linux/amd64
```
12 changes: 12 additions & 0 deletions examples/flax/MNIST/bentofile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
service: "service.py:svc"
labels:
owner: bentoml-team
project: mnist-flax
experiemental: true
include:
- "*.py"
python:
lock_packages: false
extra_index_url:
- https://storage.googleapis.com/jax-releases/jax_releases.html
requirements_txt: ./requirements.txt
10 changes: 10 additions & 0 deletions examples/flax/MNIST/requirements-gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda]==0.4.4
flax>=0.6.1
optax>=0.1.3
bentoml[grpc,grpc-reflection]
tensorflow
tensorflow-datasets
Pillow
pytest
pytest-asyncio
10 changes: 10 additions & 0 deletions examples/flax/MNIST/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
jax[cpu]==0.4.4
flax>=0.6.1
optax>=0.1.3
bentoml[grpc,grpc-reflection]
tensorflow;platform_system!="Darwin"
tensorflow-macos;platform_system=="Darwin"
tensorflow-datasets
Pillow
pytest
pytest-asyncio
Binary file added examples/flax/MNIST/samples/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flax/MNIST/samples/9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions examples/flax/MNIST/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import typing as t
from typing import TYPE_CHECKING

import jax.numpy as jnp
from PIL.Image import Image as PILImage

import bentoml

if TYPE_CHECKING:
from numpy.typing import NDArray

mnist_runner = bentoml.flax.get("mnist_flax").to_runner()

svc = bentoml.Service(name="mnist_flax", runners=[mnist_runner])


@svc.api(input=bentoml.io.Image(), output=bentoml.io.NumpyNdarray())
async def predict(f: PILImage) -> NDArray[t.Any]:
arr = jnp.array(f) / 255.0
arr = jnp.expand_dims(arr, (0, 3))
res = await mnist_runner.async_run(arr)
return res.argmax()
91 changes: 91 additions & 0 deletions examples/flax/MNIST/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import os
import sys
import typing as t
import contextlib
import subprocess
from typing import TYPE_CHECKING

import psutil
import pytest

import bentoml
from bentoml.testing.server import host_bento
from bentoml._internal.configuration.containers import BentoMLContainer

if TYPE_CHECKING:
from contextlib import ExitStack

from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.config import Config
from _pytest.fixtures import FixtureRequest as _PytestFixtureRequest

class FixtureRequest(_PytestFixtureRequest):
param: str


PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def pytest_collection_modifyitems(
session: Session, config: Config, items: list[Item]
) -> None:
try:
m = bentoml.models.get("mnist_flax")
print(f"Model exists: {m}")
except bentoml.exceptions.NotFound:
subprocess.check_call(
[
sys.executable,
f"{os.path.join(PROJECT_DIR, 'train.py')}",
"--num-epochs",
"2", # 2 epochs for faster testing
"--lr",
"0.22", # speed up training time
"--enable-tensorboard",
]
)


@pytest.fixture(name="enable_grpc", params=[True, False], scope="session")
def fixture_enable_grpc(request: FixtureRequest) -> str:
return request.param


@pytest.fixture(scope="session", autouse=True)
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
stack = contextlib.ExitStack()
yield stack
stack.close()


@pytest.fixture(
name="deployment_mode",
params=["container", "distributed", "standalone"],
scope="session",
)
def fixture_deployment_mode(request: FixtureRequest) -> str:
return request.param


@pytest.mark.usefixtures("change_test_dir")
@pytest.fixture(scope="module")
def host(
deployment_mode: t.Literal["container", "distributed", "standalone"],
clean_context: ExitStack,
enable_grpc: bool,
) -> t.Generator[str, None, None]:
if enable_grpc and psutil.WINDOWS:
pytest.skip("gRPC is not supported on Windows.")

with host_bento(
"service:svc",
deployment_mode=deployment_mode,
project_path=PROJECT_DIR,
bentoml_home=BentoMLContainer.bentoml_home.get(),
clean_context=clean_context,
use_grpc=enable_grpc,
) as _host:
yield _host
Loading