Skip to content

Commit

Permalink
Added support for model settings
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jun 13, 2024
1 parent 57b946c commit b701fad
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 23 deletions.
39 changes: 26 additions & 13 deletions pmml-rexp/src/main/java/org/jpmml/rexp/MaxLikConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@

public class MaxLikConverter extends ModelConverter<RGenericVector> {

private Map<String, RExp> variables = null;
private RFunctionCall settings = null;

private RFunctionCall availabilities = null;
private Map<String, RExp> variables = null;

private Map<?, RFunctionCall> utilityFunctions = null;

private Map<?, Feature> availabilityFeatures = null;

private Map<?, Feature> utilityFunctionFeatures = null;

private RFunctionCall nlNests = null;

private Map<?, RFunctionCall> nlStructures = null;

private Map<String, Double> estimates = null;

private Map<?, Feature> availabilityFeatures = null;

private Map<?, Feature> utilityFunctionFeatures = null;


public MaxLikConverter(RGenericVector maxLik){
super(maxLik);
Expand All @@ -85,9 +85,15 @@ public void encodeSchema(RExpEncoder encoder){

RStringVector modelTypeList = maxLik.getStringElement("modelTypeList");

Map<String, RExp> variables = this.variables;
RFunctionCall settings = this.settings;

Map<String, RExp> settingsMap = parseList(settings, (value) -> value);

RFunctionCall alternatives = (RFunctionCall)settingsMap.get("alternatives");
RFunctionCall availabilities = (RFunctionCall)settingsMap.get("avail");
RString choiceVar = (RString)settingsMap.get("choiceVar");

RFunctionCall availabilities = this.availabilities;
Map<String, RExp> variables = this.variables;
Map<?, RFunctionCall> utilityFunctions = this.utilityFunctions;

if(utilityFunctions.isEmpty()){
Expand All @@ -98,7 +104,7 @@ public void encodeSchema(RExpEncoder encoder){

List<?> choices = new ArrayList<>(utilityFunctions.keySet());

DataField choiceField = encoder.createDataField("choice", OpType.CATEGORICAL, TypeUtil.getDataType(choices, DataType.STRING), choices);
DataField choiceField = encoder.createDataField(choiceVar.getValue(), OpType.CATEGORICAL, TypeUtil.getDataType(choices, DataType.STRING), choices);

encoder.setLabel(choiceField);

Expand Down Expand Up @@ -386,8 +392,11 @@ private void parseApolloProbabilities(){
throw new IllegalArgumentException();
}

RFunctionCall settings = null;

Map<String, RExp> variables = new LinkedHashMap<>();

RFunctionCall alternatives = null;
RFunctionCall availabilities = null;
Map<Object, RFunctionCall> utilityFunctions = new LinkedHashMap<>();

Expand All @@ -409,8 +418,8 @@ private void parseApolloProbabilities(){
if(firstArgValue instanceof RString){
RString string = (RString)firstArgValue;

if(matchVariable(firstArgValue, "avail")){
availabilities = (RFunctionCall)secondArgValue;
if(matchVariable(firstArgValue, "mnl_settings") || matchVariable(firstArgValue, "nl_settings")){
settings = (RFunctionCall)secondArgValue;

continue;
} else
Expand Down Expand Up @@ -449,13 +458,17 @@ private void parseApolloProbabilities(){
}
}

if(settings == null){
throw new IllegalArgumentException();
} // End if

if(!Collections.disjoint(utilityFunctions.keySet(), nlStructures.keySet())){
throw new IllegalArgumentException();
}

this.variables = variables;
this.settings = settings;

this.availabilities = availabilities;
this.variables = variables;
this.utilityFunctions = utilityFunctions;

this.nlNests = nlNests;
Expand Down
14 changes: 4 additions & 10 deletions pmml-rexp/src/test/R/apollo.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ generateMNLModeChoice = function(){
V[["air"]] = asc_air + b_tt_air * time_air + b_access * access_air + b_cost * cost_air + b_no_frills * ( service_air == 1 ) + b_wifi * ( service_air == 2 ) + b_food * ( service_air == 3 )
V[["rail"]] = asc_rail + b_tt_rail * time_rail + b_access * access_rail + b_cost * cost_rail + b_no_frills * ( service_rail == 1 ) + b_wifi * ( service_rail == 2 ) + b_food * ( service_rail == 3 )

alternatives = c(car=1, bus=2, air=3, rail=4)
avail = list(car=av_car, bus=av_bus, air=av_air, rail=av_rail)

### Define settings for MNL model component
mnl_settings = list(
alternatives = alternatives,
avail = avail,
alternatives = c(car=1, bus=2, air=3, rail=4),
avail = list(car=av_car, bus=av_bus, air=av_air, rail=av_rail),
choiceVar = choice,
utilities = V
)
Expand Down Expand Up @@ -169,13 +166,10 @@ generateNLModeChoice = function(){
nlStructure[["PT"]] = c("bus","fastPT")
nlStructure[["fastPT"]] = c("air","rail")

alternatives = c(car=1, bus=2, air=3, rail=4)
avail = list(car=av_car, bus=av_bus, air=av_air, rail=av_rail)

### Define settings for NL model
nl_settings = list(
alternatives = alternatives,
avail = avail,
alternatives = c(car=1, bus=2, air=3, rail=4),
avail = list(car=av_car, bus=av_bus, air=av_air, rail=av_rail),
choiceVar = choice,
utilities = V,
nlNests = nlNests,
Expand Down
Binary file modified pmml-rexp/src/test/resources/rds/MNLModeChoice.rds
Binary file not shown.
Binary file modified pmml-rexp/src/test/resources/rds/NLModeChoice.rds
Binary file not shown.

0 comments on commit b701fad

Please sign in to comment.