Skip to content

Commit

Permalink
Refactored the encoding of utility functions
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jun 14, 2024
1 parent d2065a9 commit 1a0956e
Showing 1 changed file with 188 additions and 34 deletions.
222 changes: 188 additions & 34 deletions pmml-rexp/src/main/java/org/jpmml/rexp/MaxLikConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,26 @@
import java.util.Objects;
import java.util.function.Function;

import com.google.common.collect.Iterables;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ConstantFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
Expand All @@ -51,7 +57,9 @@
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.converter.transformations.ExpTransformation;

public class MaxLikConverter extends ModelConverter<RGenericVector> {

Expand All @@ -71,6 +79,8 @@ public class MaxLikConverter extends ModelConverter<RGenericVector> {

private Map<?, Feature> utilityFunctionFeatures = null;

private Map<?, Feature> expUtilityFunctionFeatures = null;


public MaxLikConverter(RGenericVector maxLik){
super(maxLik);
Expand Down Expand Up @@ -134,20 +144,41 @@ public Feature apply(RExp rexp){
}

Map<Object, Feature> utilityFunctionFeatures = new LinkedHashMap<>();
Map<Object, Feature> expUtilityFunctionFeatures = new LinkedHashMap<>();

for(Object choice : choices){
RFunctionCall functionCall = utilityFunctions.get(choice);

Expression expression = toPMML(functionCall, variables, estimates, encoder);
Model model = encodeUtilityFunction(choice, functionCall, variables, estimates, encoder);

encoder.addTransformer(model);

Output output = model.getOutput();

DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("utility", choice), OpType.CONTINUOUS, DataType.DOUBLE, expression);
List<OutputField> outputFields = output.getOutputFields();
for(OutputField outputField : outputFields){
DerivedField derivedField = encoder.createDerivedField(model, outputField, true);

Feature feature = new ContinuousFeature(encoder, derivedField);
Feature feature = new ContinuousFeature(encoder, derivedField);

utilityFunctionFeatures.put(choice, feature);
ResultFeature resultFeature = outputField.getResultFeature();
switch(resultFeature){
case PREDICTED_VALUE:
{
utilityFunctionFeatures.put(choice, feature);
}
break;
case TRANSFORMED_VALUE:
{
expUtilityFunctionFeatures.put(choice, feature);
}
break;
default:
throw new IllegalArgumentException();
}
}

// XXX
encoder.addFeature(feature);
outputFields.clear();
}

String modelType = modelTypeList.getValue(0);
Expand Down Expand Up @@ -215,29 +246,29 @@ public Feature apply(RExp rexp){
Apply apply = ExpressionUtil.createApply(PMMLFunctions.SUM);

for(Object childChoice : childChoices){
Feature feature = utilityFunctionFeatures.get(childChoice);

Apply choiceApply;
Expression expression;

if(lambda.doubleValue() != 1d){
choiceApply = ExpressionUtil.createApply(PMMLFunctions.EXP,
Feature feature = utilityFunctionFeatures.get(childChoice);

expression = ExpressionUtil.createApply(PMMLFunctions.EXP,
ExpressionUtil.createApply(PMMLFunctions.DIVIDE, feature.ref(), ExpressionUtil.createConstant(lambda))
);
} else

{
choiceApply = ExpressionUtil.createApply(PMMLFunctions.EXP,
feature.ref()
);
Feature expFeature = expUtilityFunctionFeatures.get(childChoice);

expression = expFeature.ref();
} // End if

if(availabilities != null){
Feature availabilityFeature = availabilityFeatures.get(childChoice);

choiceApply = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, availabilityFeature.ref(), choiceApply);
expression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, availabilityFeature.ref(), expression);
}

apply.addExpressions(choiceApply);
apply.addExpressions(expression);
}

apply = ExpressionUtil.createApply(PMMLFunctions.LN, apply);
Expand All @@ -251,6 +282,14 @@ public Feature apply(RExp rexp){
Feature feature = new ContinuousFeature(encoder, derivedField);

utilityFunctionFeatures.put(choice, feature);

Apply expApply = ExpressionUtil.createApply(PMMLFunctions.EXP, feature.ref());

DerivedField expDerivedField = encoder.createDerivedField(FieldNameUtil.create("exp", feature), OpType.CONTINUOUS, DataType.DOUBLE, expApply);

Feature expFeature = new ContinuousFeature(encoder, expDerivedField);

expUtilityFunctionFeatures.put(choice, expFeature);
}
}
}
Expand All @@ -260,7 +299,9 @@ public Feature apply(RExp rexp){
}

this.availabilityFeatures = availabilityFeatures;

this.utilityFunctionFeatures = utilityFunctionFeatures;
this.expUtilityFunctionFeatures = expUtilityFunctionFeatures;
}

@Override
Expand All @@ -274,7 +315,9 @@ public Model encodeModel(Schema schema){
List<? extends Feature> features = schema.getFeatures();

Map<?, Feature> availabilityFeatures = this.availabilityFeatures;

Map<?, Feature> utilityFunctionFeatures = this.utilityFunctionFeatures;
Map<?, Feature> expUtilityFunctionFeatures = this.expUtilityFunctionFeatures;

List<RegressionTable> regressionTables = new ArrayList<>();

Expand All @@ -285,7 +328,7 @@ public Model encodeModel(Schema schema){
for(int i = 0; i < categoricalLabel.size(); i++){
Object choice = categoricalLabel.getValue(i);

Feature feature = toExpFeature(utilityFunctionFeatures.get(choice), encoder);
Feature feature = expUtilityFunctionFeatures.get(choice);

if(availabilityFeatures != null && !availabilityFeatures.isEmpty()){
Feature availabilityFeature = availabilityFeatures.get(choice);
Expand Down Expand Up @@ -325,35 +368,49 @@ public Model encodeModel(Schema schema){
for(int i = 0; i < categoricalLabel.size(); i++){
Object choice = categoricalLabel.getValue(i);

Apply apply = ExpressionUtil.createApply(PMMLFunctions.PRODUCT);
List<Expression> expressions = new ArrayList<>();

for(Object currentChoice = choice, nextChoice = nlTree.get(currentChoice); nextChoice != null; currentChoice = nextChoice, nextChoice = nlTree.get(currentChoice)){
Number lambda = lambdas.get(nextChoice);

Feature currentFeature = toExpFeature(utilityFunctionFeatures.get(currentChoice), encoder);
Feature nextFeature = toExpFeature(utilityFunctionFeatures.get(nextChoice), encoder);
Feature currentFeature = expUtilityFunctionFeatures.get(currentChoice);
Feature nextFeature = expUtilityFunctionFeatures.get(nextChoice);

DerivedField derivedField = encoder.ensureDerivedField(FieldNameUtil.create("term", currentChoice, nextChoice), OpType.CONTINUOUS, DataType.DOUBLE, () -> {
DerivedField derivedField = encoder.ensureDerivedField(FieldNameUtil.create("decisionFunction", currentChoice, nextChoice), OpType.CONTINUOUS, DataType.DOUBLE, () -> {
// Can't divide by zero.
// A division by integer zero raises an invalid result error.
// However, a division by floating-point zero succeeds - the result is a (positive-) infinity.
Apply choiceApply = ExpressionUtil.createApply(PMMLFunctions.IF,
Expression expression = ExpressionUtil.createApply(PMMLFunctions.IF,
ExpressionUtil.createApply(PMMLFunctions.EQUAL, nextFeature.ref(), ExpressionUtil.createConstant(0d)),
ExpressionUtil.createConstant(0d),
ExpressionUtil.createApply(PMMLFunctions.DIVIDE, currentFeature.ref(), nextFeature.ref())
);

if(lambda.doubleValue() != 1d){
choiceApply = ExpressionUtil.createApply(PMMLFunctions.POW, choiceApply, ExpressionUtil.createConstant(1d / lambda.doubleValue()));
expression = ExpressionUtil.createApply(PMMLFunctions.POW, expression, ExpressionUtil.createConstant(1d / lambda.doubleValue()));
}

return choiceApply;
return expression;
});

apply.addExpressions(new FieldRef(derivedField));
expressions.add(new FieldRef(derivedField));
}

DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("term", choice), OpType.CONTINUOUS, DataType.DOUBLE, apply);
Expression expression;

if(expressions.size() == 1){
expression = Iterables.getOnlyElement(expressions);
} else

{
Apply apply = ExpressionUtil.createApply(PMMLFunctions.PRODUCT);

(apply.getExpressions()).addAll(expressions);

expression = apply;
}

DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("decisionFunction", choice), OpType.CONTINUOUS, DataType.DOUBLE, expression);

Feature feature = new ContinuousFeature(encoder, derivedField);

Expand Down Expand Up @@ -381,6 +438,109 @@ public Model encodeModel(Schema schema){
return regressionModel;
}

private RegressionModel encodeUtilityFunction(Object choice, RFunctionCall functionCall, Map<String, RExp> variables, Map<String, Double> estimates, RExpEncoder encoder){
List<Feature> features = new ArrayList<>();
List<Number> coefficients = new ArrayList<>();

encodeTerm(choice, functionCall, MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder);

RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null), null)
.setNormalizationMethod(RegressionModel.NormalizationMethod.NONE)
.addRegressionTables(RegressionModelUtil.createRegressionTable(features, coefficients, null))
.setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create("utility", choice), OpType.CONTINUOUS, DataType.DOUBLE, new ExpTransformation()));

return regressionModel;
}

private void encodeTerm(Object choice, RExp rexp, int sign, Map<String, RExp> variables, Map<String, Double> estimates, List<Feature> features, List<Number> coefficients, RExpEncoder encoder){

if(rexp instanceof RFunctionCall){
RFunctionCall functionCall = (RFunctionCall)rexp;

if(functionCall.hasValue("+")){
Iterator<RExp> it = functionCall.argumentValues();

encodeTerm(choice, it.next(), MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder);
encodeTerm(choice, it.next(), MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder);

return;
} else

if(functionCall.hasValue("-")){
Iterator<RExp> it = functionCall.argumentValues();

encodeTerm(choice, it.next(), MaxLikConverter.SIGN_PLUS, variables, estimates, features, coefficients, encoder);
encodeTerm(choice, it.next(), MaxLikConverter.SIGN_MINUS, variables, estimates, features, coefficients, encoder);

return;
}
} else

if(rexp instanceof RString){
RString string = (RString)rexp;

if(estimates.containsKey(string.getValue())){
Number beta = estimates.get(string.getValue());

features.add(new ConstantFeature(encoder, beta));
coefficients.add(sign);

return;
}
}

Feature feature;
Number coefficient = null;

if(rexp instanceof RFunctionCall){
RFunctionCall functionCall = (RFunctionCall)rexp;

if(functionCall.hasValue("*")){
Iterator<RExp> it = functionCall.argumentValues();

RExp firstArgValue = it.next();
RExp secondArgValue = it.next();

if(firstArgValue instanceof RString){
RString string = (RString)firstArgValue;

if(estimates.containsKey(string.getValue())){
coefficient = estimates.get(string.getValue());

rexp = secondArgValue;
}
}
}
}

Expression expression = toPMML(rexp, variables, estimates, encoder);

if(expression instanceof FieldRef){
FieldRef fieldRef = (FieldRef)expression;

Field<?> field = encoder.getField(fieldRef.requireField());

feature = new ContinuousFeature(encoder, field);
} else

{
DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("term", choice, features.size()), OpType.CONTINUOUS, DataType.DOUBLE, expression);

feature = new ContinuousFeature(encoder, derivedField);
} // End if

if(coefficient != null){
coefficient = ValueUtil.multiply(MathContext.DOUBLE, sign, coefficient);
} else

{
coefficient = sign;
}

features.add(feature);
coefficients.add(coefficient);
}

private void parseApolloProbabilities(){
RGenericVector maxLik = getObject();

Expand Down Expand Up @@ -540,15 +700,6 @@ private Object matchListAssignment(RExp argValue, String variableName){
return null;
}

static
private ContinuousFeature toExpFeature(Feature feature, PMMLEncoder encoder){
DerivedField derivedField = encoder.ensureDerivedField(FieldNameUtil.create(PMMLFunctions.EXP, feature), OpType.CONTINUOUS, DataType.DOUBLE, () -> {
return ExpressionUtil.createApply(PMMLFunctions.EXP, feature.ref());
});

return new ContinuousFeature(encoder, derivedField);
}

static
private Expression toPMML(RExp argumentValue, Map<String, RExp> variables, Map<String, Double> estimates, RExpEncoder encoder){

Expand Down Expand Up @@ -700,4 +851,7 @@ private List<?> parseVector(RFunctionCall functionCall){

private static final String TYPE_MNL = "MNL";
private static final String TYPE_NL = "NL";

private static final int SIGN_MINUS = -1;
private static final int SIGN_PLUS = 1;
}

0 comments on commit 1a0956e

Please sign in to comment.