Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dft to be idempotent for already transformed coords #397

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions dascore/transform/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,24 @@ def _get_dft_attrs(patch, dims, new_coords):
new = dict(patch.attrs)
new["dims"] = new_coords.dims
new["data_units"] = _get_data_units_from_dims(patch, dims, mul)
# As per #390, we also want to remove data_type (eg the patch is no
# longer in strain rate after the dft)
new["_pre_dft_data_type"] = new.pop("data_type", None)
return PatchAttrs(**new)


def _get_untransformed_dims(patch, dims):
"""Return dimensions which have not been transformed."""
dim_set = set(patch.dims)
out = []
for dim in dims:
# This dim has already been transformed.
if (dim not in dim_set) and f"ft_{dim}" in dim_set:
continue
out.append(dim)
return out


@patch_function()
def dft(
patch: PatchType, dim: str | None | Sequence[str], *, real: str | bool | None = None
Expand Down Expand Up @@ -111,7 +126,7 @@ def dft(
- Non-dimensional coordiantes associated with transformed coordinates
will be dropped in the output.

- See the [FFT note](dascore.org/notes/fft_notes.html) in the Notes section
- See the [FFT note](`notes/dft_notes.qmd`) in the Notes section
of DASCore's documentation for more details.

See Also
Expand All @@ -131,11 +146,15 @@ def dft(
"""
dims = list(iterate(dim if dim is not None else patch.dims))
patch.assert_has_coords(dims)
real = dims[-1] if real is True else real # if true grab last dim
dims = _get_untransformed_dims(patch, dims)
real = real if real in dims else None # may need to reset real
if not dims: # no transformation needed.
return patch
# re-arrange list so real dim is last (if provided)
if isinstance(real, str):
assert real in dims, "real must be in provided dimensions."
dims.append(dims.pop(dims.index(real)))
real = dims[-1] if real is True else real # if true grab last dim
# get axes and spacing along desired dimensions.
dxs, axes = _get_dx_or_spacing_and_axes(patch, dims, require_evenly_spaced=True)
func = nft.rfftn if real is not None else nft.fftn
Expand Down Expand Up @@ -208,6 +227,9 @@ def _get_idft_attrs(patch, dims, new_coords):
new = dict(patch.attrs)
new["dims"] = new_coords.dims
new["data_units"] = _get_data_units_from_dims(patch, dims, mul)
# Restore the pre-dft datatype.
if "_pre_dft_data_type" in new:
new["data_type"] = new.pop("_pre_dft_data_type", None)
return PatchAttrs(**new)


Expand Down
44 changes: 41 additions & 3 deletions tests/test_transform/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
@pytest.fixture(scope="session")
def sin_patch():
"""Get the sine wave patch, set units for testing."""
patch = dc.get_example_patch("sin_wav", sample_rate=100, duration=3, frequency=F_0)
out = patch.set_units(get_quantity("1.0 V"), time="s", distance="m")
return out
patch = (
dc.get_example_patch("sin_wav", sample_rate=100, duration=3, frequency=F_0)
.set_units(get_quantity("1.0 V"), time="s", distance="m")
.update_attrs(data_type="strain_rate")
)
return patch


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -123,6 +126,36 @@ def test_parseval(self, sin_patch, fft_sin_patch_time):
vals2 = (pa2.abs() ** 2).integrate("ft_time", definite=True)
assert np.allclose(vals1.data, vals2.data)

def test_idempotent_single_dim(self, fft_sin_patch_time):
"""
Ensure dft is idempotent for a single dimension.
"""
out = fft_sin_patch_time.dft("time")
assert out.equals(fft_sin_patch_time)

def test_idempotent_all_dims(self, fft_sin_patch_all):
"""
Ensure dft is idempotent for transforms applied to all dims.
"""
out = fft_sin_patch_all.dft(dim=("time", "distance"))
assert out.equals(fft_sin_patch_all)

def test_transform_single_dim(
self, sin_patch, fft_sin_patch_time, fft_sin_patch_all
):
"""
Ensure dft is idempotent for time, but untransformed axis still gets
transformed.
"""
out = fft_sin_patch_time.dft(dim=("time", "distance"))
assert not out.equals(fft_sin_patch_time)
assert np.allclose(out.data, fft_sin_patch_all.data)

def test_datatype_removed(self, fft_sin_patch_time, sin_patch):
"""Ensure the data_type attr is removed after transform."""
assert sin_patch.attrs.data_type == "strain_rate"
assert fft_sin_patch_time.attrs.data_type == ""


class TestInverseDiscreteFourierTransform:
"""Inverse DFT suite."""
Expand Down Expand Up @@ -167,3 +200,8 @@ def test_partial_inverse(self, fft_sin_patch_all, sin_patch):
# and then if we reverse distance it should be the same as original
full_inverse = ift.idft("distance")
self._patches_about_equal(full_inverse, sin_patch)

def test_data_type_restored(self, fft_sin_patch_time, sin_patch):
"""Ensure data_type attr is restored."""
out = fft_sin_patch_time.idft("time")
assert out.attrs.data_type == sin_patch.attrs.data_type
Loading