You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
const val SEED = 12L
const val TEST_BATCH_SIZE = 5
const val EPOCHS = 100
const val TRAINING_BATCH_SIZE = 3
const val PATH_TO_MODEL = "models/irisNet"
val model = Sequential.of(
Input(4),
Dense(5, Activations.Relu, kernelInitializer = HeNormal(SEED), biasInitializer = HeUniform(SEED)),
Dense(3, Activations.Linear, kernelInitializer = HeNormal(SEED), biasInitializer = HeUniform(SEED))
)
fun main() {
data.shuffle()
val dataset = OnHeapDataset.create(
::extractX,
::extractY
)
val (train, test) = dataset.split(0.8)
model.use {
it.compile(optimizer = Adam(), loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS, metric = Metrics.ACCURACY)
it.summary()
it.fit(dataset = train, epochs = EPOCHS, batchSize = TRAINING_BATCH_SIZE)
val accuracy = model.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
model.save(modelDirectory = File(PATH_TO_MODEL), savingFormat = SavingFormat.JSON_CONFIG_CUSTOM_VARIABLES, writingMode = WritingMode.OVERRIDE)
println("Accuracy: $accuracy")
}
}
it fails with the
Exception in thread "main" java.lang.IndexOutOfBoundsException: Index 1 out of bounds for length 1
at java.base/jdk.internal.util.Preconditions.outOfBounds(Preconditions.java:64)
at java.base/jdk.internal.util.Preconditions.outOfBoundsCheckIndex(Preconditions.java:70)
at java.base/jdk.internal.util.Preconditions.checkIndex(Preconditions.java:248)
at java.base/java.util.Objects.checkIndex(Objects.java:372)
at java.base/java.util.ArrayList.get(ArrayList.java:459)
at org.jetbrains.kotlinx.dl.api.inference.keras.ModelSaverKt.saveModelConfiguration(ModelSaver.kt:54)
at org.jetbrains.kotlinx.dl.api.inference.keras.ModelSaverKt.saveModelConfiguration$default(ModelSaver.kt:36)
at org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel.saveModel(GraphTrainableModel.kt:794)
at org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel.saveInKerasFormat(GraphTrainableModel.kt:788)
at org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel.save(GraphTrainableModel.kt:783)
at org.jetbrains.kotlinx.dl.api.core.TrainableModel.save$default(TrainableModel.kt:259)
at IrisKt.main(Iris.kt:45)
at IrisKt.main(Iris.kt)
This code only worked for a specific number of dimensions, while being unnecessary since createKerasInputLayer method already sets batch_input_shape field correctly.
Kotlin#160
* Extract createInputLayer function
* Simplify loading model layers
* Allow loading models with any input layer dimensions
#160
* Simplify GraphTrainableModel#serializeModel
* Remove unnecessary code setting batch_input_shape
This code only worked for a specific number of dimensions, while being unnecessary since createKerasInputLayer method already sets batch_input_shape field correctly.
#160
* Test that input layer dimensions are saved and loaded correctly
#160
During saving model, trained on IrisData
it fails with the
due to unfixed TODO in ModelSaver
The text was updated successfully, but these errors were encountered: