Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom dtype arrays as return values #1

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stanio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"stan_variables",
]

__version__ = "0.3.1"
__version__ = "0.4.0"
44 changes: 39 additions & 5 deletions stanio/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,23 @@ class Variable:
# list of nested parameters
contents: List["Variable"]

def dtype(self, top=True):
if self.type == VariableType.TUPLE:
elts = [
(str(i + 1), param.dtype(top=False))
for i, param in enumerate(self.contents)
]
dtype = np.dtype(elts)
elif self.type == VariableType.SCALAR:
dtype = np.float64
elif self.type == VariableType.COMPLEX:
dtype = np.complex128

if top:
return dtype
else:
return np.dtype((dtype, self.dimensions))

def columns(self) -> Iterable[int]:
return range(self.start_idx, self.end_idx)

Expand Down Expand Up @@ -81,7 +98,7 @@ def _extract_helper(self, src: np.ndarray, offset: int = 0):
out[i, idx] = tuple(elt[i] for elt in elts)
return out.reshape(-1, *self.dimensions, order="F")

def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
def extract_reshape(self, src: np.ndarray, object=True) -> npt.NDArray[Any]:
"""
Given an array where the final dimension is the flattened output of a
Stan model, (e.g. one row of a Stan CSV file), extract the variable
Expand All @@ -98,6 +115,10 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
Indicies besides the final dimension are preserved
in the output.

object : bool
If True, the output of tuple types will be an object array,
otherwise it will use custom dtypes to represent tuples.

Returns
-------
npt.NDArray[Any]
Expand All @@ -106,10 +127,14 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
otherwise it will have a dtype of either float64 or complex128.
"""
out = self._extract_helper(src)
if not object:
out = out.astype(self.dtype())
if src.ndim > 1:
return out.reshape(*src.shape[:-1], *self.dimensions, order="F")
out = out.reshape(*src.shape[:-1], *self.dimensions, order="F")
else:
return out.squeeze(axis=0)
out = out.squeeze(axis=0)

return out


def _munge_first_tuple(tup: str) -> str:
Expand Down Expand Up @@ -194,7 +219,10 @@ def parse_header(header: str) -> Dict[str, Variable]:


def stan_variables(
parameters: Dict[str, Variable], source: npt.NDArray[np.float64]
parameters: Dict[str, Variable],
source: npt.NDArray[np.float64],
*,
object: bool = True,
) -> Dict[str, npt.NDArray[Any]]:
"""
Given a dictionary of :class:`Variable` objects and a source array,
Expand All @@ -208,11 +236,17 @@ def stan_variables(
like that returned by :func:`parse_header()`.
source : npt.NDArray[np.float64]
The array to extract from.
object : bool
If True, the output of tuple types will be an object array,
otherwise it will use custom dtypes to represent tuples.

Returns
-------
Dict[str, npt.NDArray[Any]]
A dictionary mapping the base name of each variable to the extracted
and reshaped data.
"""
return {param.name: param.extract_reshape(source) for param in parameters.values()}
return {
param.name: param.extract_reshape(source, object=object)
for param in parameters.values()
}
82 changes: 66 additions & 16 deletions test/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


# see file data/rectangles/output.stan
@pytest.fixture(scope="module")
def rect_data():
@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"])
def rect_data(request):
files = [DATA / "rectangles" / f"output_{i}.csv" for i in range(1, 5)]
header, data = read_csv(files)
params = parse_header(header)
yield stan_variables(params, data)
yield stan_variables(params, data, object=request.param)


def test_basic_shapes(rect_data):
Expand Down Expand Up @@ -91,43 +91,93 @@ def test_basic_values(rect_data):


# see file data/tuples/output.stan
@pytest.fixture(scope="module")
def tuple_data():
@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"])
def tuple_data(request):
files = [DATA / "tuples" / f"output_{i}.csv" for i in range(1, 5)]
header, data = read_csv(files)
params = parse_header(header)
yield stan_variables(params, data)
yield stan_variables(params, data, object=request.param)


def test_tuple_shapes(tuple_data):
assert isinstance(tuple_data["pair"][0, 0], tuple)
assert len(tuple_data["pair"][0, 0]) == 2

assert isinstance(tuple_data["nested"][0, 0], tuple)
assert len(tuple_data["nested"][0, 0]) == 2
assert isinstance(tuple_data["nested"][0, 0][1], tuple)
assert len(tuple_data["nested"][0, 0][1]) == 2

assert tuple_data["arr_pair"].shape == (4, 1000, 2)
assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple)

assert tuple_data["arr_very_nested"].shape == (4, 1000, 3)

assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2)

assert tuple_data["ultimate"].shape == (4, 1000, 2, 3)
assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,)
assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,)
assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5)


def check_tuple_shapes_objects(tuple_data):
assert isinstance(tuple_data["pair"][0, 0], tuple)

assert isinstance(tuple_data["nested"][0, 0], tuple)
assert isinstance(tuple_data["nested"][0, 0][1], tuple)

assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple)

assert isinstance(tuple_data["arr_very_nested"][0, 0, 0], tuple)
assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0], tuple)
assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0][1], tuple)

assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2)
assert isinstance(tuple_data["arr_2d_pair"][0, 0, 0, 0], tuple)

assert tuple_data["ultimate"].shape == (4, 1000, 2, 3)
assert isinstance(tuple_data["ultimate"][0, 0, 0, 0], tuple)
assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,)
assert isinstance(tuple_data["ultimate"][0, 0, 0, 0][0][0], tuple)
assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,)
assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5)


def check_tuple_shapes_custom_dtypes(tuple_data):
for value in tuple_data.values():
assert not value.dtype.hasobject

pair_dtype = np.dtype([("1", "f8"), ("2", "f8")])
assert tuple_data["pair"].dtype == pair_dtype

nested_dtype = np.dtype([("1", "f8"), ("2", [("1", "f8"), ("2", "c16")])])
assert tuple_data["nested"].dtype == nested_dtype
assert tuple_data["nested"][0, 0][1].dtype == nested_dtype[1]

assert tuple_data["arr_pair"].dtype == pair_dtype

very_nested_dtype = np.dtype(
[
("1", nested_dtype),
("2", "f8"),
]
)
assert tuple_data["arr_very_nested"].dtype == very_nested_dtype
assert tuple_data["arr_very_nested"][0, 0, 0][0].dtype == nested_dtype
assert tuple_data["arr_very_nested"][0, 0, 0][0][1].dtype == nested_dtype[1]

ultimate_dtype = np.dtype(
[
("1", ([("1", "f8"), ("2", "(2,)f8")], (2,))),
("2", "(4,5)f8"),
]
)
assert tuple_data["ultimate"].dtype == ultimate_dtype


def test_tuple_dtypes(tuple_data):
if isinstance(tuple_data["pair"][0, 0], tuple):
check_tuple_shapes_objects(tuple_data)
else:
check_tuple_shapes_custom_dtypes(tuple_data)


def assert_tuple_equal(t1, t2):
if hasattr(t1, "dtype") and t1.dtype.kind == "V":
t1 = t1.tolist()

assert len(t1) == len(t2)
for x, y in zip(t1, t2):
if isinstance(x, tuple):
Expand All @@ -140,7 +190,7 @@ def check_tuples(tuple_data, chain, draw):
base = tuple_data["base"][chain, draw]
base_i = tuple_data["base_i"][chain, draw]
pair_exp = (base, 2 * base)
np.testing.assert_almost_equal(tuple_data["pair"][chain, draw], pair_exp)
assert_tuple_equal(tuple_data["pair"][chain, draw], pair_exp)
nested_exp = (base * 3, (base_i, 4j * base))
assert_tuple_equal(tuple_data["nested"][chain, draw], nested_exp)

Expand Down