Skip to content

Commit

Permalink
Extracted ModelTranslator#ensureFunctionalInterface(Class, Translatio…
Browse files Browse the repository at this point in the history
…nContext) utility method
  • Loading branch information
vruusmann committed Jun 26, 2023
1 parent 30c025e commit 9ec3b8f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JClassAlreadyExistsException;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
Expand Down Expand Up @@ -488,6 +490,32 @@ public JDefinedClass ensureArgumentsType(TranslationContext context){
return argumentsClazz;
}

static
public JDefinedClass ensureFunctionalInterface(Class<?> returnType, TranslationContext context){
JDefinedClass owner = context.getOwner(JavaModel.class);

JDefinedClass funcInterface = JCodeModelUtil.getNestedClass(owner, "JavaModelFunction");
if(funcInterface != null){
return funcInterface;
}

try {
funcInterface = owner._interface("JavaModelFunction");
} catch(JClassAlreadyExistsException jcaee){
throw new IllegalArgumentException(jcaee);
}

funcInterface.annotate(FunctionalInterface.class);

JClass argumentsClazz = ensureArgumentsType(context);

JMethod method = funcInterface.method(Modifiers.PUBLIC_ABSTRACT, returnType, "apply");

method.param(argumentsClazz, "arguments");

return funcInterface;
}

static
private void enhanceFieldInfo(FieldInfo fieldInfo, MiningSchema miningSchema, Map<String, Field<?>> bodyFields, FieldInfoMap fieldInfos, FunctionInvocationContext context){
Field<?> field = fieldInfo.getField();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@

import com.google.common.collect.Iterables;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JClassAlreadyExistsException;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JForLoop;
import com.sun.codemodel.JInvocation;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JTypeVar;
import com.sun.codemodel.JVar;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
Expand All @@ -60,7 +58,6 @@
import org.jpmml.translator.FieldInfoMap;
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.JBinaryFileInitializer;
import org.jpmml.translator.JCodeModelUtil;
import org.jpmml.translator.JDirectInitializer;
import org.jpmml.translator.JVarBuilder;
import org.jpmml.translator.MethodScope;
Expand Down Expand Up @@ -276,7 +273,7 @@ private void translateValueAggregatorSegmentation(Segmentation segmentation, Tra
pullUpDerivedFields(miningModel, treeModel);
}

JDefinedClass treeFunctionInterface = ensureTreeModelFuncInterface(context);
JDefinedClass modelFuncInterface = ensureFunctionalInterface(int.class, context);

JBinaryFileInitializer resourceInitializer = new JBinaryFileInitializer(IdentifierUtil.create(Segmentation.class.getSimpleName(), segmentation) + ".data", context);

Expand All @@ -296,7 +293,7 @@ private void translateValueAggregatorSegmentation(Segmentation segmentation, Tra

JDirectInitializer codeInitializer = new JDirectInitializer(context);

JFieldVar methodsVar = codeInitializer.initLambdas(IdentifierUtil.create("methods", segmentation), treeFunctionInterface.narrow(ensureArgumentsType(context)), methods);
JFieldVar methodsVar = codeInitializer.initLambdas(IdentifierUtil.create("methods", segmentation), modelFuncInterface, methods);

JBlock block = context.block();

Expand Down Expand Up @@ -445,7 +442,7 @@ public ValueFactory<Number> getValueFactory(){
pullUpDerivedFields(miningModel, treeModel);
}

JDefinedClass treeModelFuncInterface = ensureTreeModelFuncInterface(context);
JDefinedClass modelFuncInterface = ensureFunctionalInterface(int.class, context);

JBinaryFileInitializer resourceInitializer = new JBinaryFileInitializer(IdentifierUtil.create(Segmentation.class.getSimpleName(), segmentation) + ".data", context);

Expand All @@ -465,7 +462,7 @@ public ValueFactory<Number> getValueFactory(){

JDirectInitializer codeInitializer = new JDirectInitializer(context);

JFieldVar methodsVar = codeInitializer.initLambdas(IdentifierUtil.create("methods", segmentation), treeModelFuncInterface.narrow(ensureArgumentsType(context)), methods);
JFieldVar methodsVar = codeInitializer.initLambdas(IdentifierUtil.create("methods", segmentation), modelFuncInterface, methods);

JFieldVar categoriesVar = codeInitializer.initTargetCategories("targetCategories", Arrays.asList(categories));

Expand Down Expand Up @@ -524,31 +521,6 @@ public ValueFactory<Number> getValueFactory(){
context._return(context._new(ProbabilityDistribution.class, valueMapInit));
}

private JDefinedClass ensureTreeModelFuncInterface(TranslationContext context){
JDefinedClass owner = context.getOwner();

JDefinedClass definedClazz = JCodeModelUtil.getNestedClass(owner, "TreeModelFunction");
if(definedClazz != null){
return definedClazz;
}

try {
definedClazz = owner._interface("TreeModelFunction");
} catch(JClassAlreadyExistsException jcaee){
throw new IllegalArgumentException(jcaee);
}

definedClazz.annotate(FunctionalInterface.class);

JTypeVar typeVar = definedClazz.generify("T");

JMethod method = definedClazz.method(Modifiers.PUBLIC_ABSTRACT, int.class, "apply");

method.param(typeVar, "value");

return definedClazz;
}

private <S> JMethod createEvaluatorMethod(TreeModel treeModel, Node node, Scorer<S> scorer, FieldInfoMap fieldInfos, TranslationContext context){
JDefinedClass treeModelClazz = PMMLObjectUtil.createMemberClass(Modifiers.PRIVATE_STATIC_FINAL, IdentifierUtil.create(TreeModel.class.getSimpleName(), treeModel), context);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@
import com.google.common.collect.Multimaps;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JClassAlreadyExistsException;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JForEach;
import com.sun.codemodel.JForLoop;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JTypeVar;
import com.sun.codemodel.JVar;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
Expand Down Expand Up @@ -73,7 +71,6 @@
import org.jpmml.translator.FunctionInvocation;
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.JBinaryFileInitializer;
import org.jpmml.translator.JCodeModelUtil;
import org.jpmml.translator.JDirectInitializer;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
Expand Down Expand Up @@ -321,7 +318,7 @@ public ValueBuilder translateRegressionTable(RegressionTable regressionTable, Fi
Map<String, List<CategoricalPredictor>> fieldCategoricalPredictors = regressionTable.getCategoricalPredictors().stream()
.collect(Collectors.groupingBy(categoricalPredictor -> categoricalPredictor.requireField(), Collectors.toList()));

JDefinedClass regressionModelFunctionInterface = ensureRegressionModelFuncInterface(context);
JDefinedClass modelFuncInterface = ensureFunctionalInterface(Number.class, context);

List<JMethod> evaluateCategoryMethods = new ArrayList<>();

Expand Down Expand Up @@ -367,7 +364,7 @@ public ValueBuilder translateRegressionTable(RegressionTable regressionTable, Fi

JDirectInitializer codeInitializer = new JDirectInitializer(context);

JFieldVar categoryMethodsVar = codeInitializer.initLambdas(IdentifierUtil.create("categoryMethods", regressionTable), regressionModelFunctionInterface.narrow(ensureArgumentsType(context)), evaluateCategoryMethods);
JFieldVar categoryMethodsVar = codeInitializer.initLambdas(IdentifierUtil.create("categoryMethods", regressionTable), modelFuncInterface, evaluateCategoryMethods);

JBlock block = context.block();

Expand Down Expand Up @@ -566,32 +563,6 @@ public Number apply(FunctionInvocationPredictor tfTerm){
}
}

static
private JDefinedClass ensureRegressionModelFuncInterface(TranslationContext context){
JDefinedClass owner = context.getOwner();

JDefinedClass definedClazz = JCodeModelUtil.getNestedClass(owner, "RegressionModelFunction");
if(definedClazz != null){
return definedClazz;
}

try {
definedClazz = owner._interface("RegressionModelFunction");
} catch(JClassAlreadyExistsException jcaee){
throw new IllegalArgumentException(jcaee);
}

definedClazz.annotate(FunctionalInterface.class);

JTypeVar typeVar = definedClazz.generify("T");

JMethod method = definedClazz.method(Modifiers.PUBLIC_ABSTRACT, Number.class, "apply");

method.param(typeVar, "value");

return definedClazz;
}

static
private class FunctionInvocationPredictor {

Expand Down

0 comments on commit 9ec3b8f

Please sign in to comment.