From 5a4ac1ebf0e30b415b90a57904e4c2cb32f35068 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Sat, 6 Aug 2022 10:30:55 -0400 Subject: [PATCH] work around apfloat bug in FMA by using host floats instead --- src/shims/foreign_items.rs | 18 +++++++++--------- src/shims/intrinsics/mod.rs | 34 ++++++++++++++++++++-------------- src/shims/intrinsics/simd.rs | 19 +++++++++++++++---- tests/pass/intrinsics-math.rs | 2 ++ tests/pass/portable-simd.rs | 10 ++++++++++ 5 files changed, 56 insertions(+), 27 deletions(-) diff --git a/src/shims/foreign_items.rs b/src/shims/foreign_items.rs index 208e7ea788..9a985b2450 100644 --- a/src/shims/foreign_items.rs +++ b/src/shims/foreign_items.rs @@ -588,7 +588,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx let [f] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; // FIXME: Using host floats. let f = f32::from_bits(this.read_scalar(f)?.to_u32()?); - let f = match link_name.as_str() { + let res = match link_name.as_str() { "cbrtf" => f.cbrt(), "coshf" => f.cosh(), "sinhf" => f.sinh(), @@ -598,7 +598,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx "atanf" => f.atan(), _ => bug!(), }; - this.write_scalar(Scalar::from_u32(f.to_bits()), dest)?; + this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; } #[rustfmt::skip] | "_hypotf" @@ -611,12 +611,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx // FIXME: Using host floats. let f1 = f32::from_bits(this.read_scalar(f1)?.to_u32()?); let f2 = f32::from_bits(this.read_scalar(f2)?.to_u32()?); - let n = match link_name.as_str() { + let res = match link_name.as_str() { "_hypotf" | "hypotf" => f1.hypot(f2), "atan2f" => f1.atan2(f2), _ => bug!(), }; - this.write_scalar(Scalar::from_u32(n.to_bits()), dest)?; + this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; } #[rustfmt::skip] | "cbrt" @@ -630,7 +630,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx let [f] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; // FIXME: Using host floats. let f = f64::from_bits(this.read_scalar(f)?.to_u64()?); - let f = match link_name.as_str() { + let res = match link_name.as_str() { "cbrt" => f.cbrt(), "cosh" => f.cosh(), "sinh" => f.sinh(), @@ -640,7 +640,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx "atan" => f.atan(), _ => bug!(), }; - this.write_scalar(Scalar::from_u64(f.to_bits()), dest)?; + this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; } #[rustfmt::skip] | "_hypot" @@ -651,12 +651,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx // FIXME: Using host floats. let f1 = f64::from_bits(this.read_scalar(f1)?.to_u64()?); let f2 = f64::from_bits(this.read_scalar(f2)?.to_u64()?); - let n = match link_name.as_str() { + let res = match link_name.as_str() { "_hypot" | "hypot" => f1.hypot(f2), "atan2" => f1.atan2(f2), _ => bug!(), }; - this.write_scalar(Scalar::from_u64(n.to_bits()), dest)?; + this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; } #[rustfmt::skip] | "_ldexp" @@ -668,7 +668,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx let x = this.read_scalar(x)?.to_f64()?; let exp = this.read_scalar(exp)?.to_i32()?; - // Saturating cast to i16. Even those are outside the valid exponent range to + // Saturating cast to i16. Even those are outside the valid exponent range so // `scalbn` below will do its over/underflow handling. let exp = if exp > i32::from(i16::MAX) { i16::MAX diff --git a/src/shims/intrinsics/mod.rs b/src/shims/intrinsics/mod.rs index 4c2d08ffce..08a6e0fcc0 100644 --- a/src/shims/intrinsics/mod.rs +++ b/src/shims/intrinsics/mod.rs @@ -285,7 +285,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx // FIXME: Using host floats. let f = f32::from_bits(this.read_scalar(f)?.to_u32()?); let f2 = f32::from_bits(this.read_scalar(f2)?.to_u32()?); - this.write_scalar(Scalar::from_u32(f.powf(f2).to_bits()), dest)?; + let res = f.powf(f2); + this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; } "powf64" => { @@ -293,25 +294,28 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx // FIXME: Using host floats. let f = f64::from_bits(this.read_scalar(f)?.to_u64()?); let f2 = f64::from_bits(this.read_scalar(f2)?.to_u64()?); - this.write_scalar(Scalar::from_u64(f.powf(f2).to_bits()), dest)?; + let res = f.powf(f2); + this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; } "fmaf32" => { let [a, b, c] = check_arg_count(args)?; - let a = this.read_scalar(a)?.to_f32()?; - let b = this.read_scalar(b)?.to_f32()?; - let c = this.read_scalar(c)?.to_f32()?; - let res = a.mul_add(b, c).value; - this.write_scalar(Scalar::from_f32(res), dest)?; + // FIXME: Using host floats, to work around https://github.com/rust-lang/miri/issues/2468. + let a = f32::from_bits(this.read_scalar(a)?.to_u32()?); + let b = f32::from_bits(this.read_scalar(b)?.to_u32()?); + let c = f32::from_bits(this.read_scalar(c)?.to_u32()?); + let res = a.mul_add(b, c); + this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; } "fmaf64" => { let [a, b, c] = check_arg_count(args)?; - let a = this.read_scalar(a)?.to_f64()?; - let b = this.read_scalar(b)?.to_f64()?; - let c = this.read_scalar(c)?.to_f64()?; - let res = a.mul_add(b, c).value; - this.write_scalar(Scalar::from_f64(res), dest)?; + // FIXME: Using host floats, to work around https://github.com/rust-lang/miri/issues/2468. + let a = f64::from_bits(this.read_scalar(a)?.to_u64()?); + let b = f64::from_bits(this.read_scalar(b)?.to_u64()?); + let c = f64::from_bits(this.read_scalar(c)?.to_u64()?); + let res = a.mul_add(b, c); + this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; } "powif32" => { @@ -319,7 +323,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx // FIXME: Using host floats. let f = f32::from_bits(this.read_scalar(f)?.to_u32()?); let i = this.read_scalar(i)?.to_i32()?; - this.write_scalar(Scalar::from_u32(f.powi(i).to_bits()), dest)?; + let res = f.powi(i); + this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; } "powif64" => { @@ -327,7 +332,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx // FIXME: Using host floats. let f = f64::from_bits(this.read_scalar(f)?.to_u64()?); let i = this.read_scalar(i)?.to_i32()?; - this.write_scalar(Scalar::from_u64(f.powi(i).to_bits()), dest)?; + let res = f.powi(i); + this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; } "float_to_int_unchecked" => { diff --git a/src/shims/intrinsics/simd.rs b/src/shims/intrinsics/simd.rs index d467c3c509..0c3241683a 100644 --- a/src/shims/intrinsics/simd.rs +++ b/src/shims/intrinsics/simd.rs @@ -238,14 +238,25 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx let dest = this.mplace_index(&dest, i)?; // Works for f32 and f64. + // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468. let ty::Float(float_ty) = dest.layout.ty.kind() else { span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) }; let val = match float_ty { - FloatTy::F32 => - Scalar::from_f32(a.to_f32()?.mul_add(b.to_f32()?, c.to_f32()?).value), - FloatTy::F64 => - Scalar::from_f64(a.to_f64()?.mul_add(b.to_f64()?, c.to_f64()?).value), + FloatTy::F32 => { + let a = f32::from_bits(a.to_u32()?); + let b = f32::from_bits(b.to_u32()?); + let c = f32::from_bits(c.to_u32()?); + let res = a.mul_add(b, c); + Scalar::from_u32(res.to_bits()) + } + FloatTy::F64 => { + let a = f64::from_bits(a.to_u64()?); + let b = f64::from_bits(b.to_u64()?); + let c = f64::from_bits(c.to_u64()?); + let res = a.mul_add(b, c); + Scalar::from_u64(res.to_bits()) + } }; this.write_scalar(val, &dest.into())?; } diff --git a/tests/pass/intrinsics-math.rs b/tests/pass/intrinsics-math.rs index 0cb42580fc..fad01047b9 100644 --- a/tests/pass/intrinsics-math.rs +++ b/tests/pass/intrinsics-math.rs @@ -60,6 +60,8 @@ pub fn main() { assert_eq!(0.0f32.mul_add(-2.0, f32::consts::E), f32::consts::E); assert_approx_eq!(3.0f64.mul_add(2.0, 5.0), 11.0); assert_eq!(0.0f64.mul_add(-2.0f64, f64::consts::E), f64::consts::E); + assert_eq!((-3.2f32).mul_add(2.4, f32::NEG_INFINITY), f32::NEG_INFINITY); + assert_eq!((-3.2f64).mul_add(2.4, f64::NEG_INFINITY), f64::NEG_INFINITY); assert_approx_eq!((-1.0f32).abs(), 1.0f32); assert_approx_eq!(34.2f64.abs(), 34.2f64); diff --git a/tests/pass/portable-simd.rs b/tests/pass/portable-simd.rs index ec70eea6b1..173ac654b0 100644 --- a/tests/pass/portable-simd.rs +++ b/tests/pass/portable-simd.rs @@ -18,6 +18,11 @@ fn simd_ops_f32() { assert_eq!(a.mul_add(b, a), (a * b) + a); assert_eq!(b.mul_add(b, a), (b * b) + a); + assert_eq!(a.mul_add(b, b), (a * b) + b); + assert_eq!( + f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)), + f32x4::splat(f32::NEG_INFINITY) + ); assert_eq!((a * a).sqrt(), a); assert_eq!((b * b).sqrt(), b.abs()); @@ -67,6 +72,11 @@ fn simd_ops_f64() { assert_eq!(a.mul_add(b, a), (a * b) + a); assert_eq!(b.mul_add(b, a), (b * b) + a); + assert_eq!(a.mul_add(b, b), (a * b) + b); + assert_eq!( + f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)), + f64x4::splat(f64::NEG_INFINITY) + ); assert_eq!((a * a).sqrt(), a); assert_eq!((b * b).sqrt(), b.abs());