From e89c0db13fff222541cdd74440cbcd7649ebffe1 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sun, 31 Mar 2024 13:22:59 +0300 Subject: [PATCH] Cleaned up code --- .../src/main/java/category_encoders/CountEncoder.java | 2 +- .../src/main/java/category_encoders/MapEncoder.java | 2 +- .../src/main/java/category_encoders/MeanEncoder.java | 2 +- .../src/main/java/category_encoders/OrdinalEncoder.java | 2 +- .../main/java/sklearn/feature_extraction/DictVectorizer.java | 2 +- .../main/java/sklearn2pmml/preprocessing/LookupTransformer.java | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pmml-sklearn-extension/src/main/java/category_encoders/CountEncoder.java b/pmml-sklearn-extension/src/main/java/category_encoders/CountEncoder.java index d625f86ff..ac3e4cbf5 100644 --- a/pmml-sklearn-extension/src/main/java/category_encoders/CountEncoder.java +++ b/pmml-sklearn-extension/src/main/java/category_encoders/CountEncoder.java @@ -176,7 +176,7 @@ public Boolean getNormalize(){ @SuppressWarnings("unchecked") public Map> getMinGroupCategories(){ - Map minGroupCategories = get("_min_group_categories", Map.class); + Map minGroupCategories = (Map)getDict("_min_group_categories"); return CategoryEncoderUtil.toTransformedMap(minGroupCategories, key -> ScalarUtil.decode(key), value -> (Map)value); } diff --git a/pmml-sklearn-extension/src/main/java/category_encoders/MapEncoder.java b/pmml-sklearn-extension/src/main/java/category_encoders/MapEncoder.java index 54b2c18da..366e5e747 100644 --- a/pmml-sklearn-extension/src/main/java/category_encoders/MapEncoder.java +++ b/pmml-sklearn-extension/src/main/java/category_encoders/MapEncoder.java @@ -34,7 +34,7 @@ public MapEncoder(String module, String name){ public String functionName(); public Map getMapping(){ - Map mapping = get("mapping", Map.class); + Map mapping = getDict("mapping"); return CategoryEncoderUtil.toTransformedMap(mapping, key -> ScalarUtil.decode(key), value -> (Series)value); } diff --git a/pmml-sklearn-extension/src/main/java/category_encoders/MeanEncoder.java b/pmml-sklearn-extension/src/main/java/category_encoders/MeanEncoder.java index eb0540b67..7345be0fc 100644 --- a/pmml-sklearn-extension/src/main/java/category_encoders/MeanEncoder.java +++ b/pmml-sklearn-extension/src/main/java/category_encoders/MeanEncoder.java @@ -129,7 +129,7 @@ public String getDerivedName(){ @Override public Map getMapping(){ - Map mapping = get("mapping", Map.class); + Map mapping = getDict("mapping"); return CategoryEncoderUtil.toTransformedMap(mapping, key -> ScalarUtil.decode(key), value -> toMeanSeries((DataFrame)value, createFunction())); } diff --git a/pmml-sklearn-extension/src/main/java/category_encoders/OrdinalEncoder.java b/pmml-sklearn-extension/src/main/java/category_encoders/OrdinalEncoder.java index 260aa3828..554e6766a 100644 --- a/pmml-sklearn-extension/src/main/java/category_encoders/OrdinalEncoder.java +++ b/pmml-sklearn-extension/src/main/java/category_encoders/OrdinalEncoder.java @@ -134,7 +134,7 @@ public Map getCategoryMapping(){ return SeriesUtil.toMap(mapping, Functions.identity(), ValueUtil::asInteger); } catch(IllegalArgumentException iae){ - return get("mapping", Map.class); + return (Map)getDict("mapping"); } } } diff --git a/pmml-sklearn/src/main/java/sklearn/feature_extraction/DictVectorizer.java b/pmml-sklearn/src/main/java/sklearn/feature_extraction/DictVectorizer.java index 6c4076d62..a3e91002c 100644 --- a/pmml-sklearn/src/main/java/sklearn/feature_extraction/DictVectorizer.java +++ b/pmml-sklearn/src/main/java/sklearn/feature_extraction/DictVectorizer.java @@ -105,6 +105,6 @@ public String getSeparator(){ } public Map getVocabulary(){ - return get("vocabulary_", Map.class); + return (Map)getDict("vocabulary_"); } } diff --git a/pmml-sklearn/src/main/java/sklearn2pmml/preprocessing/LookupTransformer.java b/pmml-sklearn/src/main/java/sklearn2pmml/preprocessing/LookupTransformer.java index 4f7482ff3..3adc71d68 100644 --- a/pmml-sklearn/src/main/java/sklearn2pmml/preprocessing/LookupTransformer.java +++ b/pmml-sklearn/src/main/java/sklearn2pmml/preprocessing/LookupTransformer.java @@ -137,7 +137,7 @@ protected Map> parseMapping(List inputColumns, Stri } public Map getMapping(){ - return get("mapping", Map.class); + return getDict("mapping"); } public Object getDefaultValue(){