-
-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add metadata field to Learner
#121
base: master
Are you sure you want to change the base?
Conversation
Looks good to me 👍 out of curiosity, what kind of training do you want to use this for? This is meant for state that belongs to the training loop and not any callback, right? Can you add a short CHANGELOG entry? |
Not state, but a hyper-parameter that belongs to the training loop and not a callback. I basically have a setup where each batch is time series that computes a loss and pseudo-gradient on each time step. These parameters control how my method updates the weights within this loop. Normally, hyper-parameters are either part of the loss or optimizer and can be either statically closed over or scheduled. In this case, the hyper-parameter belongs to neither the loss nor the optimizer but the actual training step code. |
Are you using a custom training step? Then it's also possible to add a field to the Well, this will be useful anyway. |
Ah I didn't think of that. At least for my case, adding a field to the phase will be much more intuitive, so I probably won't use this feature. I'll leave it up to you if you think it is still worth adding. |
Hm. I think I'll leave this unmerged until someone comes with a use case where adding a field to the phase doesn't work. Where possible, that should be the preferred way. |
I found a potential use case for this: anything stored in the phase struct can't be scheduled as a hyper-parameter. Either hyper-parameters should be extended to include the phase or the learner will need to store this information. |
I'd prefer passing sethyperparameter!(learner, ::Type{<:HyperParameter}, value) into sethyperparameter!(learner, ::Type{<:HyperParameter}, ::Phase, value) with a default method to make it non-breaking: sethyperparameter!(learner, T::Type{<:HyperParameter}, ::Phase, value) = sethyperparameter!(learner, T, value) The only other thing that would need to be changed is this line to add the
|
This adds a "metadata"
PropDict
toLearner
for storing information that is required for training but extraneous to the training state or callback state. This is useful for unconventional training methods (issue that I am currently dealing with). In the same way that the loss function is a "parameter" that needs to be specified to standard supervised training, the metadata field holds parameters that need to be specified for unconventional training. Of course, we can't know what these parameters will be like standard training, so instead of explicit names, we provide a container to hold them.