Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAML++: Per-Layer & Per-Layer Per-Step learning rate transforms #328

Open
wants to merge 15 commits into
base: master
Choose a base branch
from

Conversation

DubiousCactus
Copy link
Contributor

@DubiousCactus DubiousCactus commented Mar 25, 2022

Description

Learning Per-Layer Per-Step Learning Rates (LSLR, MAML++). These GBML transforms can be used to reproduce this MAML++ functionality for MAML, MetaSGD and other algorithms which can be reproduced with GBML.

Contribution Checklist

If your contribution modifies code in the core library (not docs, tests, or examples), please fill the following checklist.

  • My contribution is listed in CHANGELOG.md with attribution.
  • My contribution modifies code in the main library.
  • My modifications are tested.
  • My modifications are documented.

Optional

If you make major changes to the core library, please run make alltests and copy-paste the content of alltests.txt below.

make[1]: Entering directory '/home/cactus/Code/learn2learn'
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python -W ignore -m unittest discover -s 'tests' -p '*_test.py' -v
test_final_accuracy (integration.maml_omniglot_test.MAMLOmniglotIntegrationTests) ... ok
test_adaptation (unit.algorithms.gbml_test.TestGBMLgorithm) ... ok
test_allow_nograd (unit.algorithms.gbml_test.TestGBMLgorithm) ... Traceback (most recent call last):
  File "/home/cactus/Code/learn2learn/learn2learn/optim/parameter_update.py", line 119, in forward
    gradients = torch.autograd.grad(
  File "/home/cactus/anaconda3/envs/torch/lib/python3.8/site-packages/torch/autograd/__init__.py", line 234, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: One of the differentiated Tensors does not require grad
ok
test_allow_unused (unit.algorithms.gbml_test.TestGBMLgorithm) ... ok
test_clone_module (unit.algorithms.gbml_test.TestGBMLgorithm) ... ok
test_graph_connection (unit.algorithms.gbml_test.TestGBMLgorithm) ... ok
test_adaptation (unit.algorithms.maml_test.TestMAMLAlgorithm) ... ok
test_allow_nograd (unit.algorithms.maml_test.TestMAMLAlgorithm) ... Traceback (most recent call last):
  File "/home/cactus/Code/learn2learn/learn2learn/algorithms/maml.py", line 160, in adapt
    gradients = grad(loss,
  File "/home/cactus/anaconda3/envs/torch/lib/python3.8/site-packages/torch/autograd/__init__.py", line 234, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: One of the differentiated Tensors does not require grad
ok
test_allow_unused (unit.algorithms.maml_test.TestMAMLAlgorithm) ... ok
test_clone_module (unit.algorithms.maml_test.TestMAMLAlgorithm) ... ok
test_first_order_adaptation (unit.algorithms.maml_test.TestMAMLAlgorithm) ... ok
test_graph_connection (unit.algorithms.maml_test.TestMAMLAlgorithm) ... ok
test_memory_consumption (unit.algorithms.maml_test.TestMAMLAlgorithm) ... ok
test_module_shared_params (unit.algorithms.maml_test.TestMAMLAlgorithm) ... ok
test_adaptation (unit.algorithms.metasgd_test.TestMetaSGDAlgorithm) ... ok
test_clone_module (unit.algorithms.metasgd_test.TestMetaSGDAlgorithm) ... ok
test_graph_connection (unit.algorithms.metasgd_test.TestMetaSGDAlgorithm) ... ok
test_memory_consumption (unit.algorithms.metasgd_test.TestMetaSGDAlgorithm) ... ok
test_meta_lr (unit.algorithms.metasgd_test.TestMetaSGDAlgorithm) ... ok
Files already downloaded and verified
Files already downloaded and verified
0 Meta Train Accuracy 0.3937500095926225
1 Meta Train Accuracy 0.5000000102445483
2 Meta Train Accuracy 0.5125000113621354
3 Meta Train Accuracy 0.5312500107102096
4 Meta Train Accuracy 0.5812500161118805
learn2learn: Maybe try with allow_nograd=True and/orallow_unused=True ?
learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to /tmp/datasets/omniglot-py/images_background.zip

  0%|          | 0/9464212 [00:00<?, ?it/s]
  3%|| 319488/9464212 [00:00<00:02, 3176299.16it/s]
 24%|██▍       | 2285568/9464212 [00:00<00:00, 12845223.70it/s]
 58%|█████▊    | 5480448/9464212 [00:00<00:00, 21536194.07it/s]
 91%|█████████ | 8569856/9464212 [00:00<00:00, 25217203.90it/s]
9464832it [00:00, 21873074.09it/s]                             
Extracting /tmp/datasets/omniglot-py/images_background.zip to /tmp/datasets/omniglot-py
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to /tmp/datasets/omniglot-py/images_evaluation.zip

  0%|          | 0/6462886 [00:00<?, ?it/s]
  5%|| 327680/6462886 [00:00<00:01, 3260524.99it/s]
 36%|███▌      | 2326528/6462886 [00:00<00:00, 13038351.78it/s]
 85%|████████▍ | 5472256/6462886 [00:00<00:00, 21400815.70it/s]
6463488it [00:00, 19254143.16it/s]                             
test_data_labels_length (unit.data.metadataset_test.TestMetaDataset) ... ok
test_data_labels_values (unit.data.metadataset_test.TestMetaDataset) ... ok
test_data_length (unit.data.metadataset_test.TestMetaDataset) ... ok
test_fails_with_non_torch_dataset (unit.data.metadataset_test.TestMetaDataset) ... ok
test_filtered_metadataset (unit.data.metadataset_test.TestMetaDataset) ... Extracting /tmp/datasets/omniglot-py/images_evaluation.zip to /tmp/datasets/omniglot-py
Files already downloaded and verified
Files already downloaded and verified
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/datasets/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
  4%|| 374784/9912422 [00:00<00:02, 3743015.66it/s]
 18%|█▊        | 1787904/9912422 [00:00<00:00, 9836908.54it/s]
 50%|████▉     | 4931584/9912422 [00:00<00:00, 19689376.20it/s]
 80%|████████  | 7957504/9912422 [00:00<00:00, 23661710.36it/s]
9913344it [00:00, 21682749.88it/s]                             
Extracting /tmp/datasets/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/datasets/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
29696it [00:00, 49445832.31it/s]         
Extracting /tmp/datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
 24%|██▍       | 394240/1648877 [00:00<00:00, 3922912.19it/s]
1649664it [00:00, 9436849.35it/s]                            
Extracting /tmp/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
5120it [00:00, 39116277.74it/s]         
ok
test_get_item (unit.data.metadataset_test.TestMetaDataset) ... ok
test_labels_to_indices (unit.data.metadataset_test.TestMetaDataset) ... ok
test_union_metadataset (unit.data.metadataset_test.TestMetaDataset) ... ok
test_dataloader (unit.data.task_dataset_test.TestTaskDataset) ... Extracting /tmp/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/datasets/MNIST/raw

Files already downloaded and verified
Files already downloaded and verified
ok
test_infinite_tasks (unit.data.task_dataset_test.TestTaskDataset) ... ok
test_instanciation (unit.data.task_dataset_test.TestTaskDataset) ... ok
test_task_caching (unit.data.task_dataset_test.TestTaskDataset) ... ok
test_task_transforms (unit.data.task_dataset_test.TestTaskDataset) ... ok
test_filter_labels (unit.data.transforms_test.TestTransforms) ... ok
test_k_shots (unit.data.transforms_test.TestTransforms) ... ok
test_load_data (unit.data.transforms_test.TestTransforms) ... ok
test_n_ways (unit.data.transforms_test.TestTransforms) ... ok
test_remap_labels (unit.data.transforms_test.TestTransforms) ... ok
test_infinite_iterator (unit.data.utils_test.DataUtilsTests) ... ok
test_partition_task (unit.data.utils_test.DataUtilsTests) ... ok
test_illegal_dimensions (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_illegal_dimensions_1d (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_m_edge (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_m_edge_1d (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_m_n_edge (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_m_n_edge_1d (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_n_edge (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_n_edge_1d (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_simple (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_simple_1d (unit.nn.kroneckers_test.KroneckerLinearTests) ... ok
test_cosine_distance (unit.nn.protonet_test.PrototypicalClassifierTests) ... ok
test_euclidean_distance (unit.nn.protonet_test.PrototypicalClassifierTests) ... ok
test_simple (unit.nn.protonet_test.PrototypicalClassifierTests) ... ok
test_clone_module_basics (unit.utils_test.UtilTests) ... ok
test_clone_module_models (unit.utils_test.UtilTests) ... ok
test_clone_module_nomodule (unit.utils_test.UtilTests) ... ok
test_distribution_clone (unit.utils_test.UtilTests) ... ok
test_distribution_detach (unit.utils_test.UtilTests) ... ok
test_module_clone_shared_params (unit.utils_test.UtilTests) ... ok
test_module_detach (unit.utils_test.UtilTests) ... ok
test_module_detach_keep_requires_grad (unit.utils_test.UtilTests) ... ok
test_module_update_shared_params (unit.utils_test.UtilTests) ... ok
test_rnn_clone (unit.utils_test.UtilTests) ... ok

----------------------------------------------------------------------
Ran 62 tests in 40.781s

OK
make lint
make[2]: Entering directory '/home/cactus/Code/learn2learn'
pycodestyle learn2learn/ --max-line-length=160
make[2]: Leaving directory '/home/cactus/Code/learn2learn'
make[1]: Leaving directory '/home/cactus/Code/learn2learn'
make[1]: Entering directory '/home/cactus/Code/learn2learn'
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python -W ignore -m unittest discover -s 'tests' -p '*_test_notravis.py' -v
test_final_accuracy (integration.maml_miniimagenet_test_notravis.MAMLMiniImagenetIntegrationTests) ... ok
test_final_accuracy (integration.protonets_miniimagenet_test_notravis.ProtoNetMiniImageNetIntegrationTests) ... make[1]: *** [Makefile:37: notravis-tests] Killed
make[1]: Leaving directory '/home/cactus/Code/learn2learn'

@lgtm-com
Copy link

lgtm-com bot commented Mar 25, 2022

This pull request introduces 1 alert when merging baf38e0 into df3c329 - view on LGTM.com

new alerts:

  • 1 for Unused import

@DubiousCactus DubiousCactus marked this pull request as ready for review April 22, 2022 19:42
@DubiousCactus DubiousCactus changed the title [WIP] MAML++: LSLR MAML++: Per-Layer & Per-Layer Per-Step learning rate transforms Apr 22, 2022
* Add load_state_dict(), state_dict() for saving/restoring of LRs
* Add optional argument to PerLayerPerStepLRTransform to only attach per-step LRs to given layer names
@lgtm-com
Copy link

lgtm-com bot commented Jul 28, 2022

This pull request fixes 2 alerts when merging 98a118c into f099ddc - view on LGTM.com

fixed alerts:

  • 1 for Unused local variable
  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Jul 29, 2022

This pull request fixes 2 alerts when merging c4ea713 into f099ddc - view on LGTM.com

fixed alerts:

  • 1 for Unused local variable
  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Jul 30, 2022

This pull request fixes 2 alerts when merging 811beb0 into f099ddc - view on LGTM.com

fixed alerts:

  • 1 for Unused local variable
  • 1 for Unused import

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant