Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FeatureRequest] Option to reset already fitted model to its initial state #271

Closed
rychuelektryk opened this issue Oct 30, 2021 · 4 comments · Fixed by #338
Closed

[FeatureRequest] Option to reset already fitted model to its initial state #271

rychuelektryk opened this issue Oct 30, 2021 · 4 comments · Fixed by #338
Labels
enhancement New feature or request
Milestone

Comments

@rychuelektryk
Copy link

Hi,

I would like to have some kind of a clear method on a already learned model to restore it to the unlearned state. I guess it boils down to resetting model weights. Think you could add such an option in future?

My case:
In each iteration of my kfold procedure I would like to start fitting model from the beginning. I could copy the model in each iteration without copying weights but that leads to loosing already set callback and I would be forced to recompile the model for which I would need information about optimizer, metrics and loss which are not accessible from original model.

@zaleslaw
Copy link
Collaborator

zaleslaw commented Nov 1, 2021

Yeah, I see your need. If it's not private, please share your K-Fold

I think it could be solved as a resetting method or option, but could I ask you: do you need to reset the model only because of the problem with callbacks in #270 ?

@zaleslaw zaleslaw added the enhancement New feature or request label Nov 1, 2021
@zaleslaw zaleslaw added this to the 0.4 milestone Nov 1, 2021
@zaleslaw
Copy link
Collaborator

@rychuelektryk could you please comment on this thread?

@rychuelektryk
Copy link
Author

rychuelektryk commented Nov 10, 2021

Hi @zaleslaw, and sorry for late response

This issue is related to the callback issue I've described in other thread. Being unable to reset model weights forces me to make additional model compilation in my kfold method which also lets me set callback.

Below code snippets are simplified for the sake of readability

This is how my current kfold looks like

class KFoldEvaluator : IKFoldEvaluator {
    override fun evaluate(model: Functional, fitDataset: FitDataset, foldCount: Int, trainBatchSize: Int, validationBatchSize: Int, optimizer: Optimizer, loss: Losses, metric: Metrics): KFoldResult {

        fitDataset.toKFoldDatasets(foldCount).forEach { (trainFitDataset, validationFitDataset) ->
            val dumbModel = model.copy(copyWeights = false)

            dumbModel.compile(optimizer, loss, metric)

            dumbModel.fit(trainFitDataset.toOnHeapDataset(), validationFitDataset.toOnHeapDataset(), epochs = Int.MAX_VALUE, trainBatchSize, validationBatchSize)
        }
    }
}

And here is how I would like it to look like

class KFoldEvaluator : IKFoldEvaluator {
    override fun evaluate(model: Functional, fitDataset: FitDataset, foldCount: Int, trainBatchSize: Int, validationBatchSize: Int): KFoldResult {

		val dumbModel = model.copy(copyWeights = false)
		
        fitDataset.toKFoldDatasets(foldCount).forEach { (trainFitDataset, validationFitDataset) ->
            dumbModel.resetWeights()

            dumbModel.fit(trainFitDataset.toOnHeapDataset(), validationFitDataset.toOnHeapDataset(), epochs = Int.MAX_VALUE, trainBatchSize, validationBatchSize)
        }
    }
}

What do you think?

@zaleslaw
Copy link
Collaborator

zaleslaw commented Nov 15, 2021

Regarding reset weights, do you mean that KotlinDL should keep somewhere the initial weights (loaded from h5 model or generated by initializers) or reset the weights leads only to new random initialization without relation to previously initialized weights?

I guess that in the context of usage KFold you don't need a stable copy of initial weights and resetting means just re-run of all initializers to generate new weights. Correct me if I'm wrong

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants