Skip to content

Commit

Permalink
Merge pull request #1296 from stan-dev/fix/int-real-funcall-promotion
Browse files Browse the repository at this point in the history
Explicitly promote integers to reals for UDF calls
  • Loading branch information
WardBrian authored Mar 20, 2023
2 parents 2f20a44 + 0dd667d commit a490a09
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/stan_math_backend/Cpp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ module Exprs = struct

(** Helper for [std::numeric_limits<int>::min()] *)
let int_min = fun_call "std::numeric_limits<int>::min" []

let static_cast type_ expr = FunCall ("static_cast", [type_], [expr])
end

module Expression_syntax = struct
Expand Down
15 changes: 11 additions & 4 deletions src/stan_math_backend/Lower_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ and lower_user_defined_fun f suffix es =
let extra_args =
suffix_args suffix @ ["pstream__"] |> List.map ~f:Exprs.to_var in
Exprs.templated_fun_call f (templates true suffix)
(lower_exprs es @ extra_args)
((lower_exprs ~promote_reals:true) es @ extra_args)

and lower_compiler_internal ad ut f es =
let open Expression_syntax in
Expand Down Expand Up @@ -483,15 +483,21 @@ and lower_indexed_simple e idcs =
List.fold idcs ~init:(lower_expr e) ~f:(fun e id ->
Index (e, idx_minus_one id) )

and lower_expr (Expr.Fixed.{pattern; meta} : Expr.Typed.t) : Cpp.expr =
and lower_expr ?(promote_reals = false)
(Expr.Fixed.{pattern; meta} : Expr.Typed.t) : Cpp.expr =
let open Exprs in
match pattern with
| Var s -> Var s
| Lit (Str, s) -> literal_string s
| Lit (Imaginary, s) ->
fun_call "stan::math::to_complex" [Literal "0"; Literal s]
| Lit ((Real | Int), s) -> Literal s
| Promotion (expr, UReal, _) when is_scalar expr -> lower_expr expr
| Promotion (expr, UReal, _) when is_scalar expr ->
if promote_reals then
(* this can be important for e.g. templated function calls
where we might generate an incorrect specification for int *)
static_cast Cpp.Double (lower_expr expr)
else lower_expr expr
| Promotion (expr, UComplex, DataOnly) when is_scalar expr ->
(* this is in principle a little better than promote_scalar since it is constexpr *)
fun_call "stan::math::to_complex" [lower_expr expr; Literal "0"]
Expand Down Expand Up @@ -538,7 +544,8 @@ and lower_expr (Expr.Fixed.{pattern; meta} : Expr.Typed.t) : Cpp.expr =
lower_indexed_simple e idx
| _ -> lower_indexed e idx (Fmt.to_to_string Expr.Typed.pp e) )

and lower_exprs = List.map ~f:lower_expr
and lower_exprs ?(promote_reals = false) =
List.map ~f:(lower_expr ~promote_reals)

module Testing = struct
(* these functions are just for testing *)
Expand Down
7 changes: 2 additions & 5 deletions src/stan_math_backend/Lower_program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ let gen_validate_data name st =
if String.is_suffix ~suffix:"__" name then []
else
let vector args =
let cast x =
Exprs.templated_fun_call "static_cast" [Types.size_t] [lower_expr x]
in
let cast x = Exprs.static_cast Types.size_t (lower_expr x) in
InitializerExpr (Types.std_vector Types.size_t, List.map ~f:cast args)
in
let open Expression_syntax in
Expand Down Expand Up @@ -344,8 +342,7 @@ let gen_get_param_names {Program.output_vars; _} =
~body ~cv_qualifiers:[Const] () )

let gen_get_dims {Program.output_vars; _} =
let cast x =
Exprs.templated_fun_call "static_cast" [Types.size_t] [lower_expr x] in
let cast x = Exprs.static_cast Types.size_t (lower_expr x) in
let pack inner_dims =
Exprs.std_vector_init_expr Types.size_t
(List.map ~f:cast (SizedType.get_dims_io inner_dims)) in
Expand Down
Loading

0 comments on commit a490a09

Please sign in to comment.