diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionRegressor.java b/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionRegressor.java index 27af094f6..c7f2adc3f 100644 --- a/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionRegressor.java +++ b/pmml-sklearn/src/main/java/sklearn2pmml/expression/ExpressionRegressor.java @@ -46,6 +46,7 @@ public ExpressionRegressor(String module, String name){ @Override public RegressionModel encodeModel(Schema schema){ Expression expr = getExpr(); + RegressionModel.NormalizationMethod normalizationMethod = parseNormalizationMethod(getNormalizationMethod()); PMMLEncoder encoder = schema.getEncoder(); @@ -61,7 +62,7 @@ public RegressionModel encodeModel(Schema schema){ RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(exprFeature), Collections.singletonList(1d), 0d); RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel), null) - .setNormalizationMethod(RegressionModel.NormalizationMethod.NONE) + .setNormalizationMethod(normalizationMethod) .addRegressionTables(regressionTable); return regressionModel; @@ -70,4 +71,27 @@ public RegressionModel encodeModel(Schema schema){ public Expression getExpr(){ return get("expr", Expression.class); } + + public String getNormalizationMethod(){ + + if(!containsKey("normalization_method")){ + return "none"; + } + + // SkLearn2PMML 0.105.0+ + return getString("normalization_method"); + } + + static + private RegressionModel.NormalizationMethod parseNormalizationMethod(String normalizationMethod){ + + switch(normalizationMethod){ + case "none": + return RegressionModel.NormalizationMethod.NONE; + case "exp": + return RegressionModel.NormalizationMethod.EXP; + default: + throw new IllegalArgumentException(normalizationMethod); + } + } } \ No newline at end of file diff --git a/pmml-sklearn/src/test/resources/extensions/sklearn2pmml.py b/pmml-sklearn/src/test/resources/extensions/sklearn2pmml.py index 9d850c1dd..675cd91c4 100644 --- a/pmml-sklearn/src/test/resources/extensions/sklearn2pmml.py +++ b/pmml-sklearn/src/test/resources/extensions/sklearn2pmml.py @@ -317,7 +317,7 @@ def build_expr_auto(auto_df, name): expr = Expression("-1.724 * _scale_displacement(X['displacement']) + 4.879 * _scale_weight(X['weight']) + 23.45", function_defs = [_scale_displacement, _scale_weight]) pipeline = PMMLPipeline([ - ("regressor", ExpressionRegressor(expr)) + ("regressor", ExpressionRegressor(expr, normalization_method = "none")) ]) pipeline.fit(auto_X, auto_y) store_pkl(pipeline, name) diff --git a/pmml-sklearn/src/test/resources/pkl/ExpressionAuto.pkl b/pmml-sklearn/src/test/resources/pkl/ExpressionAuto.pkl index c2536acaf..6d4a3bc52 100644 Binary files a/pmml-sklearn/src/test/resources/pkl/ExpressionAuto.pkl and b/pmml-sklearn/src/test/resources/pkl/ExpressionAuto.pkl differ