Skip to content

Commit

Permalink
Added support for the 'TfidfVectorizer' transformation type
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 25, 2017
1 parent 909b71d commit 65636a7
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Java library and command-line application for converting [Scikit-Learn] (http://
* [`ensemble.VotingClassifier`] (http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html)
* Feature Extraction:
* [`feature_extraction.text.CountVectorizer`] (http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html)
* [`feature_extraction.text.TfidfVectorizer`] (http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html)
* Feature Selection:
* [`feature_selection.GenericUnivariateSelect`] (http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.GenericUnivariateSelect.html) (only via `sklearn2pmml.SelectorProxy`)
* [`feature_selection.RFE`] (http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.RFE.html) (only via `sklearn2pmml.SelectorProxy`)
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/scipy/sparse/CSRMatrixUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ private CSRMatrixUtil(){
public int[] getShape(CSRMatrix matrix){
Object[] shape = matrix.getShape();

if(shape.length == 1){
return new int[]{ValueUtil.asInt((Number)shape[0])};
} else

if(shape.length == 2){
return new int[]{ValueUtil.asInt((Number)shape[0]), ValueUtil.asInt((Number)shape[1])};
}

List<? extends Number> values = (List)Arrays.asList(shape);

return Ints.toArray(ValueUtil.asIntegers(values));
Expand Down
20 changes: 13 additions & 7 deletions src/main/java/sklearn/feature_extraction/text/CountVectorizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
import com.google.common.collect.HashBiMap;
import numpy.DType;
import numpy.core.Scalar;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
Expand Down Expand Up @@ -92,7 +92,7 @@ public List<Feature> encodeFeatures(List<String> ids, List<Feature> features, Sk
} // End if

if(tokenPattern != null && !("(?u)\\b\\w\\w+\\b").equals(tokenPattern)){
throw new IllegalArgumentException();
throw new IllegalArgumentException(tokenPattern);
}

BiMap<String, Integer> termIndexMap = HashBiMap.create(vocabulary.size());
Expand Down Expand Up @@ -140,32 +140,38 @@ public ContinuousFeature toContinuousFeature(){
for(int i = 0, max = indexTermMap.size(); i < max; i++){
String term = indexTermMap.get(i);

Expression termFrequency = PMMLUtil.createApply(defineFunction.getName(), feature.ref(), PMMLUtil.createConstant(term));

final
Apply apply = PMMLUtil.createApply(defineFunction.getName(), feature.ref(), PMMLUtil.createConstant(term));
Expression weightedTermFrequency = encodeWeight(termFrequency, i);

Feature termFrequencyFeature = new Feature(encoder, FieldName.create("tf(" + term + ")"), dtype != null ? dtype.getDataType() : DataType.DOUBLE){
Feature termFeature = new Feature(encoder, FieldName.create("tf" + ((weightedTermFrequency != termFrequency) ? "-idf" : "") + "(" + term + ")"), dtype != null ? dtype.getDataType() : DataType.DOUBLE){

@Override
public ContinuousFeature toContinuousFeature(){
PMMLEncoder encoder = ensureEncoder();

DerivedField derivedField = encoder.getDerivedField(getName());
if(derivedField == null){
derivedField = encoder.createDerivedField(getName(), OpType.CONTINUOUS, getDataType(), apply);
derivedField = encoder.createDerivedField(getName(), OpType.CONTINUOUS, getDataType(), weightedTermFrequency);
}

return new ContinuousFeature(encoder, derivedField);
}
};

ids.add((termFrequencyFeature.getName()).getValue());
ids.add((termFeature.getName()).getValue());

result.add(termFrequencyFeature);
result.add(termFeature);
}

return result;
}

public Expression encodeWeight(Expression expression, int index){
return expression;
}

public String getAnalyzer(){
return (String)get("analyzer");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.feature_extraction.text;

import java.util.List;

import net.razorvine.pickle.objects.ClassDict;
import scipy.sparse.CSRMatrix;

public class TfidfTransformer extends ClassDict {

public TfidfTransformer(String module, String name){
super(module, name);
}

public Number getWeight(int index){
CSRMatrix idfDiag = (CSRMatrix)get("_idf_diag");

List<?> data = idfDiag.getData();

return (Number)data.get(index);
}

public String getNorm(){
return (String)get("norm");
}

public Boolean getSublinearTf(){
return (Boolean)get("sublinear_tf");
}

public Boolean getUseIdf(){
return (Boolean)get("use_idf");
}
}
68 changes: 68 additions & 0 deletions src/main/java/sklearn/feature_extraction/text/TfidfVectorizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.feature_extraction.text;

import java.util.List;

import org.dmg.pmml.Expression;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.sklearn.SkLearnEncoder;

public class TfidfVectorizer extends CountVectorizer {

public TfidfVectorizer(String module, String name){
super(module, name);
}

@Override
public List<Feature> encodeFeatures(List<String> ids, List<Feature> features, SkLearnEncoder encoder){
TfidfTransformer transformer = getTransformer();

String norm = transformer.getNorm();
if(norm != null){
throw new IllegalArgumentException(norm);
}

return super.encodeFeatures(ids, features, encoder);
}

@Override
public Expression encodeWeight(Expression expression, int index){
TfidfTransformer transformer = getTransformer();

Boolean sublinearTf = transformer.getSublinearTf();
if(sublinearTf){
expression = PMMLUtil.createApply("+", PMMLUtil.createApply("log", expression), PMMLUtil.createConstant(1d));
} // End if

Boolean useIdf = transformer.getUseIdf();
if(useIdf){
Number weight = transformer.getWeight(index);

expression = PMMLUtil.createApply("*", expression, PMMLUtil.createConstant(weight));
}

return expression;
}

public TfidfTransformer getTransformer(){
return (TfidfTransformer)get("_tfidf");
}
}
2 changes: 2 additions & 0 deletions src/main/resources/META-INF/sklearn2pmml.properties
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ sklearn.ensemble.weight_boosting.AdaBoostRegressor = sklearn.ensemble.weight_boo
sklearn.ensemble.voting_classifier.VotingClassifier = sklearn.ensemble.voting_classifier.VotingClassifier
sklearn.externals.joblib.numpy_pickle.NumpyArrayWrapper = joblib.NumpyArrayWrapper
sklearn.feature_extraction.text.CountVectorizer = sklearn.feature_extraction.text.CountVectorizer
sklearn.feature_extraction.text.TfidfTransformer = sklearn.feature_extraction.text.TfidfTransformer
sklearn.feature_extraction.text.TfidfVectorizer = sklearn.feature_extraction.text.TfidfVectorizer
sklearn.feature_selection.from_model.SelectFromModel = sklearn.feature_selection.SelectFromModel
sklearn.feature_selection.univariate_selection.SelectKBest = sklearn.feature_selection.SelectKBest
sklearn.linear_model.base.LinearRegression = sklearn.linear_model.base.LinearRegression
Expand Down

0 comments on commit 65636a7

Please sign in to comment.