Skip to content

Commit

Permalink
Changes per review
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Dec 16, 2022
1 parent 844439e commit 55eb5dd
Show file tree
Hide file tree
Showing 15 changed files with 2,322 additions and 2,329 deletions.
46 changes: 21 additions & 25 deletions src/stan_math_backend/Cpp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ module Stmts = struct

(** Supress warnings for a variable which may not be used. *)
let unused s =
[Comment "supress unused var warning"; Expression (Cast (Void, Var s))]
[Comment "suppress unused var warning"; Expression (Cast (Void, Var s))]
end

module Decls = struct
Expand Down Expand Up @@ -389,14 +389,12 @@ module Printing = struct
pf ppf "@[<2>%s<%a>@]" s (list ~sep:comma pp_type_) ts

let pp_requires ~default ppf requires =
match requires with
| [] -> ()
| _ ->
let pp_require ppf (trait, name) = pf ppf "%s<%s>" trait name in
pf ppf ",@ stan::require_all_t<@[%a@]>*%s"
(list ~sep:comma pp_require)
requires
(if default then " = nullptr" else "")
if not (List.is_empty requires) then
let pp_require ppf (trait, name) = pf ppf "%s<%s>" trait name in
pf ppf ",@ stan::require_all_t<@[%a@]>*%s"
(list ~sep:comma pp_require)
requires
(if default then " = nullptr" else "")

(**
Pretty print a list of templates as [template <parameter-list>].name
Expand All @@ -409,19 +407,17 @@ module Printing = struct
| `Require (requirement, args) ->
pf ppf "%s<%a>*%s" requirement (list ~sep:comma string) args
(if default then " = nullptr" else "") in
match template_parameters with
| [] -> ()
| _ ->
let templates, requires =
List.partition_map template_parameters ~f:(function
| RequireIs (trait, name) -> Second (trait, name)
| Typename name -> First (`Typename name)
| Bool name -> First (`Bool name)
| Require (requirement, args) -> First (`Require (requirement, args)) )
in
pf ppf "template <@[%a%a@]>@ "
(list ~sep:comma pp_basic_template)
templates (pp_requires ~default) requires
if not (List.is_empty template_parameters) then
let templates, requires =
List.partition_map template_parameters ~f:(function
| RequireIs (trait, name) -> Second (trait, name)
| Typename name -> First (`Typename name)
| Bool name -> First (`Bool name)
| Require (requirement, args) -> First (`Require (requirement, args)) )
in
pf ppf "template <@[%a%a@]>@ "
(list ~sep:comma pp_basic_template)
templates (pp_requires ~default) requires

let pp_operator ppf = function
| Multiply -> string ppf "*"
Expand All @@ -438,8 +434,8 @@ module Printing = struct

let rec pp_expr ppf e =
let maybe_templates ppf ts =
if List.length ts = 0 then nop ppf ()
else pf ppf "<@,%a>" (list ~sep:comma pp_type_) ts in
if not (List.is_empty ts) then
pf ppf "<@,%a>" (list ~sep:comma pp_type_) ts in
match e with
| Literal s -> pf ppf "%s" s
| Var id -> string ppf id
Expand Down Expand Up @@ -609,7 +605,7 @@ module Printing = struct
; destructor_body
; public_members } =
pf ppf
"@[<v 1>class %s%s : public %a{@,\
"@[<v 1>class %s%s : public %a {@,\
@[<v 1>private:@,\
%a@]@,\
@[<v 1>public:@,\
Expand Down
19 changes: 8 additions & 11 deletions src/stan_math_backend/Lower_functions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,11 @@ let lower_fun_def (functors : (string, struct_defn) Hashtbl.t)
@ ( if String.Set.mem funs_used_in_reduce_sum fdname then
[register_functor `ReduceSum]
else [] )
@
if String.Map.mem variadic_fns fdname then
(* Produces the variadic functors that has the pstream argument
@ (* Produces the variadic functors that has the pstream argument
as not the last argument. For DAEs this is the 4th, for ODEs the 3rd *)
List.map
(List.stable_dedup @@ String.Map.find_exn variadic_fns fdname)
~f:(fun i -> register_functor (`VariadicHOF i))
else []
List.map
(List.stable_dedup @@ Map.find_multi variadic_fns fdname)
~f:(fun i -> register_functor (`VariadicHOF i))

let is_fun_used_with_reduce_sum (p : Program.Numbered.t) =
let rec find_functors_expr accum Expr.Fixed.{pattern; _} =
Expand Down Expand Up @@ -427,10 +424,10 @@ module Testing = struct
const auto& x = stan::math::to_ref(x_arg__);
const auto& y = stan::math::to_ref(y_arg__);
static constexpr bool propto__ = true;
// supress unused var warning
// suppress unused var warning
(void) propto__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
try {
return stan::math::add(x, 1);
Expand Down Expand Up @@ -494,10 +491,10 @@ module Testing = struct
const auto& y = stan::math::to_ref(y_arg__);
const auto& z = stan::math::to_ref(z_arg__);
static constexpr bool propto__ = true;
// supress unused var warning
// suppress unused var warning
(void) propto__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
try {
return stan::math::add(x, 1);
Expand Down
22 changes: 11 additions & 11 deletions test/integration/cli-args/filename-in-msg/filename_good.expected
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ static constexpr std::array<const char*, 3> locations_array__ =
{" (found before start of program)",
" (in 'filename_good.stan', line 2, column 4 to column 11)",
" (in 'filename_good.stan', line 3, column 4 to column 19)"};
class filename_good_model final : public model_base_crtp<filename_good_model>{
class filename_good_model final : public model_base_crtp<filename_good_model> {
private:
double p;
double q;
Expand All @@ -22,14 +22,14 @@ class filename_good_model final : public model_base_crtp<filename_good_model>{
using local_scalar_t__ = double;
boost::ecuyer1988 base_rng__ =
stan::services::util::create_rng(random_seed__, 0);
// supress unused var warning
// suppress unused var warning
(void) base_rng__;
static constexpr const char* function__ =
"filename_good_model_namespace::filename_good_model";
// supress unused var warning
// suppress unused var warning
(void) function__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
try {
int pos__ = std::numeric_limits<int>::min();
Expand Down Expand Up @@ -65,11 +65,11 @@ class filename_good_model final : public model_base_crtp<filename_good_model>{
stan::io::deserializer<local_scalar_t__> in__(params_r__, params_i__);
int current_statement__ = 0;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
static constexpr const char* function__ =
"filename_good_model_namespace::log_prob";
// supress unused var warning
// suppress unused var warning
(void) function__;
try {

Expand All @@ -94,20 +94,20 @@ class filename_good_model final : public model_base_crtp<filename_good_model>{
stan::io::deserializer<local_scalar_t__> in__(params_r__, params_i__);
stan::io::serializer<local_scalar_t__> out__(vars__);
static constexpr bool propto__ = true;
// supress unused var warning
// suppress unused var warning
(void) propto__;
double lp__ = 0.0;
// supress unused var warning
// suppress unused var warning
(void) lp__;
int current_statement__ = 0;
stan::math::accumulator<double> lp_accum__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
constexpr bool jacobian__ = false;
static constexpr const char* function__ =
"filename_good_model_namespace::write_array";
// supress unused var warning
// suppress unused var warning
(void) function__;
try {
if (stan::math::logical_negation(
Expand All @@ -133,7 +133,7 @@ class filename_good_model final : public model_base_crtp<filename_good_model>{
stan::io::serializer<local_scalar_t__> out__(vars__);
int current_statement__ = 0;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
try {
int pos__ = std::numeric_limits<int>::min();
Expand Down
22 changes: 11 additions & 11 deletions test/integration/good/code-gen/cl.expected
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ static constexpr std::array<const char*, 217> locations_array__ =
" (in 'optimize_glm.stan', line 32, column 9 to column 10)",
" (in 'optimize_glm.stan', line 32, column 12 to column 13)",
" (in 'optimize_glm.stan', line 33, column 13 to column 14)"};
class optimize_glm_model final : public model_base_crtp<optimize_glm_model>{
class optimize_glm_model final : public model_base_crtp<optimize_glm_model> {
private:
int k;
int n;
Expand Down Expand Up @@ -265,14 +265,14 @@ class optimize_glm_model final : public model_base_crtp<optimize_glm_model>{
using local_scalar_t__ = double;
boost::ecuyer1988 base_rng__ =
stan::services::util::create_rng(random_seed__, 0);
// supress unused var warning
// suppress unused var warning
(void) base_rng__;
static constexpr const char* function__ =
"optimize_glm_model_namespace::optimize_glm_model";
// supress unused var warning
// suppress unused var warning
(void) function__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
try {
int pos__ = std::numeric_limits<int>::min();
Expand Down Expand Up @@ -523,11 +523,11 @@ class optimize_glm_model final : public model_base_crtp<optimize_glm_model>{
stan::io::deserializer<local_scalar_t__> in__(params_r__, params_i__);
int current_statement__ = 0;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
static constexpr const char* function__ =
"optimize_glm_model_namespace::log_prob";
// supress unused var warning
// suppress unused var warning
(void) function__;
try {
Eigen::Matrix<local_scalar_t__,-1,1> alpha_v =
Expand Down Expand Up @@ -1238,20 +1238,20 @@ class optimize_glm_model final : public model_base_crtp<optimize_glm_model>{
stan::io::deserializer<local_scalar_t__> in__(params_r__, params_i__);
stan::io::serializer<local_scalar_t__> out__(vars__);
static constexpr bool propto__ = true;
// supress unused var warning
// suppress unused var warning
(void) propto__;
double lp__ = 0.0;
// supress unused var warning
// suppress unused var warning
(void) lp__;
int current_statement__ = 0;
stan::math::accumulator<double> lp_accum__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
constexpr bool jacobian__ = false;
static constexpr const char* function__ =
"optimize_glm_model_namespace::write_array";
// supress unused var warning
// suppress unused var warning
(void) function__;
try {
Eigen::Matrix<double,-1,1> alpha_v =
Expand Down Expand Up @@ -1327,7 +1327,7 @@ class optimize_glm_model final : public model_base_crtp<optimize_glm_model>{
stan::io::serializer<local_scalar_t__> out__(vars__);
int current_statement__ = 0;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// supress unused var warning
// suppress unused var warning
(void) DUMMY_VAR__;
try {
int pos__ = std::numeric_limits<int>::min();
Expand Down
Loading

0 comments on commit 55eb5dd

Please sign in to comment.