Skip to content

Commit

Permalink
Added the 'input_float' tree model transformation option
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 28, 2024
1 parent 5d07e83 commit 7e1292b
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ public class Main {
)
private Boolean flat = null;

@Parameter (
names = {"--X-" + HasTreeOptions.OPTION_INPUT_FLOAT},
description = "Allow field data type updates",
arity = 1
)
private Boolean inputFloat = null;

@Parameter (
names = {"--X-" + HasTreeOptions.OPTION_NODE_ID},
description = "Keep SkLearn node identifiers",
Expand Down Expand Up @@ -219,6 +226,7 @@ public void run() throws Exception {
options.put(HasTreeOptions.OPTION_ALLOW_MISSING, this.allowMissing);
options.put(HasTreeOptions.OPTION_COMPACT, this.compact);
options.put(HasTreeOptions.OPTION_FLAT, this.flat);
options.put(HasTreeOptions.OPTION_INPUT_FLOAT, this.inputFloat);
options.put(HasTreeOptions.OPTION_NODE_ID, this.nodeId);
options.put(HasTreeOptions.OPTION_NODE_SCORE, this.nodeScore);
options.put(HasTreeOptions.OPTION_NUMERIC, this.numeric);
Expand Down
6 changes: 6 additions & 0 deletions pmml-sklearn/src/main/java/sklearn/tree/HasTreeOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.LinkedHashMap;
import java.util.Map;

import org.dmg.pmml.DerivedField;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
Expand Down Expand Up @@ -48,6 +49,11 @@ public interface HasTreeOptions extends HasSkLearnOptions, HasNativeConfiguratio
*/
String OPTION_FLAT = "flat";

/**
* @see DerivedField
*/
String OPTION_INPUT_FLOAT = "input_float";

/**
* @see Node#hasExtensions()
* @see Node#getExtensions()
Expand Down
49 changes: 45 additions & 4 deletions pmml-sklearn/src/main/java/sklearn/tree/TreeUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
import com.google.common.primitives.Doubles;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.HasContinuousDomain;
import org.dmg.pmml.HasExtensions;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
Expand Down Expand Up @@ -345,8 +348,9 @@ public Number get(int index){
static
public <E extends Estimator & HasTreeOptions> Schema configureSchema(E estimator, Schema schema){
Boolean numeric = (Boolean)estimator.getOption(HasTreeOptions.OPTION_NUMERIC, Boolean.TRUE);
Boolean inputFloat = (Boolean)estimator.getOption(HasTreeOptions.OPTION_INPUT_FLOAT, null);

return toTreeModelSchema(estimator.getDataType(), numeric, schema);
return toTreeModelSchema(numeric, inputFloat, schema);
}

static
Expand Down Expand Up @@ -567,7 +571,7 @@ private void encodeNodeId(Estimator estimator, Model model){
}

static
private Schema toTreeModelSchema(DataType dataType, boolean numeric, Schema schema){
private Schema toTreeModelSchema(Boolean numeric, Boolean inputFloat, Schema schema){
Function<Feature, Feature> function = new Function<Feature, Feature>(){

@Override
Expand All @@ -585,14 +589,51 @@ public Feature apply(Feature feature){
return missingValueFeature;
} else

if(feature instanceof ThresholdFeature && !numeric){
if(feature instanceof ThresholdFeature && (numeric != null && !numeric)){
ThresholdFeature thresholdFeature = (ThresholdFeature)feature;

return thresholdFeature;
} else

if(inputFloat != null && inputFloat){
ContinuousFeature continuousFeature = feature.toContinuousFeature();

DataType dataType = continuousFeature.getDataType();
if(dataType != DataType.FLOAT){
Field<?> field = continuousFeature.getField();

field.setDataType(DataType.FLOAT);

// XXX
if(field instanceof HasContinuousDomain){
HasContinuousDomain<?> hasContinuousDomain = (HasContinuousDomain<?>)field;

if(hasContinuousDomain.hasIntervals()){
List<Interval> intervals = hasContinuousDomain.getIntervals();

for(Interval interval : intervals){
Number leftMargin = interval.getLeftMargin();
Number rightMargin = interval.getRightMargin();

if(leftMargin != null){
interval.setLeftMargin((double)leftMargin.floatValue());
} // End if

if(rightMargin != null){
interval.setRightMargin((double)rightMargin.floatValue());
}
}
}
}

return new ContinuousFeature(continuousFeature.getEncoder(), field);
}

return continuousFeature;
} else

{
ContinuousFeature continuousFeature = feature.toContinuousFeature(dataType);
ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.FLOAT);

return continuousFeature;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,22 @@ public List<Map<String, Object>> getOptionsMatrix(){
String algorithm = getAlgorithm();
String dataset = getDataset();

if((AUDIT).equals(dataset) || (IRIS).equals(dataset)){

if((DECISION_TREE).equals(algorithm) || (RANDOM_FOREST).equals(algorithm)){
Map<String, Object> options = new LinkedHashMap<>();
options.put(HasTreeOptions.OPTION_INPUT_FLOAT, new Boolean[]{false, true});

return OptionsUtil.generateOptionsMatrix(options);
}
} else

if((AUDIT_NA).equals(dataset) || (IRIS_NA).equals(dataset)){

if((RANDOM_FOREST).equals(algorithm)){
Map<String, Object> options = new LinkedHashMap<>();
options.put(HasTreeOptions.OPTION_ALLOW_MISSING, Boolean.TRUE);
options.put(HasTreeOptions.OPTION_COMPACT, new Boolean[]{true, false});
options.put(HasTreeOptions.OPTION_COMPACT, new Boolean[]{false, true});

return OptionsUtil.generateOptionsMatrix(options);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public List<Map<String, Object>> getOptionsMatrix(){

if((ISOLATION_FOREST).equals(algorithm)){
Map<String, Object> options = new LinkedHashMap<>();
options.put(HasTreeOptions.OPTION_INPUT_FLOAT, new Boolean[]{false, true});
options.put(HasTreeOptions.OPTION_PRUNE, new Boolean[]{false, true});

return OptionsUtil.generateOptionsMatrix(options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,22 @@ public List<Map<String, Object>> getOptionsMatrix(){
String algorithm = getAlgorithm();
String dataset = getDataset();

if((AUTO).equals(dataset)){

if((DECISION_TREE).equals(algorithm) || (RANDOM_FOREST).equals(algorithm)){
Map<String, Object> options = new LinkedHashMap<>();
options.put(HasTreeOptions.OPTION_INPUT_FLOAT, new Boolean[]{false, true});

return OptionsUtil.generateOptionsMatrix(options);
}
} else

if((AUTO_NA).equals(dataset)){

if((RANDOM_FOREST).equals(algorithm)){
Map<String, Object> options = new LinkedHashMap<>();
options.put(HasTreeOptions.OPTION_ALLOW_MISSING, Boolean.TRUE);
options.put(HasTreeOptions.OPTION_COMPACT, new Boolean[]{true, false});
options.put(HasTreeOptions.OPTION_COMPACT, new Boolean[]{false, true});

return OptionsUtil.generateOptionsMatrix(options);
}
Expand Down

0 comments on commit 7e1292b

Please sign in to comment.