Skip to content

Commit

Permalink
Diagnostic tool to look for potential evaluation errors (#1268)
Browse files Browse the repository at this point in the history
* diagnostic tool to look for potential evaluation errors

* working on diagnostic tool to detect possible evaluation errors

* working on diagnostic tool to detect possible evaluation errors

* run black

* working on tests for evaluation error detection

* working on tests for evaluation error detection

* run black

* evaluation error detection

* evaluation error detection

* fix pylint issues

* run black

* fix pylint issues

* fix tests

---------

Co-authored-by: Michael Bynum <[email protected]>
  • Loading branch information
michaelbynum and michaelbynum authored Nov 15, 2023
1 parent 210db70 commit d2f24ad
Show file tree
Hide file tree
Showing 2 changed files with 541 additions and 2 deletions.
244 changes: 242 additions & 2 deletions idaes/core/util/model_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from operator import itemgetter
import sys
from inspect import signature
import math
from math import log
from typing import List

import numpy as np
from scipy.linalg import svd
Expand All @@ -29,6 +31,7 @@

from pyomo.environ import (
Binary,
Integers,
Block,
check_optimal_termination,
ConcreteModel,
Expand All @@ -40,9 +43,21 @@
value,
Var,
)
from pyomo.core.expr.numeric_expr import (
DivisionExpression,
NPV_DivisionExpression,
PowExpression,
NPV_PowExpression,
UnaryFunctionExpression,
NPV_UnaryFunctionExpression,
NumericExpression,
)
from pyomo.core.base.block import _BlockData
from pyomo.core.base.var import _VarData
from pyomo.core.base.var import _GeneralVarData, _VarData
from pyomo.core.base.constraint import _ConstraintData
from pyomo.repn.standard_repn import ( # pylint: disable=no-name-in-module
generate_standard_repn,
)
from pyomo.common.collections import ComponentSet
from pyomo.common.config import (
ConfigDict,
Expand All @@ -52,9 +67,10 @@
)
from pyomo.util.check_units import identify_inconsistent_units
from pyomo.contrib.incidence_analysis import IncidenceGraphInterface
from pyomo.core.expr.visitor import identify_variables
from pyomo.core.expr.visitor import identify_variables, StreamBasedExpressionVisitor
from pyomo.contrib.pynumero.interfaces.pyomo_nlp import PyomoNLP
from pyomo.contrib.pynumero.asl import AmplInterface
from pyomo.contrib.fbbt.fbbt import compute_bounds_on_expr
from pyomo.common.deprecation import deprecation_warning
from pyomo.common.errors import PyomoException

Expand Down Expand Up @@ -251,6 +267,14 @@ def svd_sparse(jacobian, number_singular_values):
description="Tolerance for raising a warning for small Jacobian values.",
),
)
CONFIG.declare(
"warn_for_evaluation_error_at_bounds",
ConfigValue(
default=True,
domain=bool,
description="If False, warnings will not be generated for things like log(x) with x >= 0",
),
)


SVDCONFIG = ConfigDict()
Expand Down Expand Up @@ -936,6 +960,13 @@ def _collect_structural_warnings(self):
if any(len(x) > 0 for x in [oc_var, oc_con]):
next_steps.append(self.display_overconstrained_set.__name__ + "()")

eval_warnings = self._collect_potential_eval_errors()
if len(eval_warnings) > 0:
warnings.append(
f"WARNING: Found {len(eval_warnings)} potential evaluation errors."
)
next_steps.append(self.display_potential_evaluation_errors.__name__ + "()")

return warnings, next_steps

def _collect_structural_cautions(self):
Expand Down Expand Up @@ -1303,6 +1334,51 @@ def report_numerical_issues(self, stream=None):
footer="=",
)

def _collect_potential_eval_errors(self) -> List[str]:
warnings = list()
for con in self.model.component_data_objects(
Constraint, active=True, descend_into=True
):
walker = _EvalErrorWalker(self.config)
con_warnings = walker.walk_expression(con.body)
for msg in con_warnings:
msg = f"{con.name}: " + msg
warnings.append(msg)
for obj in self.model.component_data_objects(
Objective, active=True, descend_into=True
):
walker = _EvalErrorWalker(self.config)
obj_warnings = walker.walk_expression(obj.expr)
for msg in obj_warnings:
msg = f"{obj.name}: " + msg
warnings.append(msg)

return warnings

def display_potential_evaluation_errors(self, stream=None):
"""
Prints constraints that may be prone to evaluation errors
(e.g., log of a negative number) based on variable bounds.
Args:
stream: an I/O object to write the output to (default = stdout)
Returns:
None
"""
if stream is None:
stream = sys.stdout

warnings = self._collect_potential_eval_errors()
_write_report_section(
stream=stream,
lines_list=warnings,
title=f"{len(warnings)} WARNINGS",
line_if_empty="No warnings found!",
header="=",
footer="=",
)

@document_kwargs_from_configdict(SVDCONFIG)
def prepare_svd_toolbox(self, **kwargs):
"""
Expand Down Expand Up @@ -1623,6 +1699,170 @@ def display_variables_in_constraint(self, constraint, stream=None):
)


def _get_bounds_with_inf(node: NumericExpression):
lb, ub = compute_bounds_on_expr(node)
if lb is None:
lb = -math.inf
if ub is None:
ub = math.inf
return lb, ub


def _check_eval_error_division(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[1])
if (config.warn_for_evaluation_error_at_bounds and (lb <= 0 <= ub)) or (
lb < 0 < ub
):
msg = f"Potential division by 0 in {node}; Denominator bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_pow(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
arg1, arg2 = node.args
lb1, ub1 = _get_bounds_with_inf(arg1)
lb2, ub2 = _get_bounds_with_inf(arg2)

integer_domains = ComponentSet([Binary, Integers])

integer_exponent = False
# if the exponent is an integer, there should not be any evaluation errors
if isinstance(arg2, _GeneralVarData) and arg2.domain in integer_domains:
# The exponent is an integer variable
# check if the base can be zero
integer_exponent = True
if lb2 == ub2 and lb2 == round(lb2):
# The exponent is fixed to an integer
integer_exponent = True
repn = generate_standard_repn(arg2, quadratic=True)
if (
repn.nonlinear_expr is None
and repn.constant == round(repn.constant)
and all(i.domain in integer_domains for i in repn.linear_vars)
and all(i[0].domain in integer_domains for i in repn.quadratic_vars)
and all(i[1].domain in integer_domains for i in repn.quadratic_vars)
and all(i == round(i) for i in repn.linear_coefs)
and all(i == round(i) for i in repn.quadratic_coefs)
):
# The exponent is a linear or quadratic expression containing
# only integer variables with integer coefficients
integer_exponent = True

if integer_exponent and (
(lb1 > 0 or ub1 < 0)
or (not config.warn_for_evaluation_error_at_bounds and (lb1 >= 0 or ub1 <= 0))
):
# life is good; the exponent is an integer and the base is nonzero
return None
elif integer_exponent and lb2 >= 0:
# life is good; the exponent is a nonnegative integer
return None

# if the base is positive, there should not be any evaluation errors
if lb1 > 0 or (not config.warn_for_evaluation_error_at_bounds and lb1 >= 0):
return None
if lb1 >= 0 and lb2 >= 0:
return None

msg = f"Potential evaluation error in {node}; "
msg += f"base bounds are ({lb1}, {ub1}); "
msg += f"exponent bounds are ({lb2}, {ub2})"
warn_list.append(msg)


def _check_eval_error_log(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if (config.warn_for_evaluation_error_at_bounds and lb <= 0) or lb < 0:
msg = f"Potential log of a non-positive number in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_tan(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node)
if not (math.isfinite(lb) and math.isfinite(ub)):
msg = f"{node} may evaluate to -inf or inf; Argument bounds are {_get_bounds_with_inf(node.args[0])}"
warn_list.append(msg)


def _check_eval_error_asin(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of asin outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_acos(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < -1 or ub > 1:
msg = f"Potential evaluation of acos outside [-1, 1] in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


def _check_eval_error_sqrt(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
lb, ub = _get_bounds_with_inf(node.args[0])
if lb < 0:
msg = f"Potential square root of a negative number in {node}; Argument bounds are ({lb}, {ub})"
warn_list.append(msg)


_unary_eval_err_handler = dict()
_unary_eval_err_handler["log"] = _check_eval_error_log
_unary_eval_err_handler["log10"] = _check_eval_error_log
_unary_eval_err_handler["tan"] = _check_eval_error_tan
_unary_eval_err_handler["asin"] = _check_eval_error_asin
_unary_eval_err_handler["acos"] = _check_eval_error_acos
_unary_eval_err_handler["sqrt"] = _check_eval_error_sqrt


def _check_eval_error_unary(
node: NumericExpression, warn_list: List[str], config: ConfigDict
):
if node.getname() in _unary_eval_err_handler:
_unary_eval_err_handler[node.getname()](node, warn_list, config)


_eval_err_handler = dict()
_eval_err_handler[DivisionExpression] = _check_eval_error_division
_eval_err_handler[NPV_DivisionExpression] = _check_eval_error_division
_eval_err_handler[PowExpression] = _check_eval_error_pow
_eval_err_handler[NPV_PowExpression] = _check_eval_error_pow
_eval_err_handler[UnaryFunctionExpression] = _check_eval_error_unary
_eval_err_handler[NPV_UnaryFunctionExpression] = _check_eval_error_unary


class _EvalErrorWalker(StreamBasedExpressionVisitor):
def __init__(self, config: ConfigDict):
super().__init__()
self._warn_list = list()
self._config = config

def exitNode(self, node, data):
"""
callback to be called as the visitor moves from the leaf
nodes back to the root node.
Args:
node: a pyomo expression node
data: not used in this walker
"""
if type(node) in _eval_err_handler:
_eval_err_handler[type(node)](node, self._warn_list, self._config)
return self._warn_list


# TODO: Rename and redirect once old DegeneracyHunter is removed
@document_kwargs_from_configdict(DHCONFIG)
class DegeneracyHunter2:
Expand Down
Loading

0 comments on commit d2f24ad

Please sign in to comment.