Skip to content

Commit

Permalink
Remove debug stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Mar 25, 2022
1 parent bda9cff commit baf38e0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 25 deletions.
4 changes: 1 addition & 3 deletions examples/vision/maml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
# Adapt the model
for step in range(adaptation_steps):
train_error = loss(learner(adaptation_data), adaptation_labels)
learner.adapt(train_error)
learner.adapt(train_error, step=step)

# Evaluate the adapted model
predictions = learner(evaluation_data)
Expand Down Expand Up @@ -99,7 +99,6 @@ def main(
shots,
ways,
device)
print("Train ", evaluation_error)
evaluation_error.backward()
meta_train_error += evaluation_error.item()
meta_train_accuracy += evaluation_accuracy.item()
Expand All @@ -114,7 +113,6 @@ def main(
shots,
ways,
device)
print("Eval ", evaluation_error)
meta_valid_error += evaluation_error.item()
meta_valid_accuracy += evaluation_accuracy.item()

Expand Down
30 changes: 8 additions & 22 deletions learn2learn/optim/lslr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class LSLR:
def __init__(self, module: torch.nn.Module, adaptation_steps: int, init_lr: float):
module.update_func = self.update_with_lslr
self.model = module
# self.model.lslr = self._init_lslr_parameters(
# adaptation_steps=adaptation_steps, init_lr=init_lr
# )
self.model.lslr = self._init_lslr_parameters(
adaptation_steps=adaptation_steps, init_lr=init_lr
)
self._current_step = 0

def _init_lslr_parameters(
Expand All @@ -79,7 +79,7 @@ def _init_lslr_parameters(
)
return lslr

def update_with_lslr(self, model: torch.nn.Module, lr=0.01, grads=None, **kwargs):
def update_with_lslr(self, model: torch.nn.Module, lr=None, grads=None, **kwargs):
# TODO: Turn this into a GBML gradient transform instead?
"""
Expand All @@ -98,17 +98,6 @@ def update_with_lslr(self, model: torch.nn.Module, lr=0.01, grads=None, **kwargs
* **grads** (list, *optional*, default=None) - A list of gradients for each layer
of the model. If None, will use the gradients in .grad attributes.
"""
if grads is not None:
params = list(model.parameters())
if not len(grads) == len(list(params)):
msg = 'WARNING:maml_update(): Parameters and gradients have different length. ('
msg += str(len(params)) + ' vs ' + str(len(grads)) + ')'
print(msg)
for p, g in zip(params, grads):
if g is not None:
p.update = - lr * g
return update_module(model)

if grads is not None:
params = list(model.parameters())
if not len(grads) == len(list(params)):
Expand All @@ -128,8 +117,7 @@ def update_with_lslr(self, model: torch.nn.Module, lr=0.01, grads=None, **kwargs
layer_name = name[: name.rfind(".")].replace(
".", "-"
) # Extract the layer name from the named parameter
# lr = self.model.lslr[layer_name][self._current_step]
lr = 0.01
lr = self.model.lslr[layer_name][self._current_step]
assert (
lr is not None
), f"Parameter {name} does not have a learning rate in LSLR dict!"
Expand Down Expand Up @@ -162,23 +150,21 @@ def __getattr__(self, name):
if name == "clone":
def override(*args, **kwargs):
method = object.__getattribute__(self.model, name)
# lslr = {k: p.detach() for k, p in self.model.lslr.items()}
lslr = {k: p.clone() for k, p in self.model.lslr.items()}
# Original clone method
self.model = method(*args, **kwargs)
# Override the update function to LSLR
with torch.no_grad():
self.model.update_func = self.update_with_lslr
# self.model.lslr = lslr
self.model.lslr = lslr
return self
attr = override
elif name == "adapt_":
print("Overriding adapt()")
elif name == "adapt":
def override(*args, **kwargs):
assert 'step' in kwargs, "Keyword argument 'step' not passed to the adapt() method"
with torch.no_grad():
self._current_step = kwargs['step']
del kwargs['step']
print(f"Setting step to {self._current_step}")
method = object.__getattribute__(self.model, name)
method(*args, **kwargs)
attr = override
Expand Down

0 comments on commit baf38e0

Please sign in to comment.