Skip to content

Commit

Permalink
Added version downgrade transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Nov 19, 2024
1 parent 36459a4 commit b1295d5
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,18 @@
import java.lang.reflect.Field;
import java.util.Objects;

import org.dmg.pmml.Apply;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.TargetValue;
import org.dmg.pmml.Version;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.time_series.TrendExpoSmooth;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.annotations.Added;
import org.jpmml.model.annotations.Optional;
import org.jpmml.model.annotations.Removed;
Expand Down Expand Up @@ -71,4 +80,74 @@ public void handleOptional(PMMLObject object, AnnotatedElement element, Optional
@Override
public void handleRequired(PMMLObject object, AnnotatedElement element, Required required){
}

@Override
public VisitorAction visit(Apply apply){
Object defaultValue = apply.getDefaultValue();

if(defaultValue != null){

if(this.version.compareTo(Version.PMML_4_1) == 0){
Object mapMissingTo = apply.getMapMissingTo();

if(mapMissingTo != null){
throw new UnsupportedAttributeException(apply, PMMLAttributes.APPLY_DEFAULTVALUE, defaultValue);
}

apply
.setDefaultValue(null)
.setMapMissingTo(defaultValue);
}
}

return super.visit(apply);
}

@Override
public VisitorAction visit(MiningField miningField){
MiningField.UsageType usageType = miningField.getUsageType();

switch(usageType){
case TARGET:
if(this.version.compareTo(Version.PMML_4_2) < 0){
miningField.setUsageType(MiningField.UsageType.PREDICTED);
}
break;
default:
break;
}

return super.visit(miningField);
}

@Override
public VisitorAction visit(PMML pmml){
pmml.setVersion(this.version.getVersion());

return super.visit(pmml);
}

@Override
public VisitorAction visit(TargetValue targetValue){
String displayValue = targetValue.getDisplayValue();

if(displayValue != null){

if(this.version.compareTo(Version.PMML_3_2) <= 0){
throw new UnsupportedAttributeException(targetValue, PMMLAttributes.TARGETVALUE_DISPLAYVALUE, displayValue);
}
}

return super.visit(targetValue);
}

@Override
public VisitorAction visit(TrendExpoSmooth trendExpoSmooth){

if(this.version.compareTo(Version.PMML_4_0) == 0){
throw new UnsupportedElementException(trendExpoSmooth);
}

return super.visit(trendExpoSmooth);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
*/
package org.jpmml.model.visitors;

import org.dmg.pmml.Apply;
import org.dmg.pmml.ClusteringModelQuality;
import org.dmg.pmml.Extension;
import org.dmg.pmml.Header;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.ModelExplanation;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.TargetValue;
import org.dmg.pmml.Version;
import org.dmg.pmml.clustering.ClusteringModel;
import org.dmg.pmml.clustering.PMMLAttributes;
import org.dmg.pmml.time_series.TrendExpoSmooth;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.model.UnsupportedAttributeException;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
Expand All @@ -23,7 +30,55 @@
public class VersionDowngraderTest {

@Test
public void downgrade(){
public void downgradeApply(){
Apply apply = new Apply()
.setDefaultValue(0)
.setMapMissingTo(null);

apply = downgrade(apply, Version.PMML_4_2);

assertEquals(0, apply.getDefaultValue());
assertNull(apply.getMapMissingTo());

apply = downgrade(apply, Version.PMML_4_1);

assertNull(apply.getDefaultValue());
assertEquals(0, apply.getMapMissingTo());

apply = new Apply()
.setDefaultValue(0)
.setMapMissingTo(-999);

apply = downgrade(apply, Version.PMML_4_2);

assertEquals(0, apply.getDefaultValue());
assertEquals(-999, apply.getMapMissingTo());

try {
downgrade(apply, Version.PMML_4_1);

fail();
} catch(UnsupportedAttributeException uae){
// Ignored
}
}

@Test
public void downgradeMiningField(){
MiningField miningField = new MiningField()
.setUsageType(MiningField.UsageType.TARGET);

miningField = downgrade(miningField, Version.PMML_4_2);

assertEquals(MiningField.UsageType.TARGET, miningField.getUsageType());

miningField = downgrade(miningField, Version.PMML_4_1);

assertEquals(MiningField.UsageType.PREDICTED, miningField.getUsageType());
}

@Test
public void downgradePMML(){
Extension extension = new Extension();

ClusteringModelQuality clusteringModelQuality = new ClusteringModelQuality()
Expand All @@ -43,47 +98,77 @@ public void downgrade(){
.setHeader(header)
.addModels(clusteringModel);

VersionInspector inspector;

try {
inspector = new VersionDowngrader(Version.XPMML);

fail();
} catch(IllegalArgumentException iae){
// Ignored
}

inspector = new VersionDowngrader(Version.PMML_4_4);
inspector.applyTo(pmml);
pmml = downgrade(pmml, Version.PMML_4_4);

assertTrue(clusteringModelQuality.hasExtensions());

inspector = new VersionDowngrader(Version.PMML_4_3);
inspector.applyTo(pmml);
assertEquals("4.4", pmml.getVersion());

pmml = downgrade(pmml, Version.PMML_4_3);

assertFalse(clusteringModelQuality.hasExtensions());

assertNotNull(header.getModelVersion());

inspector = new VersionDowngrader(Version.PMML_4_1);
inspector.applyTo(pmml);
pmml = downgrade(pmml, Version.PMML_4_1);

assertNull(header.getModelVersion());

assertFalse(clusteringModel.isScorable());

inspector = new VersionDowngrader(Version.PMML_4_0);
inspector.applyTo(pmml);
pmml = downgrade(pmml, Version.PMML_4_0);

assertTrue(clusteringModel.isScorable());

assertNull(ReflectionUtil.getFieldValue(PMMLAttributes.CLUSTERINGMODEL_SCORABLE, clusteringModel));

assertNotNull(clusteringModel.getModelExplanation());

inspector = new VersionDowngrader(Version.PMML_3_2);
inspector.applyTo(pmml);
pmml = downgrade(pmml, Version.PMML_3_2);

assertNull(clusteringModel.getModelExplanation());

assertEquals("3.2", pmml.getVersion());
}

@Test
public void downgradeTargetValue(){
TargetValue targetValue = new TargetValue()
.setValue(1)
.setDisplayValue("one");

targetValue = downgrade(targetValue, Version.PMML_4_0);

assertEquals(1, targetValue.getValue());
assertEquals("one", targetValue.getDisplayValue());

try {
downgrade(targetValue, Version.PMML_3_2);

fail();
} catch(UnsupportedAttributeException uae){
// Ignored
}
}

@Test
public void downgradeTrendExpoSmooth(){
TrendExpoSmooth trendExpoSmooth = new TrendExpoSmooth();

trendExpoSmooth = downgrade(trendExpoSmooth, Version.PMML_4_1);

try {
downgrade(trendExpoSmooth, Version.PMML_4_0);
} catch(IllegalArgumentException iae){
// Ignored
}
}

static
private <E extends PMMLObject> E downgrade(E object, Version version){
VersionDowngrader inspector = new VersionDowngrader(version);
inspector.applyTo(object);

return object;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
public class VersionStandardizerTest {

@Test
public void standardize(){
public void standardizePMML(){
PMML pmml = new PMML()
.setVersion("4.4")
.setBaseVersion("4.3");
Expand Down

0 comments on commit b1295d5

Please sign in to comment.