Skip to content

Commit

Permalink
Merge pull request #1293 from stan-dev/cleanup/promotion-robustness
Browse files Browse the repository at this point in the history
Make Promotion.get_type_promotion_exn more robust to bad inputs
  • Loading branch information
WardBrian authored Mar 17, 2023
2 parents 2eb708a + b6141bc commit 2f20a44
Showing 1 changed file with 46 additions and 32 deletions.
78 changes: 46 additions & 32 deletions src/frontend/Promotion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,39 +73,53 @@ let rec promote (exp : Ast.typed_expression) prom =
let promote_list es promotions = List.map2_exn es promotions ~f:promote

(** Get the promotion needed to make the second type into the first.
Types NEED to have previously been checked to be promotable
Types NEED to have previously been checked to be promotable or
else a fatal error will be thrown.
*)
let rec get_type_promotion_exn (ad_orig, ty_orig) (ad_expect, ty_expect) =
match (ty_orig, ty_expect) with
| UnsizedType.(
( UReal, (UReal | UInt)
| UVector, UVector
| URowVector, URowVector
| UMatrix, UMatrix ))
when ad_orig <> ad_expect ->
ToVar
| UComplex, (UReal | UInt | UComplex)
|UComplexMatrix, (UMatrix | UComplexMatrix)
|UComplexVector, (UVector | UComplexVector)
|UComplexRowVector, (URowVector | UComplexRowVector)
when ad_orig <> ad_expect ->
ToComplexVar
| UReal, UInt -> IntToReal
| UComplex, UInt -> IntToComplex
| UComplex, UReal
|UComplexMatrix, UMatrix
|UComplexVector, UVector
|UComplexRowVector, URowVector ->
RealToComplex
| UArray nt1, UArray nt2 ->
get_type_promotion_exn (ad_orig, nt1) (ad_expect, nt2)
| t1, t2 when t1 = t2 -> NoPromotion
| _, _ ->
Common.FatalError.fatal_error_msg
[%message
"Tried to get promotion of mismatched types!"
(ty_orig : UnsizedType.t)
(ty_expect : UnsizedType.t)]
let rec get_type_promotion_exn (ad_requested, ty_requested)
(ad_current, ty_current) =
if UnsizedType.autodifftype_can_convert ad_requested ad_current then
match (ty_requested, ty_current) with
| UnsizedType.(
( UReal, (UReal | UInt)
| UVector, UVector
| URowVector, URowVector
| UMatrix, UMatrix ))
when ad_current <> ad_requested ->
ToVar
| UComplex, (UReal | UInt | UComplex)
|UComplexMatrix, (UMatrix | UComplexMatrix)
|UComplexVector, (UVector | UComplexVector)
|UComplexRowVector, (URowVector | UComplexRowVector)
when ad_current <> ad_requested ->
ToComplexVar
| UReal, UInt -> IntToReal
| UComplex, UInt -> IntToComplex
| UComplex, UReal
|UComplexMatrix, UMatrix
|UComplexVector, UVector
|UComplexRowVector, URowVector ->
RealToComplex
| UArray nt1, UArray nt2 ->
get_type_promotion_exn (ad_requested, nt1) (ad_current, nt2)
| UInt, UInt -> NoPromotion
| t1, t2 when t1 = t2 && ad_requested = ad_current -> NoPromotion
| _, _ ->
Common.FatalError.fatal_error_msg
[%message
"Tried to get promotion of mismatched types!"
(ty_current : UnsizedType.t)
(ad_current : UnsizedType.autodifftype)
"cannot be promoted to "
(ty_requested : UnsizedType.t)
(ad_requested : UnsizedType.autodifftype)]
else
Common.FatalError.fatal_error_msg
[%message
"Tried to get promotion incompatible autodifftypes!"
(ad_current : UnsizedType.autodifftype)
"cannot be promoted to "
(ad_requested : UnsizedType.autodifftype)]

(** Calculate the "cost"/number of promotions performed.
Used to disambiguate function signatures
Expand Down

0 comments on commit 2f20a44

Please sign in to comment.