Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix initializers import and export #348

Merged
merged 10 commits into from
May 18, 2022
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -11,12 +11,10 @@ import org.tensorflow.op.Ops
/**
* Initializer that generates tensors with constant values.
*
* NOTE: It does not work properly during model import/export, known issue: https://github.com/zaleslaw/Kotof/issues/4.
*
* @property constantValue Constant value to fill the tensor.
juliabeliaeva marked this conversation as resolved.
Show resolved Hide resolved
* @constructor Creates a [Constant] initializer with a given [constantValue].
*/
public class Constant(private val constantValue: Float) : Initializer() {
public class Constant(public val constantValue: Float) : Initializer() {
override fun initialize(
fanIn: Int,
fanOut: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -10,8 +10,6 @@ import org.tensorflow.op.Ops

/**
* Initializer that generates tensors initialized to 1.
*
* NOTE: It does not work properly during model import/export, known issue: https://github.com/zaleslaw/Kotof/issues/4.
*/
public class Ones : Initializer() {
override fun initialize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import kotlin.math.min
*/

public class Orthogonal(
private val gain: Float = 1.0f,
private val seed: Long = 12L
public val gain: Float = 1.0f,
public val seed: Long = 12L
) : Initializer() {
override fun initialize(
fanIn: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -20,11 +20,11 @@ import org.tensorflow.op.random.ParameterizedTruncatedNormal
* @constructor Creates a [ParametrizedTruncatedNormal] initializer.
*/
public class ParametrizedTruncatedNormal(
private val mean: Float = 0.0f,
private val stdev: Float = 1.0f,
private val p1: Float = -10.0f, // low level edge
private val p2: Float = 10.0f, // high level edge
private val seed: Long
internal val mean: Float = 0.0f,
internal val stdev: Float = 1.0f,
internal val p1: Float = -10.0f, // low level edge
internal val p2: Float = 10.0f, // high level edge
internal val seed: Long
) :
Initializer() {
override fun initialize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.tensorflow.op.random.TruncatedNormal
* @property seed Seed.
* @constructor Creates [TruncatedNormal] initializer.
*/
public class TruncatedNormal(private val seed: Long = 12L) :
public class TruncatedNormal(public val seed: Long = 12L) :
Initializer() {
override fun initialize(
fanIn: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -11,8 +11,6 @@ import org.tensorflow.op.Ops

/**
* Initializer that generates tensors initialized to 0.
*
* NOTE: It does not work properly during model import/export, known issue: https://github.com/zaleslaw/Kotof/issues/4.
*/
public class Zeros : Initializer() {
override fun initialize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ internal const val INITIALIZER_ONES: String = "Ones"
internal const val INITIALIZER_RANDOM_NORMAL: String = "RandomNormal"
internal const val INITIALIZER_RANDOM_UNIFORM: String = "RandomUniform"
internal const val INITIALIZER_TRUNCATED_NORMAL: String = "TruncatedNormal"
internal const val INITIALIZER_PARAMETRIZED_TRUNCATED_NORMAL: String = "ParametrizedTruncatedNormal"
internal const val INITIALIZER_CONSTANT: String = "Constant"
internal const val INITIALIZER_VARIANCE_SCALING: String = "VarianceScaling"
internal const val INITIALIZER_IDENTITY: String = "Identity"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,8 @@ private fun convertToRegularizer(regularizer: KerasRegularizer?): Regularizer? {
}

private fun convertToInitializer(initializer: KerasInitializer): Initializer {
val seed = if (initializer.config!!.seed != null) {
initializer.config.seed!!.toLong()
} else 12L
val config = initializer.config
val seed = config!!.seed?.toLong() ?: 12L

return when (initializer.class_name!!) {
INITIALIZER_GLOROT_UNIFORM -> GlorotUniform(seed = seed)
Expand All @@ -258,36 +257,30 @@ private fun convertToInitializer(initializer: KerasInitializer): Initializer {
INITIALIZER_HE_UNIFORM -> HeUniform(seed = seed)
INITIALIZER_LECUN_NORMAL -> LeCunNormal(seed = seed)
INITIALIZER_LECUN_UNIFORM -> LeCunUniform(seed = seed)
INITIALIZER_ZEROS -> RandomUniform(
seed = seed,
minVal = 0.0f,
maxVal = 0.0f
) // instead of real initializers, because it doesn't influence on nothing
INITIALIZER_CONSTANT -> RandomUniform(
seed = seed,
minVal = 0.0f,
maxVal = 0.0f
) // instead of real initializers, because it doesn't influence on nothing
INITIALIZER_ONES -> RandomUniform(
seed = seed,
minVal = 1.0f,
maxVal = 1.0f
) // instead of real initializers, because it doesn't influence on nothing*/
INITIALIZER_RANDOM_NORMAL -> RandomNormal(
seed = seed,
mean = initializer.config.mean!!.toFloat(),
stdev = initializer.config.stddev!!.toFloat()
mean = config.mean!!.toFloat(),
stdev = config.stddev!!.toFloat()
)
INITIALIZER_RANDOM_UNIFORM -> RandomUniform(
seed = seed,
minVal = initializer.config.minval!!.toFloat(),
maxVal = initializer.config.maxval!!.toFloat()
minVal = config.minval!!.toFloat(),
maxVal = config.maxval!!.toFloat()
)
INITIALIZER_TRUNCATED_NORMAL -> TruncatedNormal(seed = seed)
INITIALIZER_VARIANCE_SCALING -> convertVarianceScalingInitializer(initializer)
INITIALIZER_ORTHOGONAL -> Orthogonal(seed = seed, gain = initializer.config.gain!!.toFloat())
/*INITIALIZER_CONSTANT -> Constant(initializer.config.value!!.toFloat())*/
INITIALIZER_IDENTITY -> Identity(initializer.config.gain?.toFloat() ?: 1f)
INITIALIZER_TRUNCATED_NORMAL -> TruncatedNormal(seed = seed)
INITIALIZER_PARAMETRIZED_TRUNCATED_NORMAL -> ParametrizedTruncatedNormal(
mean = config.mean!!.toFloat(),
stdev = config.stddev!!.toFloat(),
p1 = config.p1!!.toFloat(),
p2 = config.p2!!.toFloat(),
seed = seed
)
INITIALIZER_ORTHOGONAL -> Orthogonal(seed = seed, gain = config.gain!!.toFloat())
INITIALIZER_ZEROS -> Zeros()
INITIALIZER_ONES -> Ones()
INITIALIZER_CONSTANT -> Constant(config.value!!.toFloat())
INITIALIZER_IDENTITY -> Identity(config.gain?.toFloat() ?: 1f)
else -> throw IllegalStateException("${initializer.class_name} is not supported yet!")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,39 +157,37 @@ private fun convertToKerasRegularizer(regularizer: Regularizer?): KerasRegulariz
}

private fun convertToKerasInitializer(initializer: Initializer, isKerasFullyCompatible: Boolean): KerasInitializer {
val className: String
val config: KerasInitializerConfig
if (isKerasFullyCompatible) {
val (_className, _config) = when (initializer) {
is GlorotUniform -> convertToVarianceScalingInitializer(initializer as VarianceScaling)
is GlorotNormal -> convertToVarianceScalingInitializer(initializer as VarianceScaling)
is HeNormal -> convertToVarianceScalingInitializer(initializer as VarianceScaling)
is HeUniform -> convertToVarianceScalingInitializer(initializer as VarianceScaling)
is LeCunNormal -> convertToVarianceScalingInitializer(initializer as VarianceScaling)
is LeCunUniform -> convertToVarianceScalingInitializer(initializer as VarianceScaling)
is RandomUniform -> convertToRandomUniformInitializer(initializer)
is RandomNormal -> convertToRandomNormalInitializer(initializer)
is Identity -> convertToIdentityInitializer(initializer)
else -> throw IllegalStateException("${initializer::class.simpleName} is not supported yet!")
// TODO: support Constant initializer
val (className, config) = when (initializer) {
is VarianceScaling -> {
if (isKerasFullyCompatible) {
convertToVarianceScalingInitializer(initializer)
} else {
when (initializer) {
is GlorotUniform -> INITIALIZER_GLOROT_UNIFORM
is GlorotNormal -> INITIALIZER_GLOROT_NORMAL
is HeNormal -> INITIALIZER_HE_NORMAL
is HeUniform -> INITIALIZER_HE_UNIFORM
is LeCunNormal -> INITIALIZER_LECUN_NORMAL
is LeCunUniform -> INITIALIZER_LECUN_UNIFORM
else -> throw IllegalStateException("Exporting ${initializer::class.simpleName} is not supported yet.")
} to KerasInitializerConfig(seed = initializer.seed.toInt())
}
}

className = _className
config = _config
} else {
className = when (initializer) {
is GlorotUniform -> INITIALIZER_GLOROT_UNIFORM
is GlorotNormal -> INITIALIZER_GLOROT_NORMAL
is HeNormal -> INITIALIZER_HE_NORMAL
is HeUniform -> INITIALIZER_HE_UNIFORM
is LeCunNormal -> INITIALIZER_LECUN_NORMAL
is LeCunUniform -> INITIALIZER_LECUN_UNIFORM
is Identity -> INITIALIZER_IDENTITY
else -> throw IllegalStateException("${initializer::class.simpleName} is not supported yet!")
is RandomUniform -> convertToRandomUniformInitializer(initializer)
is RandomNormal -> convertToRandomNormalInitializer(initializer)
is TruncatedNormal -> INITIALIZER_TRUNCATED_NORMAL to KerasInitializerConfig(seed = initializer.seed.toInt())
is ParametrizedTruncatedNormal -> {
if (isKerasFullyCompatible) {
throw throw IllegalStateException("Exporting ${initializer::class.simpleName} is not supported in the fully compatible mode.")
} else convertToParametrizedTruncatedNormalInitializer(initializer)
}
config = KerasInitializerConfig(seed = 12)
is Orthogonal -> convertToOrthogonalInitializer(initializer)
is Zeros -> INITIALIZER_ZEROS to KerasInitializerConfig()
is Ones -> INITIALIZER_ONES to KerasInitializerConfig()
is Constant -> INITIALIZER_CONSTANT to KerasInitializerConfig(value = initializer.constantValue.toDouble())
is Identity -> convertToIdentityInitializer(initializer)
else -> throw IllegalStateException("Exporting ${initializer::class.simpleName} is not supported yet.")
}

return KerasInitializer(class_name = className, config = config)
}

Expand Down Expand Up @@ -233,6 +231,18 @@ private fun convertToIdentityInitializer(initializer: Identity): Pair<String, Ke
)
}

private fun convertToParametrizedTruncatedNormalInitializer(initializer: ParametrizedTruncatedNormal): Pair<String, KerasInitializerConfig> {
return Pair(
INITIALIZER_PARAMETRIZED_TRUNCATED_NORMAL, KerasInitializerConfig(
mean = initializer.mean.toDouble(),
stddev = initializer.stdev.toDouble(),
p1 = initializer.p1.toDouble(),
p2 = initializer.p2.toDouble(),
seed = initializer.seed.toInt()
)
)
}

private fun convertDistribution(distribution: Distribution): String {
return when (distribution) {
Distribution.TRUNCATED_NORMAL -> "truncated_normal"
Expand All @@ -249,6 +259,15 @@ private fun convertMode(mode: Mode): String {
}
}

private fun convertToOrthogonalInitializer(initializer: Orthogonal): Pair<String, KerasInitializerConfig> {
return Pair(
INITIALIZER_ORTHOGONAL, KerasInitializerConfig(
gain = initializer.gain.toDouble(),
seed = initializer.seed.toInt()
)
)
}

private fun convertToKerasPadding(padding: ConvPadding): KerasPadding {
return when (padding) {
ConvPadding.SAME -> KerasPadding.Same
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -25,7 +25,11 @@ internal data class KerasInitializerConfig(
@Json(serializeNull = false)
val stddev: Double? = null,
@Json(serializeNull = false)
val value: Int? = null,
val value: Double? = null,
@Json(serializeNull = false)
val gain: Double? = null
val gain: Double? = null,
@Json(serializeNull = false)
val p1: Double? = null,
@Json(serializeNull = false)
val p2: Double? = null
)
Loading