Skip to content

Commit

Permalink
Added support for the 'CountVectorizer.token_pattern' attribute. Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 12, 2021
1 parent 9bd46fb commit 00ea7b3
Show file tree
Hide file tree
Showing 12 changed files with 3,068 additions and 2,050 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;
import sklearn2pmml.feature_extraction.text.Matcher;
import sklearn2pmml.feature_extraction.text.Tokenizer;

public class CountVectorizer extends Transformer {
Expand Down Expand Up @@ -157,6 +158,13 @@ public DefineFunction encodeDefineFunction(Feature feature, SkLearnEncoder encod

if(stripAccents != null){
throw new IllegalArgumentException(stripAccents);
} // End if

if(tokenizer == null){
String tokenPattern = getTokenPattern();

tokenizer = new Matcher()
.setWordRE(tokenPattern);
}

ParameterField documentField = new ParameterField(FieldName.create("document"));
Expand Down Expand Up @@ -237,7 +245,11 @@ public String getStripAccents(){
}

public Tokenizer getTokenizer(){
return get("tokenizer", Tokenizer.class);
return getOptional("tokenizer", Tokenizer.class);
}

public String getTokenPattern(){
return getString("token_pattern");
}

public Map<String, ?> getVocabulary(){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,11 @@ public TextIndex configure(TextIndex textIndex){
public String formatStopWordsRE(List<String> stopWords){
String wordRE = getWordRE();

if(!("\\w+").equals(wordRE)){
throw new IllegalArgumentException(wordRE);
}
boolean unicode = wordRE.startsWith("(?u)");

Joiner joiner = Joiner.on("|");

return "\\b(" + joiner.join(stopWords) + ")\\b";
return (unicode ? "(?u)" : "") + "\\b(" + joiner.join(stopWords) + ")\\b";
}

public void __setstate__(String wordRE){
Expand Down
5 changes: 5 additions & 0 deletions src/test/java/org/jpmml/sklearn/TokenizerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public void split() throws Exception {
@Test
public void match() throws Exception {
Matcher matcher = new Matcher()
.setWordRE("(?u)\\b\\w\\w+\\b");

evaluate("CountVectorizer", "Sentiment", matcher);

matcher = new Matcher()
.setWordRE("\\w+");

evaluate("Matcher", "Sentiment", matcher);
Expand Down
Loading

0 comments on commit 00ea7b3

Please sign in to comment.