Skip to content

Commit

Permalink
Fix trigonometric functions to not raise on f32 inputs (#1051)
Browse files Browse the repository at this point in the history
* Fix trigonometric functions to not raise on f32 inputs

* fix precision issues in test cases
  • Loading branch information
sasikumar87 authored Jan 10, 2025
1 parent d2dfe83 commit 672f111
Show file tree
Hide file tree
Showing 2 changed files with 292 additions and 29 deletions.
160 changes: 144 additions & 16 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1588,50 +1588,178 @@ pub fn s_clip_float(s: ExSeries, min: f64, max: f64) -> Result<ExSeries, Explore

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_sin(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.sin()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.sin()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.sin())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_cos(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.cos()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.cos()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.cos())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_tan(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.tan()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.tan()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.tan())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_asin(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.asin()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.asin()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.asin())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_acos(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.acos()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.acos()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.acos())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_atan(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.atan()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.atan()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.atan())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_degrees(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.to_degrees()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.to_degrees()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.to_degrees())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_radians(s: ExSeries) -> Result<ExSeries, ExplorerError> {
let s1 = s.f64()?.apply_values(|o| o.to_radians()).into();
Ok(ExSeries::new(s1))
match s.dtype() {
DataType::Float64 => {
let s1 = s.f64()?.apply_values(|o| o.to_radians()).into();
Ok(ExSeries::new(s1))
}
DataType::Float32 => {
let s1 = s
.f32()?
.cast(&DataType::Float64)?
.f64()?
.apply_values(|o| o.to_radians())
.into();
Ok(ExSeries::new(s1))
}
_ => Err(ExplorerError::Other(
"Only f32 and f64 dtypes are supported".into(),
)),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
Expand Down
Loading

0 comments on commit 672f111

Please sign in to comment.