From 1a0956e9293678f1ea9f575edf1ab4d271fca1ad Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Fri, 14 Jun 2024 14:25:44 +0300 Subject: [PATCH] Refactored the encoding of utility functions --- .../java/org/jpmml/rexp/MaxLikConverter.java | 222 +++++++++++++++--- 1 file changed, 188 insertions(+), 34 deletions(-) diff --git a/pmml-rexp/src/main/java/org/jpmml/rexp/MaxLikConverter.java b/pmml-rexp/src/main/java/org/jpmml/rexp/MaxLikConverter.java index 7f945d8..cabac5f 100644 --- a/pmml-rexp/src/main/java/org/jpmml/rexp/MaxLikConverter.java +++ b/pmml-rexp/src/main/java/org/jpmml/rexp/MaxLikConverter.java @@ -28,6 +28,7 @@ import java.util.Objects; import java.util.function.Function; +import com.google.common.collect.Iterables; import org.dmg.pmml.Apply; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; @@ -35,13 +36,18 @@ import org.dmg.pmml.Expression; import org.dmg.pmml.Field; import org.dmg.pmml.FieldRef; +import org.dmg.pmml.MathContext; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.Model; import org.dmg.pmml.OpType; +import org.dmg.pmml.Output; +import org.dmg.pmml.OutputField; import org.dmg.pmml.PMMLFunctions; +import org.dmg.pmml.ResultFeature; import org.dmg.pmml.regression.RegressionModel; import org.dmg.pmml.regression.RegressionTable; import org.jpmml.converter.CategoricalLabel; +import org.jpmml.converter.ConstantFeature; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.ExpressionUtil; import org.jpmml.converter.Feature; @@ -51,7 +57,9 @@ import org.jpmml.converter.PMMLEncoder; import org.jpmml.converter.Schema; import org.jpmml.converter.TypeUtil; +import org.jpmml.converter.ValueUtil; import org.jpmml.converter.regression.RegressionModelUtil; +import org.jpmml.converter.transformations.ExpTransformation; public class MaxLikConverter extends ModelConverter { @@ -71,6 +79,8 @@ public class MaxLikConverter extends ModelConverter { private Map utilityFunctionFeatures = null; + private Map expUtilityFunctionFeatures = null; + public MaxLikConverter(RGenericVector maxLik){ super(maxLik); @@ -134,20 +144,41 @@ public Feature apply(RExp rexp){ } Map utilityFunctionFeatures = new LinkedHashMap<>(); + Map expUtilityFunctionFeatures = new LinkedHashMap<>(); for(Object choice : choices){ RFunctionCall functionCall = utilityFunctions.get(choice); - Expression expression = toPMML(functionCall, variables, estimates, encoder); + Model model = encodeUtilityFunction(choice, functionCall, variables, estimates, encoder); + + encoder.addTransformer(model); + + Output output = model.getOutput(); - DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("utility", choice), OpType.CONTINUOUS, DataType.DOUBLE, expression); + List outputFields = output.getOutputFields(); + for(OutputField outputField : outputFields){ + DerivedField derivedField = encoder.createDerivedField(model, outputField, true); - Feature feature = new ContinuousFeature(encoder, derivedField); + Feature feature = new ContinuousFeature(encoder, derivedField); - utilityFunctionFeatures.put(choice, feature); + ResultFeature resultFeature = outputField.getResultFeature(); + switch(resultFeature){ + case PREDICTED_VALUE: + { + utilityFunctionFeatures.put(choice, feature); + } + break; + case TRANSFORMED_VALUE: + { + expUtilityFunctionFeatures.put(choice, feature); + } + break; + default: + throw new IllegalArgumentException(); + } + } - // XXX - encoder.addFeature(feature); + outputFields.clear(); } String modelType = modelTypeList.getValue(0); @@ -215,29 +246,29 @@ public Feature apply(RExp rexp){ Apply apply = ExpressionUtil.createApply(PMMLFunctions.SUM); for(Object childChoice : childChoices){ - Feature feature = utilityFunctionFeatures.get(childChoice); - - Apply choiceApply; + Expression expression; if(lambda.doubleValue() != 1d){ - choiceApply = ExpressionUtil.createApply(PMMLFunctions.EXP, + Feature feature = utilityFunctionFeatures.get(childChoice); + + expression = ExpressionUtil.createApply(PMMLFunctions.EXP, ExpressionUtil.createApply(PMMLFunctions.DIVIDE, feature.ref(), ExpressionUtil.createConstant(lambda)) ); } else { - choiceApply = ExpressionUtil.createApply(PMMLFunctions.EXP, - feature.ref() - ); + Feature expFeature = expUtilityFunctionFeatures.get(childChoice); + + expression = expFeature.ref(); } // End if if(availabilities != null){ Feature availabilityFeature = availabilityFeatures.get(childChoice); - choiceApply = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, availabilityFeature.ref(), choiceApply); + expression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, availabilityFeature.ref(), expression); } - apply.addExpressions(choiceApply); + apply.addExpressions(expression); } apply = ExpressionUtil.createApply(PMMLFunctions.LN, apply); @@ -251,6 +282,14 @@ public Feature apply(RExp rexp){ Feature feature = new ContinuousFeature(encoder, derivedField); utilityFunctionFeatures.put(choice, feature); + + Apply expApply = ExpressionUtil.createApply(PMMLFunctions.EXP, feature.ref()); + + DerivedField expDerivedField = encoder.createDerivedField(FieldNameUtil.create("exp", feature), OpType.CONTINUOUS, DataType.DOUBLE, expApply); + + Feature expFeature = new ContinuousFeature(encoder, expDerivedField); + + expUtilityFunctionFeatures.put(choice, expFeature); } } } @@ -260,7 +299,9 @@ public Feature apply(RExp rexp){ } this.availabilityFeatures = availabilityFeatures; + this.utilityFunctionFeatures = utilityFunctionFeatures; + this.expUtilityFunctionFeatures = expUtilityFunctionFeatures; } @Override @@ -274,7 +315,9 @@ public Model encodeModel(Schema schema){ List features = schema.getFeatures(); Map availabilityFeatures = this.availabilityFeatures; + Map utilityFunctionFeatures = this.utilityFunctionFeatures; + Map expUtilityFunctionFeatures = this.expUtilityFunctionFeatures; List regressionTables = new ArrayList<>(); @@ -285,7 +328,7 @@ public Model encodeModel(Schema schema){ for(int i = 0; i < categoricalLabel.size(); i++){ Object choice = categoricalLabel.getValue(i); - Feature feature = toExpFeature(utilityFunctionFeatures.get(choice), encoder); + Feature feature = expUtilityFunctionFeatures.get(choice); if(availabilityFeatures != null && !availabilityFeatures.isEmpty()){ Feature availabilityFeature = availabilityFeatures.get(choice); @@ -325,35 +368,49 @@ public Model encodeModel(Schema schema){ for(int i = 0; i < categoricalLabel.size(); i++){ Object choice = categoricalLabel.getValue(i); - Apply apply = ExpressionUtil.createApply(PMMLFunctions.PRODUCT); + List expressions = new ArrayList<>(); for(Object currentChoice = choice, nextChoice = nlTree.get(currentChoice); nextChoice != null; currentChoice = nextChoice, nextChoice = nlTree.get(currentChoice)){ Number lambda = lambdas.get(nextChoice); - Feature currentFeature = toExpFeature(utilityFunctionFeatures.get(currentChoice), encoder); - Feature nextFeature = toExpFeature(utilityFunctionFeatures.get(nextChoice), encoder); + Feature currentFeature = expUtilityFunctionFeatures.get(currentChoice); + Feature nextFeature = expUtilityFunctionFeatures.get(nextChoice); - DerivedField derivedField = encoder.ensureDerivedField(FieldNameUtil.create("term", currentChoice, nextChoice), OpType.CONTINUOUS, DataType.DOUBLE, () -> { + DerivedField derivedField = encoder.ensureDerivedField(FieldNameUtil.create("decisionFunction", currentChoice, nextChoice), OpType.CONTINUOUS, DataType.DOUBLE, () -> { // Can't divide by zero. // A division by integer zero raises an invalid result error. // However, a division by floating-point zero succeeds - the result is a (positive-) infinity. - Apply choiceApply = ExpressionUtil.createApply(PMMLFunctions.IF, + Expression expression = ExpressionUtil.createApply(PMMLFunctions.IF, ExpressionUtil.createApply(PMMLFunctions.EQUAL, nextFeature.ref(), ExpressionUtil.createConstant(0d)), ExpressionUtil.createConstant(0d), ExpressionUtil.createApply(PMMLFunctions.DIVIDE, currentFeature.ref(), nextFeature.ref()) ); if(lambda.doubleValue() != 1d){ - choiceApply = ExpressionUtil.createApply(PMMLFunctions.POW, choiceApply, ExpressionUtil.createConstant(1d / lambda.doubleValue())); + expression = ExpressionUtil.createApply(PMMLFunctions.POW, expression, ExpressionUtil.createConstant(1d / lambda.doubleValue())); } - return choiceApply; + return expression; }); - apply.addExpressions(new FieldRef(derivedField)); + expressions.add(new FieldRef(derivedField)); } - DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("term", choice), OpType.CONTINUOUS, DataType.DOUBLE, apply); + Expression expression; + + if(expressions.size() == 1){ + expression = Iterables.getOnlyElement(expressions); + } else + + { + Apply apply = ExpressionUtil.createApply(PMMLFunctions.PRODUCT); + + (apply.getExpressions()).addAll(expressions); + + expression = apply; + } + + DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("decisionFunction", choice), OpType.CONTINUOUS, DataType.DOUBLE, expression); Feature feature = new ContinuousFeature(encoder, derivedField); @@ -381,6 +438,109 @@ public Model encodeModel(Schema schema){ return regressionModel; } + private RegressionModel encodeUtilityFunction(Object choice, RFunctionCall functionCall, Map variables, Map estimates, RExpEncoder encoder){ + List features = new ArrayList<>(); + List coefficients = new ArrayList<>(); + + encodeTerm(choice, functionCall, MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder); + + RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null), null) + .setNormalizationMethod(RegressionModel.NormalizationMethod.NONE) + .addRegressionTables(RegressionModelUtil.createRegressionTable(features, coefficients, null)) + .setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create("utility", choice), OpType.CONTINUOUS, DataType.DOUBLE, new ExpTransformation())); + + return regressionModel; + } + + private void encodeTerm(Object choice, RExp rexp, int sign, Map variables, Map estimates, List features, List coefficients, RExpEncoder encoder){ + + if(rexp instanceof RFunctionCall){ + RFunctionCall functionCall = (RFunctionCall)rexp; + + if(functionCall.hasValue("+")){ + Iterator it = functionCall.argumentValues(); + + encodeTerm(choice, it.next(), MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder); + encodeTerm(choice, it.next(), MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder); + + return; + } else + + if(functionCall.hasValue("-")){ + Iterator it = functionCall.argumentValues(); + + encodeTerm(choice, it.next(), MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder); + encodeTerm(choice, it.next(), MaxLikConverter.SIGN_MINUS, variables, estimates, features, coefficients, encoder); + + return; + } + } else + + if(rexp instanceof RString){ + RString string = (RString)rexp; + + if(estimates.containsKey(string.getValue())){ + Number beta = estimates.get(string.getValue()); + + features.add(new ConstantFeature(encoder, beta)); + coefficients.add(sign); + + return; + } + } + + Feature feature; + Number coefficient = null; + + if(rexp instanceof RFunctionCall){ + RFunctionCall functionCall = (RFunctionCall)rexp; + + if(functionCall.hasValue("*")){ + Iterator it = functionCall.argumentValues(); + + RExp firstArgValue = it.next(); + RExp secondArgValue = it.next(); + + if(firstArgValue instanceof RString){ + RString string = (RString)firstArgValue; + + if(estimates.containsKey(string.getValue())){ + coefficient = estimates.get(string.getValue()); + + rexp = secondArgValue; + } + } + } + } + + Expression expression = toPMML(rexp, variables, estimates, encoder); + + if(expression instanceof FieldRef){ + FieldRef fieldRef = (FieldRef)expression; + + Field field = encoder.getField(fieldRef.requireField()); + + feature = new ContinuousFeature(encoder, field); + } else + + { + DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("term", choice, features.size()), OpType.CONTINUOUS, DataType.DOUBLE, expression); + + feature = new ContinuousFeature(encoder, derivedField); + } // End if + + if(coefficient != null){ + coefficient = ValueUtil.multiply(MathContext.DOUBLE, sign, coefficient); + } else + + { + coefficient = sign; + } + + features.add(feature); + coefficients.add(coefficient); + } + private void parseApolloProbabilities(){ RGenericVector maxLik = getObject(); @@ -540,15 +700,6 @@ private Object matchListAssignment(RExp argValue, String variableName){ return null; } - static - private ContinuousFeature toExpFeature(Feature feature, PMMLEncoder encoder){ - DerivedField derivedField = encoder.ensureDerivedField(FieldNameUtil.create(PMMLFunctions.EXP, feature), OpType.CONTINUOUS, DataType.DOUBLE, () -> { - return ExpressionUtil.createApply(PMMLFunctions.EXP, feature.ref()); - }); - - return new ContinuousFeature(encoder, derivedField); - } - static private Expression toPMML(RExp argumentValue, Map variables, Map estimates, RExpEncoder encoder){ @@ -700,4 +851,7 @@ private List parseVector(RFunctionCall functionCall){ private static final String TYPE_MNL = "MNL"; private static final String TYPE_NL = "NL"; + + private static final int SIGN_MINUS = -1; + private static final int SIGN_PLUS = 1; } \ No newline at end of file