Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of the preprocessing DSL (#416) #425

Merged
merged 14 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.jetbrains.kotlinx.dl.api.inference.keras.loaders

import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import java.io.File

/**
* Basic interface for models loaded from S3.
Expand Down Expand Up @@ -46,4 +45,4 @@ public interface ModelType<T : InferenceModel, U : InferenceModel> {
public fun model(modelHub: ModelHub): T {
return modelHub.loadModel(this)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@

package org.jetbrains.kotlinx.dl.api.inference.imagerecognition

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.core.util.loadImageNetClassLabels
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelType
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.predictTopKImageNetLabels
import org.jetbrains.kotlinx.dl.dataset.DataLoader
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
import org.jetbrains.kotlinx.dl.dataset.preprocessor.*
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.InterpolationType
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.convert
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.resize
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.toFloatArray
import java.awt.image.BufferedImage
import java.io.File

/**
Expand Down Expand Up @@ -90,16 +95,14 @@ public class ImageRecognitionModel(
else
Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1])

val preprocessing: Preprocessing = preprocess {
transformImage {
resize {
outputHeight = height.toInt()
outputWidth = width.toInt()
interpolation = InterpolationType.BILINEAR
}
convert { colorMode = modelType.inputColorMode }
val preprocessing = pipeline<BufferedImage>()
.resize {
outputHeight = height.toInt()
outputWidth = width.toInt()
interpolation = InterpolationType.BILINEAR
}
}
.convert { colorMode = modelType.inputColorMode }
.toFloatArray {}

return modelType.preprocessInput(imageFile, preprocessing)
}
Expand Down Expand Up @@ -137,7 +140,7 @@ public class ImageRecognitionModel(
*
* It takes preprocessing pipeline, invoke it and applied the specific preprocessing to the given data.
*/
public fun ModelType<*, *>.preprocessInput(imageFile: File, preprocessing: Preprocessing): FloatArray {
public fun ModelType<*, *>.preprocessInput(imageFile: File, preprocessing: Operation<BufferedImage, Pair<FloatArray, TensorShape>>): FloatArray {
val (data, shape) = preprocessing.dataLoader().load(imageFile)
return preprocessInput(data, shape.dims())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,4 +246,9 @@ public fun getDimsOfArray(data: kotlin.Array<*>): LongArray {
/**
* @see getDimsOfArray
*/
public val Array<*>.tensorShape: TensorShape get() = TensorShape(getDimsOfArray(this))
public val Array<*>.tensorShape: TensorShape get() = TensorShape(getDimsOfArray(this))

/**
* Wraps an IntArray to TensorShape.
*/
public fun IntArray.toTensorShape(): TensorShape = TensorShape(this.map(Int::toLong).toLongArray())
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.jetbrains.kotlinx.dl.dataset.preprocessing

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape

/**
* The aim of this class is to provide common functionality for all [Operation]s that can be applied to Pair<FloatArray, TensorShape>
* and simplify the implementation of a new [Operation]s.
*/
public abstract class FloatArrayOperation: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>> {
protected abstract fun applyImpl(data: FloatArray, shape: TensorShape): FloatArray

override fun apply(input: Pair<FloatArray, TensorShape>): Pair<FloatArray, TensorShape> {
val (data, shape) = input
return applyImpl(data, shape) to getOutputShape(shape)
}

override fun getOutputShape(inputShape: TensorShape): TensorShape = inputShape
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.dataset.preprocessing

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape

/** Applies [Normalizing] preprocessor to the tensor to normalize it with given mean and std values. */
public fun<I> Operation<I, Pair<FloatArray, TensorShape>>.normalize(block: Normalizing.() -> Unit): Operation<I, Pair<FloatArray, TensorShape>> {
return PreprocessingPipeline(this, Normalizing().apply(block))
}

/** Applies [Rescaling] preprocessor to the tensor to scale each value by a given coefficient. */
public fun<I> Operation<I, Pair<FloatArray, TensorShape>>.rescale(block: Rescaling.() -> Unit): Operation<I, Pair<FloatArray, TensorShape>> {
return PreprocessingPipeline(this, Rescaling().apply(block))
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.dl.dataset.preprocessor
package org.jetbrains.kotlinx.dl.dataset.preprocessing

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import kotlin.math.sqrt

/**
Expand All @@ -11,12 +12,11 @@ import kotlin.math.sqrt
* @property [mean] an array of mean values for each channel.
* @property [std] an array of std values for each channel.
*/
public class Normalizing : Preprocessor {
public class Normalizing : FloatArrayOperation() {
public lateinit var mean: FloatArray
public lateinit var std: FloatArray

override fun apply(data: FloatArray, inputShape: ImageShape): FloatArray {
val channels = inputShape.channels!!.toInt()
override fun applyImpl(data: FloatArray, shape: TensorShape): FloatArray {
val channels = shape.tail().last().toInt()
require(mean.size == channels) {
"Expected to get one mean value for each image channel. " +
"However ${mean.size} values was given for image with $channels channels."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.jetbrains.kotlinx.dl.dataset.preprocessing

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape

/**
* Interface for preprocessing operations.
* @param I is a type of an input.
* @param O is a type of an output.
*/
public interface Operation<I, O> {
/**
* Performs preprocessing operation on the input.
* @param input is an input to the operation of type [I].
* @return an output of the operation of type [O].
*/
public fun apply(input: I): O
ermolenkodev marked this conversation as resolved.
Show resolved Hide resolved
/**
* Returns shape of the output of the operation having input of shape [inputShape].
* @param inputShape is a shape of the input.
*/
public fun getOutputShape(inputShape: TensorShape): TensorShape
}

/**
* Identity operation which does nothing.
*/
public class Identity<I> : Operation<I, I> {
override fun apply(input: I): I = input
override fun getOutputShape(inputShape: TensorShape): TensorShape = inputShape
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.jetbrains.kotlinx.dl.dataset.preprocessing

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape

/**
* Convenience functions for executing custom logic after applying [Operation].
* Could be useful for debugging purposes.
*/
public fun <I, O> Operation<I,O>.onResult(block: (O) -> Unit): Operation<I, O> {
return PreprocessingPipeline(this, object : Operation<O, O> {
override fun apply(input: O): O {
try {
block(input)
} finally {
return input
}
}
override fun getOutputShape(inputShape: TensorShape): TensorShape = inputShape
})
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.jetbrains.kotlinx.dl.dataset.preprocessing

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape

/**
* This class is a special type of Operation which is used to build typesafe pipeline of preprocessing operations.
* Kudos to [@juliabeliaeva](https://github.com/juliabeliaeva) for the idea.
ermolenkodev marked this conversation as resolved.
Show resolved Hide resolved
*/
public class PreprocessingPipeline<I, M, O>(
private val firstOp: Operation<I, M>,
private val secondOp: Operation<M, O>
) : Operation<I, O> {
override fun apply(input: I): O = secondOp.apply(firstOp.apply(input))
override fun getOutputShape(inputShape: TensorShape): TensorShape {
return secondOp.getOutputShape(firstOp.getOutputShape(inputShape))
}
}

/**
* An entry point for building preprocessing pipeline.
*/
public fun<I> pipeline(): Identity<I> = Identity()
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.dataset.preprocessor
package org.jetbrains.kotlinx.dl.dataset.preprocessing

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape


/**
* This preprocessor defines the Rescaling operation.
* It scales each pixel pixel_i = pixel_i / [scalingCoefficient].
*
* @property [scalingCoefficient] Scaling coefficient.
*/
public class Rescaling(public var scalingCoefficient: Float = 255f) : Preprocessor {
override fun apply(data: FloatArray, inputShape: ImageShape): FloatArray {
public class Rescaling(public var scalingCoefficient: Float = 255f) : FloatArrayOperation() {
override fun applyImpl(data: FloatArray, shape: TensorShape): FloatArray {
for (i in data.indices) {
data[i] = data[i] / scalingCoefficient
}

return data
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

package org.jetbrains.kotlinx.dl.dataset

import org.jetbrains.kotlinx.dl.dataset.preprocessor.Preprocessing
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.preprocessor.dataLoader
import org.jetbrains.kotlinx.dl.dataset.preprocessor.generator.LabelGenerator
import org.jetbrains.kotlinx.dl.dataset.preprocessor.generator.LabelGenerator.Companion.prepareY
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.ConvertToFloatArray
import java.awt.image.BufferedImage
import java.io.File
import java.io.IOException
import java.nio.FloatBuffer
Expand Down Expand Up @@ -115,7 +118,7 @@ public class OnFlyImageDataset<D> internal constructor(
public fun create(
pathToData: File,
labels: FloatArray,
preprocessing: Preprocessing = Preprocessing()
preprocessing: Operation<BufferedImage, Pair<FloatArray, TensorShape>> = ConvertToFloatArray()
): OnFlyImageDataset<File> {
return try {
OnFlyImageDataset(OnHeapDataset.prepareFileNames(pathToData), labels, preprocessing.dataLoader())
Expand All @@ -131,7 +134,7 @@ public class OnFlyImageDataset<D> internal constructor(
public fun create(
pathToData: File,
labelGenerator: LabelGenerator<File>,
preprocessing: Preprocessing = Preprocessing()
preprocessing: Operation<BufferedImage, Pair<FloatArray, TensorShape>> = ConvertToFloatArray()
): OnFlyImageDataset<File> {
return try {
val xFiles = OnHeapDataset.prepareFileNames(pathToData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

package org.jetbrains.kotlinx.dl.dataset

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.dataset.DataLoader.Companion.prepareX
import org.jetbrains.kotlinx.dl.dataset.preprocessor.Preprocessing
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.preprocessor.dataLoader
import org.jetbrains.kotlinx.dl.dataset.preprocessor.generator.LabelGenerator
import org.jetbrains.kotlinx.dl.dataset.preprocessor.generator.LabelGenerator.Companion.prepareY
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.ConvertToFloatArray
import java.awt.image.BufferedImage
import java.io.File
import java.io.IOException
import java.nio.FloatBuffer
Expand Down Expand Up @@ -212,7 +215,7 @@ public class OnHeapDataset internal constructor(public val x: Array<FloatArray>,
public fun create(
pathToData: File,
labels: FloatArray,
preprocessing: Preprocessing = Preprocessing()
preprocessing: Operation<BufferedImage, Pair<FloatArray, TensorShape>>
): OnHeapDataset {
return try {
val xFiles = prepareFileNames(pathToData)
Expand All @@ -231,7 +234,7 @@ public class OnHeapDataset internal constructor(public val x: Array<FloatArray>,
public fun create(
pathToData: File,
labelGenerator: LabelGenerator<File>,
preprocessing: Preprocessing = Preprocessing()
preprocessing: Operation<BufferedImage, Pair<FloatArray, TensorShape>> = ConvertToFloatArray()
): OnHeapDataset {
return try {
val xFiles = prepareFileNames(pathToData)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.dl.dataset.image

import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape
import org.jetbrains.kotlinx.dl.dataset.preprocessor.ImageShape
import java.awt.Graphics2D
import java.awt.image.BufferedImage
Expand All @@ -21,4 +22,8 @@ internal fun BufferedImage.copy(): BufferedImage {

internal fun BufferedImage.getShape(): ImageShape {
return ImageShape(width.toLong(), height.toLong(), colorModel.numComponents.toLong())
}
}

internal fun BufferedImage.getTensorShape(): TensorShape {
return TensorShape(width.toLong(), height.toLong(), colorModel.numComponents.toLong())
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,18 @@ public data class ImageShape(
}
}
}

/**
* Convenience function to create an [ImageShape] from [TensorShape].
*/
public fun TensorShape.toImageShape(): ImageShape {
ermolenkodev marked this conversation as resolved.
Show resolved Hide resolved
val width = if (this[0] == -1L) null else this[0]
val height = if (this[1] == -1L) null else this[1]
val channels = if (this[2] == -1L) null else this[2]

return when (this.rank()) {
2 -> ImageShape(width, height, 1)
3 -> ImageShape(width, height, channels)
else -> throw IllegalArgumentException("Tensor shape must be 2D or 3D to be converted to ImageShape")
}
}
Loading