Skip to content

Commit

Permalink
Store frequencies in a tree
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Vasco committed Apr 21, 2022
1 parent 17e361c commit 2d80e26
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 251 deletions.
78 changes: 20 additions & 58 deletions src/main/kotlin/com/github/pemistahl/lingua/api/LanguageDetector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ import com.github.pemistahl.lingua.internal.Constant.NO_LETTER
import com.github.pemistahl.lingua.internal.Constant.NUMBERS
import com.github.pemistahl.lingua.internal.Constant.PUNCTUATION
import com.github.pemistahl.lingua.internal.Constant.isJapaneseAlphabet
import com.github.pemistahl.lingua.internal.JsonLanguageModel
import com.github.pemistahl.lingua.internal.Ngram
import com.github.pemistahl.lingua.internal.TestDataLanguageModel
import com.github.pemistahl.lingua.internal.TrainingDataLanguageModel
import com.github.pemistahl.lingua.internal.util.extension.containsAnyOf
import com.github.pemistahl.lingua.internal.util.extension.incrementCounter
import com.github.pemistahl.lingua.internal.util.extension.isLogogram
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json
import java.util.SortedMap
import java.util.TreeMap
import java.util.concurrent.Callable
Expand All @@ -41,12 +44,6 @@ import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import kotlin.math.ln

private val UNIGRAM_MODELS = mutableMapOf<Language, TrainingDataLanguageModel>()
private val BIGRAM_MODELS = mutableMapOf<Language, TrainingDataLanguageModel>()
private val TRIGRAM_MODELS = mutableMapOf<Language, TrainingDataLanguageModel>()
private val QUADRIGRAM_MODELS = mutableMapOf<Language, TrainingDataLanguageModel>()
private val FIVEGRAM_MODELS = mutableMapOf<Language, TrainingDataLanguageModel>()

/**
* Detects the language of given input text.
*/
Expand All @@ -55,11 +52,7 @@ class LanguageDetector internal constructor(
internal val minimumRelativeDistance: Double,
isEveryLanguageModelPreloaded: Boolean,
internal val numberOfLoadedLanguages: Int = languages.size,
internal val unigramLanguageModels: MutableMap<Language, TrainingDataLanguageModel> = UNIGRAM_MODELS,
internal val bigramLanguageModels: MutableMap<Language, TrainingDataLanguageModel> = BIGRAM_MODELS,
internal val trigramLanguageModels: MutableMap<Language, TrainingDataLanguageModel> = TRIGRAM_MODELS,
internal val quadrigramLanguageModels: MutableMap<Language, TrainingDataLanguageModel> = QUADRIGRAM_MODELS,
internal val fivegramLanguageModels: MutableMap<Language, TrainingDataLanguageModel> = FIVEGRAM_MODELS
internal val languageModels: MutableMap<Language, TrainingDataLanguageModel> = mutableMapOf()
) {
internal val threadPool = createThreadPool()

Expand Down Expand Up @@ -153,9 +146,7 @@ class LanguageDetector internal constructor(
val unigramCounts = if (i == 1) {
val languages = probabilities.keys
val unigramFilteredLanguages =
if (languages.isNotEmpty()) filteredLanguages.asSequence()
.filter { languages.contains(it) }
.toSet()
if (languages.isNotEmpty()) filteredLanguages.filterTo(mutableSetOf()) { languages.contains(it) }
else filteredLanguages
countUnigramsOfInputText(testDataModel, unigramFilteredLanguages)
} else {
Expand Down Expand Up @@ -194,13 +185,7 @@ class LanguageDetector internal constructor(
threadPool.shutdownNow()
}

for (language in languages) {
unigramLanguageModels.remove(language)
bigramLanguageModels.remove(language)
trigramLanguageModels.remove(language)
quadrigramLanguageModels.remove(language)
fivegramLanguageModels.remove(language)
}
languageModels.clear()
}

internal fun cleanUpInputText(text: String): String {
Expand Down Expand Up @@ -433,55 +418,32 @@ class LanguageDetector internal constructor(
language: Language,
ngram: Ngram
): Float {
val ngramLength = ngram.value.length
val languageModels = when (ngramLength) {
5 -> fivegramLanguageModels
4 -> quadrigramLanguageModels
3 -> trigramLanguageModels
2 -> bigramLanguageModels
1 -> unigramLanguageModels
0 -> throw IllegalArgumentException("Zerogram detected")
else -> throw IllegalArgumentException("unsupported ngram length detected: ${ngram.value.length}")
}

val model = loadLanguageModels(languageModels, language, ngramLength)
require(ngram.length > 0) { "Zerogram detected" }
require(ngram.length <= 5) { "unsupported ngram length detected: ${ngram.length}" }
val model = loadLanguageModels(languageModels, language)

return model.getRelativeFrequency(ngram)
}

private fun loadLanguageModels(
languageModels: MutableMap<Language, TrainingDataLanguageModel>,
language: Language,
ngramLength: Int
): TrainingDataLanguageModel {
if (languageModels.containsKey(language)) {
return languageModels.getValue(language)
language: Language
): TrainingDataLanguageModel =
languageModels.computeIfAbsent(language, ::loadLanguageModel)

private fun loadLanguageModel(language: Language): TrainingDataLanguageModel {
val jsonLanguageModels: Sequence<JsonLanguageModel> = (1..5).asSequence().map { ngramLength ->
val fileName = "${Ngram.getNgramNameByLength(ngramLength)}s.json"
val filePath = "/language-models/${language.isoCode639_1}/$fileName"
Json.decodeFromString(Language::class.java.getResourceAsStream(filePath).reader().use { it.readText() })
}
val model = loadLanguageModel(language, ngramLength)
languageModels[language] = model
return model
}

private fun loadLanguageModel(language: Language, ngramLength: Int): TrainingDataLanguageModel {
val fileName = "${Ngram.getNgramNameByLength(ngramLength)}s.json"
val filePath = "/language-models/${language.isoCode639_1}/$fileName"
val inputStream = Language::class.java.getResourceAsStream(filePath)
val jsonContent = inputStream.bufferedReader(Charsets.UTF_8).use { it.readText() }
return TrainingDataLanguageModel.fromJson(jsonContent)
return TrainingDataLanguageModel.fromJson(language, jsonLanguageModels)
}

private fun preloadLanguageModels() {
val tasks = mutableListOf<Callable<TrainingDataLanguageModel>>()

for (language in languages) {
tasks.add(Callable { loadLanguageModels(unigramLanguageModels, language, 1) })
tasks.add(Callable { loadLanguageModels(bigramLanguageModels, language, 2) })
tasks.add(Callable { loadLanguageModels(trigramLanguageModels, language, 3) })
tasks.add(Callable { loadLanguageModels(quadrigramLanguageModels, language, 4) })
tasks.add(Callable { loadLanguageModels(fivegramLanguageModels, language, 5) })
loadLanguageModels(languageModels, language)
}

threadPool.invokeAll(tasks)
}

private fun createThreadPool(): ExecutorService {
Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/com/github/pemistahl/lingua/internal/Ngram.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ internal value class Ngram(val value: String) : Comparable<Ngram> {
}
}

inline val length get() = value.length

override fun toString() = value

override fun compareTo(other: Ngram) = value.length.compareTo(other.value.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal data class TestDataLanguageModel(val ngrams: Set<Ngram>) {
require(ngramLength in 1..5) {
"ngram length $ngramLength is not in range 1..5"
}
val ngrams = hashSetOf<Ngram>()
val ngrams = mutableSetOf<Ngram>()
for (i in 0..text.length - ngramLength) {
val textSlice = text.substring(i, i + ngramLength)
if (LETTER_REGEX.matches(textSlice)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,21 @@ internal data class TrainingDataLanguageModel(
)
}

fun fromJson(json: String): TrainingDataLanguageModel {
val jsonLanguageModel = Json.decodeFromString<JsonLanguageModel>(json)

val jsonDataSequence = sequence {
for ((fraction, ngrams) in jsonLanguageModel.ngrams) {
val fractionAsFloat = fraction.toFloat()
for (ngram in ngrams.split(' ')) {
yield(ngram to fractionAsFloat)
fun fromJson(language: Language, jsonLanguageModels: Sequence<JsonLanguageModel>): TrainingDataLanguageModel {
val jsonDataSequence =
sequence {
for (jsonLanguageModel in jsonLanguageModels) {
for ((fraction, ngrams) in jsonLanguageModel.ngrams) {
val fractionAsFloat = fraction.toFloat()
for (ngram in ngrams.split(' ')) {
yield(ngram to fractionAsFloat)
}
}
}
}
}

return TrainingDataLanguageModel(
language = jsonLanguageModel.language,
language = language,
absoluteFrequencies = emptyMap(),
relativeFrequencies = emptyMap(),
jsonRelativeFrequencies = RelativeFrequencies.build(jsonDataSequence)
Expand Down Expand Up @@ -145,88 +146,69 @@ internal data class TrainingDataLanguageModel(
}
}

internal class RelativeFrequencies private constructor(data: Map<UByte, Entries>) {
/**
* N-ary search tree.
*/
internal class RelativeFrequencies {

private val entries: Array<Entries?> = Array(256) { data[it.toUByte()] }
var frequency: Float = 0F

operator fun get(ngram: String): Float =
entries[computeHash(ngram).toInt()]?.get(ngram) ?: 0F
private var childKeys = emptyKeys

private class Entries(private val chars: CharArray, private val frequencies: FloatArray) {
private var childValues = emptyValues

val size get() = frequencies.size
operator fun get(ngram: String) = getImpl(ngram, depth = 0)

operator fun get(ngram: String): Float {
// check range before search
var cmp = compareNgram(0, ngram)
if (cmp == 0) return frequencies.first()
if (cmp < 0 || size == 1) return 0F
private operator fun set(ngram: String, frequency: Float) = setImpl(ngram, frequency, depth = 0)

cmp = compareNgram(frequencies.lastIndex, ngram)
if (cmp == 0) return frequencies.last()
if (cmp > 0) return 0F
private fun getImpl(ngram: String, depth: Int): Float {
if (depth == ngram.length) return frequency
val i = childKeys.binarySearch(ngram[depth])
return if (i >= 0) childValues[i].getImpl(ngram, depth + 1) else 0F
}

return search(ngram)
private fun setImpl(ngram: String, frequency: Float, depth: Int) {
if (depth == ngram.length) {
this.frequency = frequency
return
}

private fun search(ngram: String): Float {
// skip edges
var low = 1
var high = size - 2

while (low <= high) {
if (low + 8 < high) {
// bisection search
val middle = (low + high) / 2
val diff = compareNgram(middle, ngram)
if (diff < 0) low = middle + 1
else if (diff > 0) high = middle - 1
else return frequencies[middle]
} else {
// linear search
for (i in low..high) {
if (compareNgram(i, ngram) == 0) return frequencies[i]
return 0F
}
var i = childKeys.binarySearch(ngram[depth])
// insert a new child
if (i < 0) {
i = -i - 1
childKeys = CharArray(childKeys.size + 1) { idx ->
when {
idx < i -> childKeys[idx]
idx > i -> childKeys[idx - 1]
else -> ngram[depth]
}
}
return 0F
}

private fun compareNgram(pos: Int, ngram: String): Int {
val base = pos * ngram.length
repeat(ngram.length) { i ->
val diff = chars[base + i].compareTo(ngram[i])
if (diff != 0) return diff
childValues = Array(childValues.size + 1) { idx ->
when {
idx < i -> childValues[idx]
idx > i -> childValues[idx - 1]
else -> RelativeFrequencies()
}
}
return 0
}

// set value
childValues[i].setImpl(ngram, frequency, depth + 1)
}

companion object {

internal fun build(relativeFrequencies: Sequence<Pair<String, Float>>): RelativeFrequencies {
val entryMap = LinkedHashMap<UByte, MutableMap<String, Float>>()
relativeFrequencies.forEach { (ngram, frequency) ->
val map = entryMap.computeIfAbsent(computeHash(ngram)) { TreeMap() }
map[ngram] = frequency
}
private val emptyKeys = CharArray(0)

val data: Map<UByte, Entries> = entryMap.entries.associateTo(LinkedHashMap()) { (highHash, map) ->
val chars = map.keys.joinToString(separator = "").toCharArray()
val float = map.values.toFloatArray()
highHash to Entries(chars, float)
}
private val emptyValues = emptyArray<RelativeFrequencies>()

return RelativeFrequencies(data)
}

private fun computeHash(ngram: String): UByte {
var hash = ngram.first().code.shr(8).toUByte()
ngram.forEach { c ->
hash = hash.rotateRight(3) xor c.code.toUByte()
internal fun build(relativeFrequencies: Sequence<Pair<String, Float>>): RelativeFrequencies {
val frequencies = RelativeFrequencies()
for ((ngram, frequency) in relativeFrequencies) {
frequencies[ngram] = frequency
}
return hash
return frequencies
}
}
}
Expand Down
Loading

0 comments on commit 2d80e26

Please sign in to comment.