Skip to content

Commit

Permalink
Updated JPMML-Python dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Aug 2, 2024
1 parent 6b68ab3 commit 3679e74
Show file tree
Hide file tree
Showing 17 changed files with 37 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
Expand All @@ -66,7 +67,7 @@ public <E extends Estimator & HasExplainableBooster> List<Feature> encodeExplain
ClassDictUtil.checkSize(bins, featureTypesIn);
ClassDictUtil.checkSize(termFeatures, termScores);

PMMLEncoder encoder = schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

Expand Down Expand Up @@ -95,7 +96,7 @@ public <E extends Estimator & HasExplainableBooster> List<Feature> encodeExplain
}

static
private List<CategoricalFeature> encodeBinLevelFeatures(Feature feature, List<?> binLevels, String featureType, PMMLEncoder encoder){
private List<CategoricalFeature> encodeBinLevelFeatures(Feature feature, List<?> binLevels, String featureType, ModelEncoder encoder){
List<CategoricalFeature> result = new ArrayList<>();

for(int i = 0; i < binLevels.size(); i++){
Expand All @@ -121,7 +122,7 @@ private List<CategoricalFeature> encodeBinLevelFeatures(Feature feature, List<?>
}

static
private IndexFeature binContinuous(Feature feature, HasArray binLevel, Integer binLevelIndex, PMMLEncoder encoder){
private IndexFeature binContinuous(Feature feature, HasArray binLevel, Integer binLevelIndex, ModelEncoder encoder){
ContinuousFeature continuousFeature = feature.toContinuousFeature();

Discretize discretize = new Discretize(continuousFeature.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ private Feature encodeFeature(String name, List<Vector> key, SkLearnEncoder enco

{
if(minusExpression != null){
expression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, ExpressionUtil.createConstant(-1d), minusExpression);
expression = ExpressionUtil.toNegative(minusExpression);
} else

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
import org.jpmml.converter.FeatureList;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.OrdinalLabel;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
Expand Down Expand Up @@ -200,7 +200,7 @@ public Label encodeLabel(List<String> names, SkLearnEncoder encoder){
public Model encodeModel(Schema schema){
Converter<?> converter = createConverter();

PMMLEncoder encoder = schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package lightgbm.sklearn;

import java.io.Reader;
import java.io.StringReader;
import java.util.List;

Expand Down Expand Up @@ -55,7 +56,7 @@ public GBDT getGBDT(){
private GBDT loadGBDT(){
String handle = getHandle();

try(StringReader reader = new StringReader(handle)){
try(Reader reader = new StringReader(handle)){
List<String> lines = CharStreams.readLines(reader);

return LightGBMUtil.loadGBDT(lines.iterator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public <E extends Estimator & HasBooster> ObjectiveFunction getObjectiveFunction
public <E extends Estimator & HasBooster & HasLightGBMOptions> MiningModel encodeModel(E estimator, Schema schema){
GBDT gbdt = getGBDT(estimator);

ModelEncoder encoder = (ModelEncoder)schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();

Map<String, ?> options = getOptions(gbdt, estimator);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import org.dmg.pmml.PMML;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.statsmodels.InterceptFeature;
import org.jpmml.statsmodels.StatsModelsEncoder;
import sklearn.Estimator;
Expand All @@ -47,7 +47,7 @@ public <E extends Estimator & HasResults> PMML encodePMML(E estimator){

static
public Schema addConstant(Schema schema){
SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
Label label = schema.getLabel();
List<Feature> features = (List)schema.getFeatures();

Expand Down
5 changes: 2 additions & 3 deletions pmml-sklearn/src/main/java/sklearn/Calibrator.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sklearn.SkLearnEncoder;
Expand All @@ -47,11 +46,11 @@ public Calibrator(String module, String name){

@Override
public RegressionModel encodeModel(Schema schema){
PMMLEncoder encoder = schema.getEncoder();
SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

features = encodeFeatures((List)features, (SkLearnEncoder)encoder);
features = encodeFeatures((List)features, encoder);

Feature feature = Iterables.getOnlyElement(features);

Expand Down
2 changes: 1 addition & 1 deletion pmml-sklearn/src/main/java/sklearn/Estimator.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public void addFeatureImportances(Model model, Schema schema){
featureImportances = getFeatureImportances();
}

ModelEncoder encoder = (ModelEncoder)schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
List<? extends Feature> features = schema.getFeatures();

if(featureImportances != null){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod

Feature feature = features.get(0);

Apply apply = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY,
ExpressionUtil.createConstant(-1),
Apply apply = (Apply)ExpressionUtil.toNegative(
ExpressionUtil.createApply(PMMLFunctions.ADD,
ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, ExpressionUtil.createConstant(a), feature.ref()),
ExpressionUtil.createConstant(b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public boolean isFinalResult(){
public Expression createExpression(FieldRef fieldRef){
Number offset = estimator.getOffset();

return ExpressionUtil.createApply(PMMLFunctions.SUBTRACT, ExpressionUtil.createConstant(-offset.doubleValue()), ExpressionUtil.createApply(PMMLFunctions.POW, ExpressionUtil.createConstant(2d), ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, ExpressionUtil.createConstant(-1d), fieldRef)));
return ExpressionUtil.createApply(PMMLFunctions.SUBTRACT, ExpressionUtil.createConstant(-offset.doubleValue()), ExpressionUtil.createApply(PMMLFunctions.POW, ExpressionUtil.createConstant(2d), ExpressionUtil.toNegative(fieldRef)));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;
Expand Down Expand Up @@ -109,7 +109,7 @@ private Model encodeMultinomialModel(Schema schema){
List<Number> coef = getCoef();
List<Number> intercept = getIntercept();

PMMLEncoder encoder = schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,9 @@ private Apply encodeYeoJohnsonTransform(ContinuousFeature continuousFeature, Num

{
// "-ln(-$name + 1)"
falseApply = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY,
ExpressionUtil.createConstant(-1d),
falseApply = (Apply)ExpressionUtil.toNegative(
ExpressionUtil.createApply(PMMLFunctions.LN1P,
ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, ExpressionUtil.createConstant(-1d), continuousFeature.ref())
ExpressionUtil.toNegative(continuousFeature.ref())
)
);
}
Expand Down
25 changes: 4 additions & 21 deletions pmml-sklearn/src/main/java/sklearn2pmml/CustomizationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@

import com.google.common.collect.Iterables;
import jakarta.xml.bind.Binder;
import jakarta.xml.bind.JAXBContext;
import jakarta.xml.bind.JAXBElement;
import jakarta.xml.bind.Marshaller;
import jakarta.xml.bind.annotation.XmlRootElement;
Expand All @@ -68,8 +67,6 @@ private CustomizationUtil(){

static
public void customize(Model model, List<? extends Customization> customizations) throws Exception {
JAXBContext jaxbContext = JAXBUtil.getContext();

DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
documentBuilderFactory.setNamespaceAware(true);

Expand All @@ -81,7 +78,7 @@ public void customize(Model model, List<? extends Customization> customizations)

NamespaceContext namespaceContext = new DocumentNamespaceContext(document);

Binder<Node> binder = jaxbContext.createBinder(Node.class);
Binder<Node> binder = JAXBUtil.createBinder(Node.class);
binder.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.FALSE);
binder.setProperty(Marshaller.JAXB_FRAGMENT, Boolean.TRUE);

Expand Down Expand Up @@ -302,28 +299,14 @@ private Field findField(PMMLObject parent, PMMLObject child){

static
private void addListElement(Field field, PMMLObject parent, PMMLObject child) throws ReflectiveOperationException {
Class<? extends PMMLObject> parentClazz = parent.getClass();
Class<? extends PMMLObject> childClazz = child.getClass();
@SuppressWarnings("unused")
List<?> fieldValue = (List<?>)ReflectionUtil.getFieldValue(field, parent);

ParameterizedType listType = (ParameterizedType)field.getGenericType();

Class<?> listElementType = (Class<?>)listType.getActualTypeArguments()[0];

Method getterMethod = ReflectionUtil.getGetterMethod(field);

String name = getterMethod.getName();
if(name.startsWith("get")){
name = "add" + name.substring(3);
} else

{
throw new IllegalArgumentException();
}

// See https://stackoverflow.com/a/1679444
Class<?> valueArrayClazz = Class.forName("[L" + childClazz.getName() + ";");

Method appenderMethod = parentClazz.getMethod(name, valueArrayClazz);
Method appenderMethod = ReflectionUtil.getAppenderMethod(field);

// See https://stackoverflow.com/a/36125994
Object[] valueArray = (Object[])Array.newInstance(listElementType, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public VisitorAction visit(Node node){
treeModelVisitor.applyTo(treeModel);
}

ModelEncoder encoder = (ModelEncoder)schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
Label label = schema.getLabel();

ContinuousLabel continuousLabel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.DataFrameScope;
Expand All @@ -55,7 +55,7 @@ public RegressionModel encodeModel(Schema schema){
Map<?, ?> classExprs = getClassExprs();
RegressionModel.NormalizationMethod normalizationMethod = parseNormalizationMethod(getNormalizationMethod());

PMMLEncoder encoder = schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.DataFrameScope;
Expand All @@ -48,7 +48,7 @@ public RegressionModel encodeModel(Schema schema){
Object expr = getExpr();
RegressionModel.NormalizationMethod normalizationMethod = parseNormalizationMethod(getNormalizationMethod());

PMMLEncoder encoder = schema.getEncoder();
ModelEncoder encoder = schema.getEncoder();
ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

Expand Down
14 changes: 7 additions & 7 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -98,42 +98,42 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-testing</artifactId>
<version>1.6.4</version>
<version>1.6.5</version>
</dependency>

<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-h2o</artifactId>
<version>1.2.12</version>
<version>1.2.13</version>
</dependency>

<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-lightgbm</artifactId>
<version>1.5.3</version>
<version>1.5.4</version>
</dependency>

<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-python</artifactId>
<version>1.2.2</version>
<version>1.2.4</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-python-testing</artifactId>
<version>1.2.2</version>
<version>1.2.4</version>
</dependency>

<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-statsmodels</artifactId>
<version>1.1.0</version>
<version>1.1.1</version>
</dependency>

<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-xgboost</artifactId>
<version>1.8.5</version>
<version>1.8.6</version>
</dependency>

<dependency>
Expand Down

0 comments on commit 3679e74

Please sign in to comment.