Skip to content

Commit

Permalink
Refactored LightGBMConverter class
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Aug 4, 2024
1 parent 687b640 commit 0fb0078
Showing 1 changed file with 49 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,78 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.dmg.pmml.PMML;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;
import org.jpmml.lightgbm.GBDT;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.rexp.Converter;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.REnvironment;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RRaw;

public class LightGBMConverter extends Converter<REnvironment> {
public class LightGBMConverter extends ModelConverter<REnvironment> {

private GBDT gbdt = null;


public LightGBMConverter(REnvironment environment){
super(environment);
}

@Override
public PMML encodePMML(RExpEncoder encoder){
public void encodeSchema(RExpEncoder encoder){
GBDT gbdt = ensureGBDT();

Schema schema = gbdt.encodeSchema(null, null, encoder);

Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();

encoder.setLabel(label);

for(Feature feature : features){
encoder.addFeature(feature);
}
}

@Override
public MiningModel encodeModel(Schema schema){
GBDT gbdt = ensureGBDT();

// XXX
Map<String, Object> options = Collections.emptyMap();

Schema lgbmSchema = gbdt.toLightGBMSchema(schema);

return gbdt.encodeModel(options, lgbmSchema);
}

private GBDT ensureGBDT(){

if(this.gbdt == null){
this.gbdt = loadGBDT();
}

return this.gbdt;
}

private GBDT loadGBDT(){
REnvironment environment = getObject();

RRaw raw = (RRaw)environment.findVariable("raw");
if(raw == null){
throw new IllegalArgumentException();
}

GBDT gbdt;

try(InputStream is = new ByteArrayInputStream(raw.getValue())){
gbdt = LightGBMUtil.loadGBDT(is);
return LightGBMUtil.loadGBDT(is);
} catch(IOException ioe){
throw new IllegalArgumentException(ioe);
}

return gbdt.encodePMML(Collections.emptyMap(), null, null);
}
}

0 comments on commit 0fb0078

Please sign in to comment.