diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index c34c5f529..9b89e628f 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1588,50 +1588,178 @@ pub fn s_clip_float(s: ExSeries, min: f64, max: f64) -> Result Result { - let s1 = s.f64()?.apply_values(|o| o.sin()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.sin()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.sin()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] pub fn s_cos(s: ExSeries) -> Result { - let s1 = s.f64()?.apply_values(|o| o.cos()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.cos()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.cos()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] pub fn s_tan(s: ExSeries) -> Result { - let s1 = s.f64()?.apply_values(|o| o.tan()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.tan()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.tan()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] pub fn s_asin(s: ExSeries) -> Result { - let s1 = s.f64()?.apply_values(|o| o.asin()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.asin()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.asin()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] pub fn s_acos(s: ExSeries) -> Result { - let s1 = s.f64()?.apply_values(|o| o.acos()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.acos()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.acos()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] pub fn s_atan(s: ExSeries) -> Result { - let s1 = s.f64()?.apply_values(|o| o.atan()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.atan()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.atan()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] pub fn s_degrees(s: ExSeries) -> Result { - let s1 = s.f64()?.apply_values(|o| o.to_degrees()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.to_degrees()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.to_degrees()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] pub fn s_radians(s: ExSeries) -> Result { - let s1 = s.f64()?.apply_values(|o| o.to_radians()).into(); - Ok(ExSeries::new(s1)) + match s.dtype() { + DataType::Float64 => { + let s1 = s.f64()?.apply_values(|o| o.to_radians()).into(); + Ok(ExSeries::new(s1)) + } + DataType::Float32 => { + let s1 = s + .f32()? + .cast(&DataType::Float64)? + .f64()? + .apply_values(|o| o.to_radians()) + .into(); + Ok(ExSeries::new(s1)) + } + _ => Err(ExplorerError::Other( + "Only f32 and f64 dtypes are supported".into(), + )), + } } #[rustler::nif(schedule = "DirtyCpu")] diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index d400573e1..75bd2c337 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -3202,7 +3202,28 @@ defmodule Explorer.SeriesTest do series = Series.sin(s) - assert Series.to_list(series) == [0.0, 1.0, 1.2246467991473532e-16, -2.4492935982947064e-16] + expected_series = + Series.from_list([0.0, 1.0, 1.2246467991473532e-16, -2.4492935982947064e-16]) + + assert all_close?(series, expected_series) + end + + test "calculates the sine of all elements in the series for f32 input and outputs f64" do + pi = :math.pi() + s = Explorer.Series.from_list([0, pi / 2, pi, 2 * pi], dtype: :f32) + + series = Series.sin(s) + + expected_series = + Series.from_list([ + 0.0, + 0.999999999999999, + -8.742278000372475e-8, + 1.7484556000744883e-7 + ]) + + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -3212,8 +3233,27 @@ defmodule Explorer.SeriesTest do s = Explorer.Series.from_list([0, pi / 2, pi, 2 * pi]) series = Series.cos(s) + expected_series = Series.from_list([1.0, 6.123233995736766e-17, -1.0, 1.0]) - assert Series.to_list(series) == [1.0, 6.123233995736766e-17, -1.0, 1.0] + assert all_close?(series, expected_series) + end + + test "calculates the cosine of all elements in the series for f32 input and outputs f64" do + pi = :math.pi() + s = Explorer.Series.from_list([0, pi / 2, pi, 2 * pi], dtype: :f32) + + series = Series.cos(s) + + expected_series = + Series.from_list([ + 1.0, + -4.371139000186241e-8, + -0.9999999999999962, + 0.9999999999999847 + ]) + + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -3224,12 +3264,33 @@ defmodule Explorer.SeriesTest do series = Series.tan(s) - assert Series.to_list(series) == [ - 0.0, - 1.633123935319537e16, - -1.2246467991473532e-16, - -2.4492935982947064e-16 - ] + expected_series = + Series.from_list([ + 0.0, + 1.633123935319537e16, + -1.2246467991473532e-16, + -2.4492935982947064e-16 + ]) + + assert all_close?(series, expected_series) + end + + test "calculates the tangent of all elements in the series for f32 input and outputs f64" do + pi = :math.pi() + s = Explorer.Series.from_list([0, pi / 2, pi, 2 * pi], dtype: :f32) + + series = Series.tan(s) + + expected_series = + Series.from_list([ + 0.0, + -22_877_332.42885646, + 8.742278000372508e-8, + 1.7484556000745148e-7 + ]) + + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -3238,8 +3299,19 @@ defmodule Explorer.SeriesTest do s = Explorer.Series.from_list([0.0, 1.0]) series = Series.asin(s) + expected_series = Series.from_list([0.0, 1.5707963267948966]) - assert Series.to_list(series) == [0.0, 1.5707963267948966] + assert all_close?(series, expected_series) + end + + test "calculates the arcsine of all elements in the series for f32 input and outputs f64" do + s = Explorer.Series.from_list([0.0, 1.0], dtype: :f32) + + series = Series.asin(s) + expected_series = Series.from_list([0.0, 1.5707963267948966]) + + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -3248,8 +3320,19 @@ defmodule Explorer.SeriesTest do s = Explorer.Series.from_list([0.0, 1.0]) series = Series.acos(s) + expected_series = Series.from_list([1.5707963267948966, 0.0]) - assert Series.to_list(series) == [1.5707963267948966, 0.0] + assert all_close?(series, expected_series) + end + + test "calculates the arccosine of all elements in the series for f32 input and outputs f64" do + s = Explorer.Series.from_list([0.0, 1.0], dtype: :f32) + + series = Series.acos(s) + expected_series = Series.from_list([1.5707963267948966, 0.0]) + + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -3258,8 +3341,19 @@ defmodule Explorer.SeriesTest do s = Explorer.Series.from_list([0.0, 1.0]) series = Series.atan(s) + expected_series = Series.from_list([0.0, 0.7853981633974483]) + + assert all_close?(series, expected_series) + end + + test "calculates the arctangent of all elements in the series for f32 input and outputs f64" do + s = Explorer.Series.from_list([0.0, 1.0], dtype: :f32) + + series = Series.atan(s) + expected_series = Series.from_list([0.0, 0.7853981633974483]) - assert Series.to_list(series) == [0.0, 0.7853981633974483] + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -3269,8 +3363,30 @@ defmodule Explorer.SeriesTest do s = Explorer.Series.from_list([-2 * pi, -pi, -pi / 2, 0, pi / 2, pi, 2 * pi]) series = Series.degrees(s) + expected_series = Series.from_list([-360.0, -180.0, -90.0, 0.0, 90.0, 180.0, 360.0]) - assert Series.to_list(series) == [-360.0, -180.0, -90.0, 0.0, 90.0, 180.0, 360.0] + assert all_close?(series, expected_series) + end + + test "converts the given series of radians to degrees for f32 input and outputs f64" do + pi = :math.pi() + s = Explorer.Series.from_list([-2 * pi, -pi, -pi / 2, 0, pi / 2, pi, 2 * pi], dtype: :f32) + + series = Series.degrees(s) + + expected_series = + Series.from_list([ + -360.00001001791264, + -180.00000500895632, + -90.00000250447816, + 0.0, + 90.00000250447816, + 180.00000500895632, + 360.00001001791264 + ]) + + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -3280,8 +3396,20 @@ defmodule Explorer.SeriesTest do s = Explorer.Series.from_list([-360.0, -180.0, -90.0, 0.0, 90.0, 180.0, 360.0]) series = Series.radians(s) + expected_series = Series.from_list([-2 * pi, -pi, -pi / 2, 0, pi / 2, pi, 2 * pi]) - assert Series.to_list(series) == [-2 * pi, -pi, -pi / 2, 0, pi / 2, pi, 2 * pi] + assert all_close?(series, expected_series) + end + + test "converts the given series of degrees to radians for f32 input and outputs f64" do + pi = :math.pi() + s = Explorer.Series.from_list([-360.0, -180.0, -90.0, 0.0, 90.0, 180.0, 360.0], dtype: :f32) + + series = Series.radians(s) + expected_series = Series.from_list([-2 * pi, -pi, -pi / 2, 0, pi / 2, pi, 2 * pi]) + + assert all_close?(series, expected_series) + assert Series.dtype(series) == {:f, 64} end end @@ -6624,4 +6752,11 @@ defmodule Explorer.SeriesTest do assert Series.to_list(sj) == [%{"n" => 1.0}] end end + + defp all_close?(a, b, tol \\ 1.0e-8) do + Series.subtract(a, b) + |> Series.abs() + |> Series.less_equal(tol) + |> Series.all?() + end end