diff --git a/core/src/main/scala/vapors/algebra/Expr.scala b/core/src/main/scala/vapors/algebra/Expr.scala index 0b5a64037..4c56d380d 100644 --- a/core/src/main/scala/vapors/algebra/Expr.scala +++ b/core/src/main/scala/vapors/algebra/Expr.scala @@ -52,6 +52,7 @@ object Expr { def visitDivideOutputs[R : Division](expr: DivideOutputs[V, R, P]): G[R] def visitEmbed[R](expr: Embed[V, R, P]): G[R] def visitExistsInOutput[M[_] : Foldable, U](expr: ExistsInOutput[V, M, U, P]): G[Boolean] + def visitExponentiateOutputs(expr: ExponentiateOutputs[V, P]): G[Double] def visitFilterOutput[M[_] : Foldable : FunctorFilter, R](expr: FilterOutput[V, M, R, P]): G[M[R]] def visitFlatMapOutput[M[_] : Foldable : FlatMap, U, X](expr: FlatMapOutput[V, M, U, X, P]): G[M[X]] def visitGroupOutput[M[_] : Foldable, U : Order, K](expr: GroupOutput[V, M, U, K, P]): G[MapView[K, Seq[U]]] @@ -531,6 +532,17 @@ object Expr { override def visit[G[_]](v: Visitor[V, P, G]): G[R] = v.visitDivideOutputs(this) } + /** + * Raise the given base expression to the result of the given power exponent expression. + */ + final case class ExponentiateOutputs[V, P]( + baseExpr: Expr[V, Double, P], + exponentExpr: Expr[V, Double, P], + capture: CaptureP[V, Double, P], + ) extends Expr[V, Double, P] { + override def visit[G[_]](v: Visitor[V, P, G]): G[Double] = v.visitExponentiateOutputs(this) + } + /** * Returns the negative result value of [[inputExpr]] using the provided definition for [[Negative]]. */ diff --git a/core/src/main/scala/vapors/algebra/ExprResult.scala b/core/src/main/scala/vapors/algebra/ExprResult.scala index d5e173724..4ecbd5196 100644 --- a/core/src/main/scala/vapors/algebra/ExprResult.scala +++ b/core/src/main/scala/vapors/algebra/ExprResult.scala @@ -53,6 +53,7 @@ object ExprResult { def visitDivideOutputs[R](result: DivideOutputs[V, R, P]): G[R] def visitEmbed[R](result: Embed[V, R, P]): G[R] def visitExistsInOutput[M[_] : Foldable, U](result: ExistsInOutput[V, M, U, P]): G[Boolean] + def visitExponentiateOutputs(result: ExponentiateOutputs[V, P]): G[Double] def visitFilterOutput[M[_] : Foldable : FunctorFilter, R](result: FilterOutput[V, M, R, P]): G[M[R]] def visitFlatMapOutput[M[_], U, R](result: FlatMapOutput[V, M, U, R, P]): G[M[R]] def visitGroupOutput[M[_] : Foldable, U : Order, K](result: GroupOutput[V, M, U, K, P]): G[MapView[K, Seq[U]]] @@ -313,6 +314,15 @@ object ExprResult { override def visit[G[_]](v: Visitor[V, P, G]): G[R] = v.visitDivideOutputs(this) } + final case class ExponentiateOutputs[V, P]( + expr: Expr.ExponentiateOutputs[V, P], + context: Context[V, Double, P], + baseResult: ExprResult[V, Double, P], + exponentResult: ExprResult[V, Double, P], + ) extends ExprResult[V, Double, P] { + override def visit[G[_]](v: Visitor[V, P, G]): G[Double] = v.visitExponentiateOutputs(this) + } + final case class WrapOutput[V, L, R, P]( expr: Expr.WrapOutput[V, L, R, P], context: Context[V, R, P], diff --git a/core/src/main/scala/vapors/dsl/ExprDsl.scala b/core/src/main/scala/vapors/dsl/ExprDsl.scala index 7a7763a64..9dded77df 100644 --- a/core/src/main/scala/vapors/dsl/ExprDsl.scala +++ b/core/src/main/scala/vapors/dsl/ExprDsl.scala @@ -261,6 +261,14 @@ trait ExprDsl extends TimeFunctions with WrapExprSyntax with WrapEachExprSyntax ): Expr.DivideOutputs[V, R, P] = Expr.DivideOutputs(NonEmptyList.of(lhs, rhs), capture) + def pow[V, P]( + base: Expr[V, Double, P], + exp: Expr[V, Double, P], + )(implicit + capture: CaptureP[V, Double, P], + ): Expr.ExponentiateOutputs[V, P] = + Expr.ExponentiateOutputs(base, exp, capture) + def negative[V, R : Negative, P]( inputExpr: Expr[V, R, P], )(implicit diff --git a/core/src/main/scala/vapors/interpreter/InterpretExprAsResultFn.scala b/core/src/main/scala/vapors/interpreter/InterpretExprAsResultFn.scala index 6c0e28378..fafad6ef4 100644 --- a/core/src/main/scala/vapors/interpreter/InterpretExprAsResultFn.scala +++ b/core/src/main/scala/vapors/interpreter/InterpretExprAsResultFn.scala @@ -167,6 +167,21 @@ final class InterpretExprAsResultFn[V, P] extends Expr.Visitor[V, P, Lambda[r => } } + override def visitExponentiateOutputs( + expr: Expr.ExponentiateOutputs[V, P], + ): ExprInput[V] => ExprResult[V, Double, P] = { input => + val baseResult = expr.baseExpr.visit(this)(input) + val exponentResult = expr.exponentExpr.visit(this)(input) + val base = baseResult.output.value + val exponent = exponentResult.output.value + val output = Math.pow(base, exponent) + val allEvidence = baseResult.output.evidence ++ exponentResult.output.evidence + val allParams = List(baseResult.param, exponentResult.param) + resultOfManySubExpr(expr, input, output, allEvidence, allParams) { + ExprResult.ExponentiateOutputs(_, _, baseResult, exponentResult) + } + } + override def visitFilterOutput[M[_] : Foldable : FunctorFilter, U]( expr: Expr.FilterOutput[V, M, U, P], ): ExprInput[V] => ExprResult[V, M[U], P] = { input => diff --git a/core/src/main/scala/vapors/interpreter/VisitGenericExprWithProxyFn.scala b/core/src/main/scala/vapors/interpreter/VisitGenericExprWithProxyFn.scala index 3ce6fbb4b..c73955549 100644 --- a/core/src/main/scala/vapors/interpreter/VisitGenericExprWithProxyFn.scala +++ b/core/src/main/scala/vapors/interpreter/VisitGenericExprWithProxyFn.scala @@ -62,6 +62,9 @@ abstract class VisitGenericExprWithProxyFn[V, P, G[_]] extends Visitor[V, P, Lam expr: Expr.ExistsInOutput[V, M, U, P], ): ExprInput[V] => G[Boolean] = visitGeneric(expr, _) + override def visitExponentiateOutputs(expr: Expr.ExponentiateOutputs[V, P]): ExprInput[V] => G[Double] = + visitGeneric(expr, _) + override def visitFilterOutput[M[_] : Foldable : FunctorFilter, R]( expr: Expr.FilterOutput[V, M, R, P], ): ExprInput[V] => G[M[R]] = visitGeneric(expr, _) diff --git a/core/src/test/scala/vapors/interpreter/ExponentiateOutputsSpec.scala b/core/src/test/scala/vapors/interpreter/ExponentiateOutputsSpec.scala new file mode 100644 index 000000000..a94a107c6 --- /dev/null +++ b/core/src/test/scala/vapors/interpreter/ExponentiateOutputsSpec.scala @@ -0,0 +1,31 @@ +package com.rallyhealth + +package vapors.interpreter + +import vapors.dsl +import vapors.dsl._ + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.freespec.AnyFreeSpec + +import scala.reflect.classTag + +class ExponentiateOutputsSpec extends AnyFreeSpec { + + "Expr.ExponentiateOutputs" - { + + "reasonable positive number raised to a reasonable positive exponent" in { + VaporsEvalTestHelpers.producesTheSameResultOrException[Double, Double, Double, ArithmeticException]( + Math.pow, + (b, e) => dsl.pow(const(b), const(e)), + )(Arbitrary(Gen.choose(0d, 100d)), Arbitrary(Gen.choose(1d, 10d)), classTag[ArithmeticException]) + } + + "arbitrary number raised to a arbitrary exponent" in { + VaporsEvalTestHelpers.producesTheSameResultOrException[Double, Double, Double, ArithmeticException]( + Math.pow, + (b, e) => dsl.pow(const(b), const(e)), + ) + } + } +} diff --git a/core/src/test/scala/vapors/interpreter/VaporsEvalTestHelpers.scala b/core/src/test/scala/vapors/interpreter/VaporsEvalTestHelpers.scala index b629ed60c..b2a9d067d 100644 --- a/core/src/test/scala/vapors/interpreter/VaporsEvalTestHelpers.scala +++ b/core/src/test/scala/vapors/interpreter/VaporsEvalTestHelpers.scala @@ -97,8 +97,18 @@ object VaporsEvalTestHelpers { }, expectedValue => { val result = eval(FactTable.empty)(query) - assertResult(expectedValue) { - result.output.value + expectedValue match { + case num: Double if num.isNaN => + result.output.value match { + case obs: Double => + assert(obs.isNaN, "Expected output to be NaN") + case _ => + fail("Expected output to be a Double") + } + case _ => + assertResult(expectedValue) { + result.output.value + } } }, )