Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor branching structure for finetuning vs. resuming training #229

Merged
merged 29 commits into from
Sep 24, 2022

Conversation

ejm714
Copy link
Collaborator

@ejm714 ejm714 commented Sep 16, 2022

Refactoring in instantiate_model

Refactors the logic in instantiate_model to branch on is_subset to be more explicit. The previous logic tried to group the cases for replacing the head vs. resuming training.

# replace the head
elif not predict_all_zamba_species or not is_subset:
...
# resume training
elif is_subset:
...

This new code change should not change the behavior, but should be more readable. The branching structure is now:

if is_subset:
  if predict_all_zamba_species:
    resume_training()
  else:
    replace_head()

else:
  replace_head()

Validation for is_subset and predict_all_zamba_species

This also allows us to enforce that is_subset (based on the model species and the labels species) is in agreement with predict_all_zamba_species. Notes:

  • predict_all_zamba_species is True by default. We assume that if you're finetuning on a subset of data, you still want all the zamba species. If you don't want this (e.g. you want duikers, elephants, and blanks only), you need to set this to False. This behavior is the same as before.
  • We now only validate predict_all_zamba_species if you set a value (we make this field an optional boolean and use a pre validator).
  • If you provide labels that are not a subset (e.g. cats and dogs) and do not specify predict_all_zamba_labels, we'll set predict_all_zamba_labels to False for you since your only option is a new head with a subset. This is the same result as before, where providing a subset yielded a new head. The difference is that before you could have set conflicting values, but not being a subset would have superseded the predict_all_zamba_labels field.

Additional changes

  • renames predict_all_zamba_species to use_default_model_classes to apply to the blank nonblank model as well
  • fixes some copy pasta errors in the train tutorial

Outstanding:

  • add test for new validator

Closes #212
Closes https://github.com/drivendataorg/pjmf-zamba/issues/130

@github-actions
Copy link
Contributor

github-actions bot commented Sep 16, 2022

@ejm714
Copy link
Collaborator Author

ejm714 commented Sep 16, 2022

This will unfortunately break a number of our tests because we use an African Forest label asset everywhere, which will trigger as "not a subset" with this validation check and then conflict with predict_all_zamba_species = True. We'll need to decide the best way forward here, potentially using model specific label files or mocking certain functions.

Edit: we can let predict_all_zamba_species be an optional boolean and only validate if a value is provided using pre validators.

@ejm714 ejm714 force-pushed the refactor-instantiation-hierarchy branch 2 times, most recently from 1237dc2 to e559499 Compare September 20, 2022 16:22
Copy link
Member

@pjbull pjbull left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic looks good to me pending tests!

zamba/models/config.py Outdated Show resolved Hide resolved
zamba/models/model_manager.py Outdated Show resolved Hide resolved
@ejm714 ejm714 force-pushed the refactor-instantiation-hierarchy branch from e559499 to 8d657be Compare September 23, 2022 02:53
@ejm714 ejm714 changed the base branch from blank-non-blank to master September 23, 2022 02:53
@netlify
Copy link

netlify bot commented Sep 23, 2022

Deploy Preview for silly-keller-664934 ready!

Name Link
🔨 Latest commit 7ce62f1
🔍 Latest deploy log https://app.netlify.com/sites/silly-keller-664934/deploys/632e52e0ecdd7d00098fb81b
😎 Deploy Preview https://deploy-preview-229--silly-keller-664934.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site settings.

@codecov-commenter
Copy link

codecov-commenter commented Sep 23, 2022

Codecov Report

Merging #229 (7ce62f1) into master (291b34f) will increase coverage by 0.1%.
The diff coverage is 94.8%.

Additional details and impacted files
@@           Coverage Diff            @@
##           master    #229     +/-   ##
========================================
+ Coverage    87.0%   87.2%   +0.1%     
========================================
  Files          28      28             
  Lines        1931    1961     +30     
========================================
+ Hits         1680    1710     +30     
  Misses        251     251             
Impacted Files Coverage Δ
zamba/models/publish_models.py 0.0% <ø> (ø)
zamba/models/model_manager.py 84.7% <91.1%> (+0.7%) ⬆️
zamba/models/config.py 96.9% <100.0%> (+0.1%) ⬆️
zamba/models/utils.py 100.0% <100.0%> (ø)

@ejm714
Copy link
Collaborator Author

ejm714 commented Sep 23, 2022

@pjbull this required quite a bit more fiddling to get everything working again. couple changes:

  • validator is regular instead of pre and the optional use_default_model_labels boolean defaults to None; we set to True if not set and labels are a subset
  • further simplified instantiate_model by doing the checkpoint download when we're validating the checkpoint anyway
  • first pass at standardizing how we get the default hparams for a model; basically, this looks like using the checkpoint where we have it (since we assign the public checkpoint for default models in validation), otherwise using the hparams file (only relevant where training from scratch and we explicitly don't set a checkpoint). This approach is useful in that the checkpoint consistently supersedes the base model
    • for example, you train a time distributed model with cats, dogs, and humans
    • then you go to finetune this with labels for cats and dogs; we should load the hparams from the checkpoint you specify, not the base model to check subsets and such

There are still improvements that could be made (e.g. instantiate model should probably just take the config object, we could better centralize the logic for looking up hparams), but I think this refactor is already at risk of scope creep and accomplishes the main task: ensuring that existing configuration parameters work with the new binary model.

Ready for another look

@ejm714 ejm714 requested a review from pjbull September 23, 2022 22:59
tests/test_config.py Show resolved Hide resolved
tests/test_config.py Show resolved Hide resolved
@ejm714
Copy link
Collaborator Author

ejm714 commented Sep 24, 2022

It's a little hard to see from the diff on the code review commit but the change is basically to introduce parity in these tests: https://github.com/drivendataorg/zamba/blob/refactor-instantiation-hierarchy/tests/test_instantiate_model.py#L75-L171

The first two test explicitly if use_default_model_labels is True or False in instantiate model. And then the next two tests set up the config and pass that object to instantiate_model and ensure that the right thing is done using our real models

This change also removes a duplicative test. Previously we had test_finetune_new_labels and test_head_replaced_for_new_species which both test head replacement. This is now a single test.

@pjbull pjbull merged commit eba2cec into master Sep 24, 2022
@pjbull pjbull deleted the refactor-instantiation-hierarchy branch September 24, 2022 01:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update the train tutorial
3 participants