Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup plan class #91

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class Plan {
//! axes for fft
axis_type<DIM> m_axes;

//! Shape of the transformed axis of the output
shape_type<DIM> m_shape;

//! directions of fft
KokkosFFT::Direction m_direction;

Expand All @@ -140,10 +143,15 @@ class Plan {
/// \param out [in] Ouput data
/// \param direction [in] Direction of FFT (forward/backward)
/// \param axis [in] Axis over which FFT is performed
/// \param n [in] Length of the transformed axis of the output (optional)
//
explicit Plan(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, KokkosFFT::Direction direction, int axis)
: m_fft_size(1), m_is_transpose_needed(false), m_direction(direction) {
OutViewType& out, KokkosFFT::Direction direction, int axis,
std::optional<std::size_t> n = std::nullopt)
: m_fft_size(1),
m_is_transpose_needed(false),
m_direction(direction),
m_axes({axis}) {
static_assert(Kokkos::is_view<InViewType>::value,
"Plan::Plan: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand Down Expand Up @@ -172,7 +180,6 @@ class Plan {
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");

m_axes = {axis};
m_in_extents = KokkosFFT::Impl::extract_extents(in);
m_out_extents = KokkosFFT::Impl::extract_extents(out);
std::tie(m_map, m_map_inv) = KokkosFFT::Impl::get_map_axes(in, axis);
Expand All @@ -188,10 +195,11 @@ class Plan {
/// \param out [in] Ouput data
/// \param direction [in] Direction of FFT (forward/backward)
/// \param axes [in] Axes over which FFT is performed
/// \param s [in] Shape of the transformed axis of the output (optional)
//
explicit Plan(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, KokkosFFT::Direction direction,
axis_type<DIM> axes)
axis_type<DIM> axes, shape_type<DIM> s = {0})
: m_fft_size(1),
m_is_transpose_needed(false),
m_direction(direction),
Expand Down Expand Up @@ -238,6 +246,11 @@ class Plan {
_destroy_plan<ExecutionSpace, fft_plan_type>(m_plan);
}

Plan(const Plan&) = delete;
Plan& operator=(const Plan&) = delete;
Plan& operator=(Plan&&) = delete;
Plan(Plan&&) = delete;

/// \brief Sanity check of the plan used to call FFT interface with
/// pre-defined FFT plan. This raises an error if there is an
/// incosistency between FFT function and plan
Expand Down
11 changes: 7 additions & 4 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,17 +319,19 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
"ifft: execution_space cannot access data in OutViewType");

InViewType _in;
// [TO DO] Modify crop_or_pad to perform the following lines
// KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, n);
if (n) {
std::size_t _n = n.value();
auto modified_shape =
KokkosFFT::Impl::get_modified_shape(in, shape_type<1>({_n}));

/* [FIX THIS] Shallow copy should be sufficient
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
*/
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
Expand Down Expand Up @@ -393,17 +395,18 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
"ifft: execution_space cannot access data in OutViewType");

InViewType _in;
// [TO DO] Modify crop_or_pad to perform the following lines
// KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, n);
if (n) {
std::size_t _n = n.value();
auto modified_shape =
KokkosFFT::Impl::get_modified_shape(in, shape_type<1>({_n}));
/* [FIX THIS] Shallow copy should be sufficient
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
*/
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
Expand Down
89 changes: 89 additions & 0 deletions fft/unit_test/Test_Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,86 @@ void test_fft1_1dihfft_1dview() {
EXPECT_TRUE(allclose(out2_f, x_herm_ref, 1.e-5, 1.e-6));
}

template <typename T, typename LayoutType>
void test_fft1_shape(T atol = 1.0e-12) {
const int n = 32;
using RealView1DType = Kokkos::View<T*, LayoutType, execution_space>;
using ComplexView1DType =
Kokkos::View<Kokkos::complex<T>*, LayoutType, execution_space>;

RealView1DType xr("xr", n), xr_ref("xr_ref", n);
ComplexView1DType x("x", n / 2 + 1), x_ref("x_ref", n / 2 + 1);

const Kokkos::complex<T> I(1.0, 1.0);
Kokkos::Random_XorShift64_Pool<> random_pool(/*seed=*/12345);
Kokkos::fill_random(xr, random_pool, 1.0);
Kokkos::fill_random(x, random_pool, I);

// Since HIP FFT destructs the input data, we need to keep the input data in
// different place
Kokkos::deep_copy(x_ref, x);
Kokkos::deep_copy(xr_ref, xr);
Kokkos::fence();

std::vector<int> shapes = {n / 2, n, n * 2};
for (auto&& shape : shapes) {
// Real to comple
ComplexView1DType outr("outr", shape / 2 + 1),
outr_b("outr_b", shape / 2 + 1), outr_o("outr_o", shape / 2 + 1),
outr_f("outr_f", shape / 2 + 1);

Kokkos::deep_copy(xr, xr_ref);
KokkosFFT::rfft(execution_space(), xr, outr, KokkosFFT::Normalization::none,
-1, shape);

Kokkos::deep_copy(xr, xr_ref);
KokkosFFT::rfft(execution_space(), xr, outr_b,
KokkosFFT::Normalization::backward, -1, shape);

Kokkos::deep_copy(xr, xr_ref);
KokkosFFT::rfft(execution_space(), xr, outr_o,
KokkosFFT::Normalization::ortho, -1, shape);

Kokkos::deep_copy(xr, xr_ref);
KokkosFFT::rfft(execution_space(), xr, outr_f,
KokkosFFT::Normalization::forward, -1, shape);

multiply(outr_o, sqrt(static_cast<T>(shape)));
multiply(outr_f, static_cast<T>(shape));

EXPECT_TRUE(allclose(outr_b, outr, 1.e-5, atol));
EXPECT_TRUE(allclose(outr_o, outr, 1.e-5, atol));
EXPECT_TRUE(allclose(outr_f, outr, 1.e-5, atol));

// Complex to real
RealView1DType out("out", shape), out_b("out_b", shape),
out_o("out_o", shape), out_f("out_f", shape);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfft(execution_space(), x, out, KokkosFFT::Normalization::none,
-1, shape);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfft(execution_space(), x, out_b,
KokkosFFT::Normalization::backward, -1, shape);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfft(execution_space(), x, out_o,
KokkosFFT::Normalization::ortho, -1, shape);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfft(execution_space(), x, out_f,
KokkosFFT::Normalization::forward, -1, shape);

multiply(out_o, sqrt(static_cast<T>(shape)));
multiply(out_b, static_cast<T>(shape));

EXPECT_TRUE(allclose(out_b, out, 1.e-5, atol));
EXPECT_TRUE(allclose(out_o, out, 1.e-5, atol));
EXPECT_TRUE(allclose(out_f, out, 1.e-5, atol));
}
}

template <typename T, typename LayoutType>
void test_fft1_1dfft_2dview(T atol = 1.e-12) {
const int n0 = 10, n1 = 12;
Expand Down Expand Up @@ -1218,6 +1298,15 @@ TYPED_TEST(FFT1D, IHFFT_1DView) {
test_fft1_1dihfft_1dview<float_type, layout_type>();
}

// fft1 on 1D Views with shape argument
TYPED_TEST(FFT1D, FFT_1DView_shape) {
using float_type = typename TestFixture::float_type;
using layout_type = typename TestFixture::layout_type;

float_type atol = std::is_same_v<float_type, float> ? 1.0e-6 : 1.0e-12;
test_fft1_shape<float_type, layout_type>(atol);
}

// batced fft1 on 2D Views
TYPED_TEST(FFT1D, FFT_batched_2DView) {
using float_type = typename TestFixture::float_type;
Expand Down
Loading