Skip to content

Commit

Permalink
Add validation of input shape (Kotlin#385)
Browse files Browse the repository at this point in the history
* Check if the number of elements in the input matches the model's input shape

* Trying to guess if the user used a grayscale image when 3-channels expected
  • Loading branch information
ermolenkodev committed Oct 11, 2022
1 parent efebf7a commit 329b017
Showing 1 changed file with 36 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ public open class OnnxInferenceModel private constructor(private val modelSource
dataType: OnnxJavaType,
shape: LongArray
): OnnxTensor {
checkTensorMatchesInputShape(data, shape)

val inputTensor = when (dataType) {
OnnxJavaType.FLOAT -> OnnxTensor.createTensor(this, FloatBuffer.wrap(data), shape)
OnnxJavaType.DOUBLE -> OnnxTensor.createTensor(
Expand Down Expand Up @@ -320,6 +322,40 @@ public open class OnnxInferenceModel private constructor(private val modelSource
return inputTensor
}

private fun checkTensorMatchesInputShape(data: FloatArray, inputShape: LongArray) {
val numOfElements = inputShape.reduce { acc, dim -> acc * dim }.toInt()

if (data.size == numOfElements) return

if (inputShape.matchNHWC(C = 3) || inputShape.matchNCHW(C = 3)) {
val (height, width) = when {
inputShape.matchNHWC(C = 3) -> inputShape[1] to inputShape[2]
else -> inputShape[2] to inputShape[3]
}

if (data.size.toLong() == height * width) {
throw IllegalArgumentException(
"The number of elements (N=${data.size}) in the input tensor does not match the model input shape - "
.plus("${inputShape.contentToString()}.")
.plus(" It looks like you are trying to use a 1-channel (grayscale) image as an input, but the model expects a 3-channel image.")
)
}
}

throw IllegalArgumentException(
"The number of elements (N=${data.size}) in the input tensor does not match the model input shape - "
.plus("${inputShape.contentToString()}.")
)
}

private fun LongArray.matchNHWC(C: Int): Boolean {
return this.size == 4 && this[0] == 1L && this[3] == C.toLong()
}

private fun LongArray.matchNCHW(C: Int): Boolean {
return this.size == 4 && this[0] == 1L && this[1] == C.toLong()
}

/**
* Loads model from serialized ONNX file.
*/
Expand Down

0 comments on commit 329b017

Please sign in to comment.