Skip to content

Commit

Permalink
Convert all the model types without constructor parameters to objects (
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva authored Dec 1, 2022
1 parent ee899bf commit fc72cce
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import java.io.File
fun efficientNet4LitePrediction() {
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

val modelType = ONNXModels.CV.EfficientNet4Lite()
val modelType = ONNXModels.CV.EfficientNet4Lite
val model = modelHub.loadModel(modelType)
model.printSummary()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels
* - Special preprocessing (used in ResNet'18 during training on ImageNet dataset) is applied to each image before prediction.
*/
fun resnet18prediction() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet18())
runImageRecognitionPrediction(ONNXModels.CV.ResNet18)
}

/** */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fun resnet18LightAPIPrediction() {
val modelHub =
ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

val model = ONNXModels.CV.ResNet18().pretrainedModel(modelHub)
val model = ONNXModels.CV.ResNet18.pretrainedModel(modelHub)
model.printSummary()

model.use {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ import java.io.File
class ExecutionProvidersTestSuite {
private fun resnetModelsInference(executionProvider: ExecutionProvider) {
val modelsToTest = listOf(
ONNXModels.CV.ResNet101(),
ONNXModels.CV.ResNet101v2(),
ONNXModels.CV.ResNet152(),
ONNXModels.CV.ResNet152v2(),
ONNXModels.CV.ResNet18(),
ONNXModels.CV.ResNet18v2(),
ONNXModels.CV.ResNet34(),
ONNXModels.CV.ResNet34v2(),
ONNXModels.CV.ResNet50(),
ONNXModels.CV.ResNet50v2(),
ONNXModels.CV.ResNet101,
ONNXModels.CV.ResNet101v2,
ONNXModels.CV.ResNet152,
ONNXModels.CV.ResNet152v2,
ONNXModels.CV.ResNet18,
ONNXModels.CV.ResNet18v2,
ONNXModels.CV.ResNet34,
ONNXModels.CV.ResNet34v2,
ONNXModels.CV.ResNet50,
ONNXModels.CV.ResNet50v2,
ONNXModels.CV.ResNet50custom,
)

Expand Down Expand Up @@ -78,7 +78,7 @@ class ExecutionProvidersTestSuite {
@Test
fun executionProvidersDuplicatesTest() {
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))
val model = modelHub.loadModel(ONNXModels.CV.ResNet18())
val model = modelHub.loadModel(ONNXModels.CV.ResNet18)

model.use {
assertDoesNotThrow {
Expand All @@ -90,7 +90,7 @@ class ExecutionProvidersTestSuite {
@Test
fun twoCpuExecutorsWithDifferentAllocatorsTest() {
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))
val model = modelHub.loadModel(ONNXModels.CV.ResNet18())
val model = modelHub.loadModel(ONNXModels.CV.ResNet18)

model.use {
assertThrows<IllegalArgumentException> {
Expand Down
22 changes: 11 additions & 11 deletions examples/src/test/kotlin/examples/onnx/ModelCopyingTestSuite.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ModelCopyingTestSuite {
ONNXModels.CV.EfficientNetB5(),
ONNXModels.CV.EfficientNetB6(),
ONNXModels.CV.EfficientNetB7(),
ONNXModels.CV.EfficientNet4Lite()
ONNXModels.CV.EfficientNet4Lite
),
"datasets/vgg/image0.jpg",
ImageRecognitionModel::predictObject
Expand All @@ -45,17 +45,17 @@ class ModelCopyingTestSuite {
fun resNetCopyTest() {
runCopyTest(
listOf(
ONNXModels.CV.ResNet18(),
ONNXModels.CV.ResNet18v2(),
ONNXModels.CV.ResNet34(),
ONNXModels.CV.ResNet34v2(),
ONNXModels.CV.ResNet50(),
ONNXModels.CV.ResNet50v2(),
ONNXModels.CV.ResNet18,
ONNXModels.CV.ResNet18v2,
ONNXModels.CV.ResNet34,
ONNXModels.CV.ResNet34v2,
ONNXModels.CV.ResNet50,
ONNXModels.CV.ResNet50v2,
ONNXModels.CV.ResNet50custom,
ONNXModels.CV.ResNet101(),
ONNXModels.CV.ResNet101v2(),
ONNXModels.CV.ResNet152(),
ONNXModels.CV.ResNet152v2()
ONNXModels.CV.ResNet101,
ONNXModels.CV.ResNet101v2,
ONNXModels.CV.ResNet152,
ONNXModels.CV.ResNet152v2
),
"datasets/vgg/image0.jpg",
ImageRecognitionModel::predictObject
Expand Down
18 changes: 9 additions & 9 deletions examples/src/test/kotlin/examples/onnx/cv/OnnxResNetTestSuite.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@ class OnnxResNetTestSuite {

@Test
fun resnet18v2predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet18v2())
runImageRecognitionPrediction(ONNXModels.CV.ResNet18v2)
}

@Test
fun resnet34predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet34())
runImageRecognitionPrediction(ONNXModels.CV.ResNet34)
}

@Test
fun resnet34v2predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet34v2())
runImageRecognitionPrediction(ONNXModels.CV.ResNet34v2)
}

@Test
fun resnet50predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet50())
runImageRecognitionPrediction(ONNXModels.CV.ResNet50)
}

@Test
fun resnet50v2predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet50v2())
runImageRecognitionPrediction(ONNXModels.CV.ResNet50v2)
}

@Test
Expand All @@ -61,22 +61,22 @@ class OnnxResNetTestSuite {

@Test
fun resnet101predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet101())
runImageRecognitionPrediction(ONNXModels.CV.ResNet101)
}

@Test
fun resnet101v2predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet101v2())
runImageRecognitionPrediction(ONNXModels.CV.ResNet101v2)
}

@Test
fun resnet152predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet152())
runImageRecognitionPrediction(ONNXModels.CV.ResNet152)
}

@Test
fun resnet152v2predictionTest() {
runImageRecognitionPrediction(ONNXModels.CV.ResNet152v2())
runImageRecognitionPrediction(ONNXModels.CV.ResNet152v2)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/efficientnet-lite4">
* Official EfficientNet4Lite model from ONNX Github.</a>
*/
public class EfficientNet4Lite : CV("efficientnet_lite4", channelsFirst = false) {
public object EfficientNet4Lite : CV("efficientnet_lite4", channelsFirst = false) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = InputType.TF.preprocessing(channelsLast = !channelsFirst)
}
Expand All @@ -83,7 +83,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/efficientnet-lite4">
* Official EfficientNet4Lite model from ONNX Github.</a>
*/
public class MobilenetV1 : CV("mobilenet_v1", channelsFirst = false) {
public object MobilenetV1 : CV("mobilenet_v1", channelsFirst = false) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = pipeline<Pair<FloatArray, TensorShape>>()
.rescale { scalingCoefficient = 255f }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet18 : CV("models/onnx/cv/resnet/resnet18-v1", channelsFirst = true) {
public object ResNet18 : CV("models/onnx/cv/resnet/resnet18-v1", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
}
Expand All @@ -93,7 +93,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet34 : CV("models/onnx/cv/resnet/resnet34-v1", channelsFirst = true) {
public object ResNet34 : CV("models/onnx/cv/resnet/resnet34-v1", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
}
Expand All @@ -115,7 +115,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet50 :
public object ResNet50 :
CV("models/onnx/cv/resnet/resnet50-v1", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -138,7 +138,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet101 :
public object ResNet101 :
CV("models/onnx/cv/resnet/resnet101-v1", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -161,7 +161,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet152 :
public object ResNet152 :
CV("models/onnx/cv/resnet/resnet152-v1", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -184,7 +184,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet18v2 :
public object ResNet18v2 :
CV("models/onnx/cv/resnet/resnet18-v2", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -207,7 +207,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet34v2 :
public object ResNet34v2 :
CV("models/onnx/cv/resnet/resnet34-v2", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -230,7 +230,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet50v2 :
public object ResNet50v2 :
CV("models/onnx/cv/resnet/resnet50-v2", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -253,7 +253,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet101v2 :
public object ResNet101v2 :
CV("models/onnx/cv/resnet/resnet101-v2", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -276,7 +276,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/resnet">
* Official ResNet model from ONNX Github.</a>
*/
public class ResNet152v2 :
public object ResNet152v2 :
CV("models/onnx/cv/resnet/resnet152-v2", channelsFirst = true) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = resNetOnnxPreprocessing()
Expand All @@ -299,7 +299,7 @@ public object ONNXModels {
* @see <a href="https://github.com/onnx/models/tree/main/vision/classification/efficientnet-lite4">
* Official EfficientNet4Lite model from ONNX Github.</a>
*/
public class EfficientNet4Lite :
public object EfficientNet4Lite :
CV("models/onnx/cv/efficientnet/efficientnet-lite4", channelsFirst = false) {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = InputType.TF.preprocessing(channelsLast = !channelsFirst)
Expand Down

0 comments on commit fc72cce

Please sign in to comment.