Skip to content

Commit

Permalink
fix: rename equal_shape
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <[email protected]>
  • Loading branch information
aarnphm committed Feb 17, 2023
1 parent 8ad79c9 commit 300959a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/bentoml/_internal/frameworks/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

@attr.define
class FlaxOptions(ModelOptions):
"""Options for the Keras model."""
"""Options for the Flax model."""

partial_kwargs: t.Dict[str, t.Any] = attr.field(factory=dict)
device: str = attr.field(default="cpu")
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/frameworks/models/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def init_mlp_state():
return params


def is_close(model: nn.Module, state_dict: dict[str, t.Any], arr: jnp.ndarray):
def assert_equal_shape(
model: nn.Module, state_dict: dict[str, t.Any], arr: jnp.ndarray
):
def check(out: jnp.ndarray) -> bool:
logit = model.apply({"params": state_dict["params"]}, arr)
chex.assert_equal_shape([logit, out])
Expand All @@ -84,7 +86,9 @@ def check(out: jnp.ndarray) -> bool:
"__call__": [
Input(
input_args=[ones_array],
expected=is_close(MLP(), init_mlp_state(), ones_array),
expected=assert_equal_shape(
MLP(), init_mlp_state(), ones_array
),
)
]
},
Expand Down

0 comments on commit 300959a

Please sign in to comment.