Skip to content

Commit

Permalink
feat: add tests to example MNIST
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Oct 24, 2022
1 parent 37caf53 commit c6d8f71
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 12 deletions.
1 change: 1 addition & 0 deletions examples/flax/MNIST/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ write_file(
content = [
"#!/usr/bin/env bash\n",
"cd $BUILD_WORKING_DIRECTORY\n",
"python -m pip install -r requirements.txt\n",
"python train.py $@",
],
)
Expand Down
23 changes: 23 additions & 0 deletions examples/flax/MNIST/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# MNIST classifier

This project demonstrates a simple CNN for MNIST classifier served with BentoML.

### Instruction

Run training scripts:

```bash
bazel run :train -- --num-epochs 2
```

Serve with either gRPC or HTTP:

```bash
bentoml serve-grpc --production --enable-reflection
```

Run the test suite:

```bash
pytest tests
```
1 change: 1 addition & 0 deletions examples/flax/MNIST/bentofile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ labels:
include:
- "*.py"
python:
lock_packages: true
requirements_txt: ./requirements.txt
6 changes: 4 additions & 2 deletions examples/flax/MNIST/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
optax>=0.1.3
flax>=0.6.1
jax>=0.3.23
jaxlib>=0.3.22
jax[cpu]
tensorflow;platform_system!="Darwin"
tensorflow-macos;platform_system=="Darwin"
tensorflow-datasets
cattrs>=22.1.0
attrs>=22.1.0
Pillow
pytest
pytest-asyncio
5 changes: 4 additions & 1 deletion examples/flax/MNIST/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def pytest_collection_modifyitems(
sys.executable,
f"{os.path.join(PROJECT_DIR, 'train.py')}",
"--num-epochs",
"2",
"2", # 2 epochs for faster testing
"--lr",
"0.22", # speed up training time
"--enable-tensorboard",
]
)

Expand Down
74 changes: 74 additions & 0 deletions examples/flax/MNIST/tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

import io
from typing import TYPE_CHECKING

import numpy as np
import pytest

import bentoml
from bentoml.testing.grpc import create_channel
from bentoml.testing.grpc import async_client_call

if TYPE_CHECKING:
import jax.numpy as jnp


@pytest.fixture()
def img():
import PIL.Image

images = {}
digits = list(range(10))
for digit in digits:
img_path = f"samples/{digit}.png"
with open(img_path, "rb") as f:
img_bytes = f.read()
arr = np.array(PIL.Image.open(io.BytesIO(img_bytes)))
images[digit] = {
"bytes": img_bytes,
"array": arr,
}

return images


@pytest.fixture(name="client")
@pytest.mark.asyncio
async def fixture_client(host: str):
return await bentoml.client.from_url(host)


# TODO: update with bentoml.client once
# gRPC client is implemented.


@pytest.mark.asyncio
async def test_image_grpc(
host: str, img: dict[int, dict[str, bytes | jnp.ndarray]], enable_grpc: bool
):
if not enable_grpc:
pytest.skip("Skipping gRPC test when testing on HTTP.")
async with create_channel(host) as channel:
for digit, d in img.items():
img_bytes = d["bytes"]
await async_client_call(
"predict",
channel=channel,
data={"serialized_bytes": img_bytes},
assert_data=lambda resp: resp.ndarray.int32_values == [digit],
)


@pytest.mark.asyncio
async def test_image_http(
client: bentoml.client.Client,
img: dict[int, dict[str, bytes | jnp.ndarray]],
enable_grpc: bool,
):
if enable_grpc:
pytest.skip("Skipping HTTP test when testing on gRPC.")
for digit, d in img.items():
img_bytes = d["bytes"]
expected = f"{digit}".encode()
assert await client.predict(img_bytes) == expected
27 changes: 18 additions & 9 deletions examples/flax/MNIST/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ConfigDict:
batch_size: int = 128
num_epochs: int = 10
momentum: float = 0.9
enable_tensorboard: bool = True

def to_dict(self) -> dict[str, t.Any]:
return cattrs.unstructure(self)
Expand Down Expand Up @@ -152,9 +153,11 @@ def train_and_evaluate(
"""
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
summary_writer: tensorboard.SummaryWriter | None = None

summary_writer = tensorboard.SummaryWriter(workdir)
summary_writer.hparams(config.to_dict())
if config.enable_tensorboard:
summary_writer = tensorboard.SummaryWriter(workdir)
summary_writer.hparams(config.to_dict())

rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)
Expand All @@ -173,12 +176,16 @@ def train_and_evaluate(
% (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
)

summary_writer.scalar("train_loss", train_loss, epoch)
summary_writer.scalar("train_accuracy", train_accuracy, epoch)
summary_writer.scalar("test_loss", test_loss, epoch)
summary_writer.scalar("test_accuracy", test_accuracy, epoch)
if config.enable_tensorboard:
assert summary_writer is not None
summary_writer.scalar("train_loss", train_loss, epoch)
summary_writer.scalar("train_accuracy", train_accuracy, epoch)
summary_writer.scalar("test_loss", test_loss, epoch)
summary_writer.scalar("test_accuracy", test_accuracy, epoch)

summary_writer.flush()
if config.enable_tensorboard:
assert summary_writer is not None
summary_writer.flush()

return state

Expand Down Expand Up @@ -209,9 +216,10 @@ def load_and_predict(path: str, idx: int = 0):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--momentum", type=float, default=0.94)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--num-epochs", type=int, default=10)
parser.add_argument("--enable-tensorboard", action="store_true")
args = parser.parse_args()

training_state = train_and_evaluate(
Expand All @@ -220,6 +228,7 @@ def load_and_predict(path: str, idx: int = 0):
momentum=args.momentum,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
enable_tensorboard=args.enable_tensorboard,
),
)

Expand Down
3 changes: 3 additions & 0 deletions src/bentoml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@

from . import io
from . import models
from . import client

# Prometheus metrics client
from . import metrics
Expand Down Expand Up @@ -121,6 +122,7 @@
io = _LazyLoader("bentoml.io", globals(), "bentoml.io")
models = _LazyLoader("bentoml.models", globals(), "bentoml.models")
metrics = _LazyLoader("bentoml.metrics", globals(), "bentoml.metrics")
client = _LazyLoader("bentoml.client", globals(), "bentoml.client")

del _LazyLoader

Expand All @@ -131,6 +133,7 @@
"Service",
"models",
"metrics",
"client",
"io",
"Tag",
"Model",
Expand Down
5 changes: 5 additions & 0 deletions src/bentoml/_internal/bento/build_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,11 @@ def with_defaults(self) -> PythonOptions:
if self.requirements_txt is None:
if self.lock_packages is None:
defaults["lock_packages"] = True
else:
logger.debug(
"'requirements_txt: %s' is passed, and bentoml won't lock PyPI package by default. If you wish to lock it pass 'python.lock_packages=true' to your 'bentofile.yaml'.",
self.requirements_txt,
)

return attr.evolve(self, **defaults)

Expand Down
9 changes: 9 additions & 0 deletions src/bentoml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
from ._internal.service.inference_api import InferenceAPI


async def from_url(url: str) -> Client:
"""
Creates a client from a URL.
This function is a proxy to :meth:`Client.from_url` for convenience.
"""
return await Client.from_url(url)


class Client(ABC):
server_url: str

Expand Down
1 change: 1 addition & 0 deletions tests/integration/frameworks/models/flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import annotations

0 comments on commit c6d8f71

Please sign in to comment.