Skip to content

Commit

Permalink
Speed improvements (#4)
Browse files Browse the repository at this point in the history
* Add fast paths for numpy numeric arrays

* Optional: ujson dependency
  • Loading branch information
WardBrian authored Feb 12, 2024
1 parent 6554807 commit e584bc2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ on:

jobs:
required:
name: "${{matrix.os}} / ${{matrix.python-version}} / numpy_nightly: ${{matrix.numpy_nightly}}"
name: "${{matrix.os}} / ${{matrix.python-version}} / ujson: ${{matrix.ujson}}"
runs-on: ${{matrix.os}}
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8, "3.12"]
numpy_nightly: [false, true]
ujson: [false, true]
steps:
- name: Check out github
uses: actions/checkout@v4
Expand All @@ -35,10 +35,10 @@ jobs:
run: |
python -m pip install .[test]
- name: Install numpy
if: ${{ matrix.numpy_nightly }}
- name: Install ujson
if: ${{ matrix.ujson }}
run: |
pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple -U numpy
pip install -U ujson
- name: Run tests
run: |
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ test = [
"pytest",
"pytest-cov",
]
ujson = [
"ujson>=5.5.0"
]

[tool.isort]
profile = "black"
Expand Down
31 changes: 27 additions & 4 deletions stanio/json.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
"""
Utilities for writing Stan Json files
"""
import json
try:
import ujson as json

uj_version = tuple(map(int, json.__version__.split(".")))
if uj_version < (5, 5, 0):
raise ImportError("ujson version too old")
UJSON_AVAILABLE = True
except:
UJSON_AVAILABLE = False
import json

from typing import Any, Mapping

import numpy as np
Expand Down Expand Up @@ -31,7 +41,17 @@ def process_value(val: Any) -> Any:
or "xarray" in original_module
or "pandas" in original_module
):
return process_value(np.asanyarray(val).tolist())
numpy_val = np.asanyarray(val)
# fast paths for numeric types
if numpy_val.dtype.kind in "iuf":
return numpy_val.tolist()
if numpy_val.dtype.kind == "c":
return np.stack([numpy_val.real, numpy_val.imag], axis=-1).tolist()
if numpy_val.dtype.kind == "b":
return numpy_val.astype(int).tolist()

# should only be object arrays (tuples, etc)
return process_value(numpy_val.tolist())

return val

Expand Down Expand Up @@ -75,5 +95,8 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
copied before type conversion, not modified
"""
with open(path, "w") as fd:
for chunk in json.JSONEncoder().iterencode(process_dictionary(data)):
fd.write(chunk)
if UJSON_AVAILABLE:
json.dump(process_dictionary(data), fd)
else:
for chunk in json.JSONEncoder().iterencode(process_dictionary(data)):
fd.write(chunk)
20 changes: 17 additions & 3 deletions test/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def test_basic_array(TMPDIR) -> None:


def test_bool(TMPDIR) -> None:
dict_bool = {"a": False, "b": True}
dict_bool = {"a": False, "b": True, "c": np.array([True, False])}
file_bool = os.path.join(TMPDIR, "bool.json")
dict_exp = {"a": 0, "b": 1}
dict_exp = {"a": 0, "b": 1, "c": [1, 0]}
after = compare_before_after(file_bool, dict_bool, dict_exp)
assert isinstance(after["a"], int)
assert not isinstance(after["a"], bool)
Expand Down Expand Up @@ -135,18 +135,32 @@ def test_special_values(TMPDIR) -> None:
]
)
}

# we want very specific values here
json_string = dump_stan_json(dict_inf_nan)
assert json_string.count("Infinity") == 8
assert json_string.count("NaN") == 4
assert json_string.count("-Infinity") == 4

dict_inf_nan_exp = {"a": [[-np.inf, np.inf, np.nan]] * 4}
file_fin = os.path.join(TMPDIR, "inf.json")
compare_before_after(file_fin, dict_inf_nan, dict_inf_nan_exp)


def test_complex_numbers(TMPDIR) -> None:
dict_complex = {"a": np.array([np.complex64(3), 3 + 4j])}
dict_complex = {"a": [3 + 0j, 3 + 4j]}
dict_complex_exp = {"a": [[3, 0], [3, 4]]}
file_complex = os.path.join(TMPDIR, "complex.json")
compare_before_after(file_complex, dict_complex, dict_complex_exp)


def test_complex_numbers_np(TMPDIR) -> None:
dict_complex = {"a": np.array([np.complex64(3), 3 + 4j])}
dict_complex_exp = {"a": [[3, 0], [3, 4]]}
file_complex = os.path.join(TMPDIR, "complex_np.json")
compare_before_after(file_complex, dict_complex, dict_complex_exp)


def test_tuples(TMPDIR) -> None:
dict_tuples = {
"a": (1, 2, 3),
Expand Down

0 comments on commit e584bc2

Please sign in to comment.