Skip to content

Commit

Permalink
Introduce KVariable data class to encapsulate all the information abo…
Browse files Browse the repository at this point in the history
…ut a single parameter of the layer (#324)

* Introduce KVariable data class to encapsulate all the information about a single parameter of the layer

* Convert fanIn and fanOut properties to local variables

Co-authored-by: Veniamin Viflyantsev <[email protected]>
  • Loading branch information
juliabeliaeva and knok16 authored Jan 12, 2022
1 parent ebfe99c commit 3dfb90e
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 313 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable

/**
* A class that keeps information about a single parameter of the [Layer].
*
* @property [name] name of the variable
* @property [shape] shape of the variable
* @property [variable] corresponding [Variable] object
* @property [initializerOperand] variable initializer
* @property [regularizer] variable regularizer
*/
public data class KVariable(
val name: String,
val shape: Shape,
val variable: Variable<Float>,
val initializerOperand: Operand<Float>,
val regularizer: Regularizer?
)

internal fun createVariable(
tf: Ops,
kGraph: KGraph,
variableName: String,
isTrainable: Boolean,
shape: Shape,
fanIn: Int,
fanOut: Int,
initializer: Initializer,
regularizer: Regularizer?
): KVariable {
val tfVariable = tf.withName(variableName).variable(shape, getDType())

val initOp = initializer.apply(fanIn, fanOut, tf, tfVariable, variableName)
kGraph.addLayerVariable(tfVariable, isTrainable)
kGraph.addInitializer(variableName, initOp)
if (regularizer != null) kGraph.addVariableRegularizer(tfVariable, regularizer)

return KVariable(
name = variableName,
shape = shape,
variable = tfVariable,
initializerOperand = initOp,
regularizer = regularizer
)
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.core.layer

import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.TrainableModel
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.extension.convertTensorToMultiDimArray
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable

/**
* Base abstract class for all layers.
Expand All @@ -34,12 +31,6 @@ public abstract class Layer(public var name: String) {
/** Model where this layer is used. */
public var parentModel: TrainableModel? = null

/** Returns number of input parameters. */
protected var fanIn: Int = Int.MIN_VALUE

/** Returns number of output parameters. */
protected var fanOut: Int = Int.MIN_VALUE

/** Returns inbound layers. */
public var inboundLayers: MutableList<Layer> = mutableListOf()

Expand Down Expand Up @@ -108,47 +99,24 @@ public abstract class Layer(public var name: String) {
return forward(tf, input[0], isTraining, numberOfLosses)
}

/**
* Adds a new weight tensor to the layer
*
* @param name variable name
* @param variable variable to add
* @return the created variable.
*/
protected fun addWeight(
tf: Ops,
kGraph: KGraph,
name: String,
variable: Variable<Float>,
initializer: Initializer,
regularizer: Regularizer? = null
): Variable<Float> {
// require(fanIn != Int.MIN_VALUE) { "fanIn should be calculated before initialization for variable $name" }
// require(fanOut != Int.MIN_VALUE) { "fanOut should be calculated before initialization for variable $name" }

val initOp = initializer.apply(fanIn, fanOut, tf, variable, name)
kGraph.addLayerVariable(variable, isTrainable)
kGraph.addInitializer(name, initOp)
if (regularizer != null) kGraph.addVariableRegularizer(variable, regularizer)
return variable
}

/** Important part of functional API. It takes [layers] as input and saves them to the [inboundLayers] of the given layer. */
public operator fun invoke(vararg layers: Layer): Layer {
inboundLayers = layers.toMutableList()
return this
}

/** Extract weights values by variable names. */
protected fun extractWeights(variableNames: List<String>): Map<String, Array<*>> {
/** Extract weights values for provided variables. */
protected fun extractWeights(vararg variables: KVariable?): Map<String, Array<*>> {
require(parentModel != null) { "Layer $name is not related to any model!" }

val session = parentModel!!.session
val runner = session.runner()

val variableNames = variables.mapNotNull { it?.name }
for (variableName in variableNames) {
runner.fetch(variableName)
}

val weights = runner.run().map { it.convertTensorToMultiDimArray() }
return variableNames.zip(weights).toMap()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

Expand All @@ -8,14 +8,14 @@ package org.jetbrains.kotlinx.dl.api.core.layer.activation
import org.jetbrains.kotlinx.dl.api.core.KGraph
import org.jetbrains.kotlinx.dl.api.core.initializer.Initializer
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros
import org.jetbrains.kotlinx.dl.api.core.layer.KVariable
import org.jetbrains.kotlinx.dl.api.core.layer.createVariable
import org.jetbrains.kotlinx.dl.api.core.regularizer.Regularizer
import org.jetbrains.kotlinx.dl.api.core.shape.numElements
import org.jetbrains.kotlinx.dl.api.core.shape.toLongArray
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.Shape
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Variable

/**
* Parametric Rectified Linear Unit.
Expand All @@ -40,15 +40,16 @@ public class PReLU(
/**
* TODO: support for constraint (alphaConstraint) should be added
*/
private lateinit var alphaShape: Shape
private lateinit var alpha: Variable<Float>
private val alphaVariableName = if (name.isNotEmpty()) name + "_" + "alpha" else "alpha"

private lateinit var alpha: KVariable
private fun alphaVariableName(): String =
if (name.isNotEmpty()) "${name}_alpha" else "alpha"

override var weights: Map<String, Array<*>>
get() = extractWeights(listOf(alphaVariableName))
get() = extractWeights(alpha)
set(value) = assignWeights(value)
override val paramCount: Int
get() = alphaShape.numElements().toInt()
get() = alpha.shape.numElements().toInt()

init {
isTrainable = true
Expand All @@ -61,19 +62,28 @@ public class PReLU(
alphaShapeArray[axis - 1] = 1
}
}
alphaShape = Shape.make(alphaShapeArray[0], *alphaShapeArray.drop(1).toLongArray())

fanIn = inputShape.size(inputShape.numDimensions() - 1).toInt()
fanOut = fanIn
val fanIn = inputShape.size(inputShape.numDimensions() - 1).toInt()
val fanOut = fanIn

alpha = tf.withName(alphaVariableName).variable(alphaShape, getDType())
alpha = addWeight(tf, kGraph, alphaVariableName, alpha, alphaInitializer, alphaRegularizer)
val alphaShape = Shape.make(alphaShapeArray[0], *alphaShapeArray.drop(1).toLongArray())
alpha = createVariable(
tf,
kGraph,
alphaVariableName(),
isTrainable,
alphaShape,
fanIn,
fanOut,
alphaInitializer,
alphaRegularizer
)
}

override fun forward(tf: Ops, input: Operand<Float>): Operand<Float> {
// It's equivalent to: `-alpha * relu(-x) + relu(x)`
val positive = tf.nn.relu(input)
val negative = tf.math.mul(tf.math.neg(alpha), tf.nn.relu(tf.math.neg(input)))
val negative = tf.math.mul(tf.math.neg(alpha.variable), tf.nn.relu(tf.math.neg(input)))
return tf.math.add(positive, negative)
}

Expand Down
Loading

0 comments on commit 3dfb90e

Please sign in to comment.