diff --git a/src/leibnetz/nets/__init__.py b/src/leibnetz/nets/__init__.py index 4c42c5b..e7a2941 100644 --- a/src/leibnetz/nets/__init__.py +++ b/src/leibnetz/nets/__init__.py @@ -1,12 +1,13 @@ from .attentive_scalenet import build_attentive_scale_net from .scalenet import build_scale_net from .unet import build_unet -from .bio import ( +from .local_learning import ( convert_to_bio, convert_to_backprop, HebbsRule, KrotovsRule, OjasRule, + GeometricConsistencyRule, ) # from .resnet import build_resnet diff --git a/src/leibnetz/nets/bio.py b/src/leibnetz/nets/local_learning.py similarity index 55% rename from src/leibnetz/nets/bio.py rename to src/leibnetz/nets/local_learning.py index f730f2a..fbff567 100644 --- a/src/leibnetz/nets/bio.py +++ b/src/leibnetz/nets/local_learning.py @@ -20,6 +20,8 @@ class LearningRule(ABC): } """ + name: str = "LearningRule" + def __init__(self): self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) @@ -34,7 +36,68 @@ def update(self, x, w): pass +class GeometricConsistencyRule(LearningRule): + """ + Implements a geometric consistency local learning rule per module. Only implemented for convolutional layers. + """ + + name: str = "GeometricConsistencyRule" + requires_grad: bool = True + + def __init__( + self, + learning_rate=0.1, + optimizer="RAdam", + optimizer_kwargs={}, + ): + super().__init__() + self.learning_rate = learning_rate + self.optimizer = optimizer + self.optimizer_kwargs = optimizer_kwargs + + def __str__(self): + return f"GeometricConsistencyRule(learning_rate={self.learning_rate}, optimizer={self.optimizer}, optimizer_kwargs={self.optimizer_kwargs})" + + def init_layers(self, layer): + if hasattr(layer, "weight"): + layer.weight.data.normal_(mean=0.0, std=1.0) + + def update(self, module, args, kwargs, output): + if module.training: + if not hasattr(module, "kernel_size"): + # Only implemented for convolutional layers + return + module.zero_grad() + if not hasattr(module, "optimizer"): + optimizer = getattr(torch.optim, self.optimizer) + setattr( + module, + "optimizer", + optimizer( + [module.weight], lr=self.learning_rate, **self.optimizer_kwargs + ), + ) + with torch.no_grad(): + inputs = args[0] + ndims = len(module.kernel_size) + # Randomly permute the input tensor in the spatial dimensions + non_spatial_dims = len(inputs.shape) - ndims + dim_permutations = torch.randperm(ndims) + non_spatial_dims + dim_permutations = ( + list(range(non_spatial_dims)) + dim_permutations.tolist() + ) + perm_outputs = output.permute(*dim_permutations) + perm_inputs = inputs.permute(*dim_permutations) + outputs_of_perm_inputs = module.forward(perm_inputs) + loss = torch.nn.functional.mse_loss(outputs_of_perm_inputs, perm_outputs) + loss.backward() + module.optimizer.step() + # TODO: Weights are exploding, need to normalize them + + class HebbsRule(LearningRule): + name: str = "HebbsRule" + requires_grad: bool = False def __init__(self, learning_rate=0.1, normalize_kwargs={"dim": 0}): super().__init__() @@ -50,6 +113,10 @@ def __init__(self, learning_rate=0.1, normalize_kwargs={"dim": 0}): def __str__(self): return f"HebbsRule(learning_rate={self.learning_rate}, normalize_kwargs={self.normalize_kwargs})" + def init_layers(self, layer): + if hasattr(layer, "weight"): + layer.weight.data.normal_(mean=0.0, std=1.0) + @torch.no_grad() def update(self, module, args, kwargs, output): if module.training: @@ -76,112 +143,14 @@ def update(self, module, args, kwargs, output): elif ndims == 4: d_W = d_W.permute(5, 0, 1, 2, 3, 4) else: - ndims = None d_W = inputs * output # = c1 x c2 d_W = self.normalize_fcn(d_W) module.weight.data += d_W * self.learning_rate -class RhoadesRule(LearningRule): - """Rule modifying Krotov-Hopfield Hebbian learning rule fast implementation. - - Args: - precision: Numerical precision of the weight updates. - delta: Anti-hebbian learning strength. - norm: Lebesgue norm of the weights. - k_ratio: Ranking parameter - """ - - def __init__( - self, k_ratio=0.5, delta=0.4, norm=2, normalize=False, precision=1e-30 - ): - super().__init__() - self.precision = precision - self.delta = delta - self.norm = norm - assert k_ratio <= 1, "k_ratio should be smaller or equal to 1" - self.k_ratio = k_ratio - self.normalize = normalize - - def __str__(self): - return f"RhoadesRule(k_ratio={self.k_ratio}, delta={self.delta}, norm={self.norm}, normalize={self.normalize})" - - def init_layers(self, layer): - if hasattr(layer, "weight"): - layer.weight.data.normal_(mean=0.0, std=1.0) - - def update(self, inputs: torch.Tensor, module: torch.nn.Module): - # TODO: WIP - if hasattr(module, "kernel_size"): - # Extract patches for convolutional layers - inputs = extract_kernel_patches( - inputs, module.kernel_size, module.stride, module.dilation - ) - weights = module.weight.view( - -1, torch.prod(torch.as_tensor(module.kernel_size)) - ) - else: - weights = module.weight - inputs = inputs.view(inputs.size(0), -1) - - # TODO: needs re-implementation - batch_size = inputs.shape[0] - num_hidden_units = torch.prod(torch.as_tensor(weights.shape)) - input_size = inputs[0].shape[0] - k = int(self.k_ratio * num_hidden_units) - - # TODO: WIP - if self.normalize: - norm = torch.norm(inputs, dim=1) - norm[norm == 0] = 1 - inputs = torch.div(inputs, norm.view(-1, 1)) - - inputs = torch.t(inputs) - - # Calculate overlap for each hidden unit and input sample - tot_input = torch.matmul( - torch.sign(weights) * torch.abs(weights) ** (self.norm - 1), inputs - ) - - # Get the top k activations for each input sample (hidden units ranked per input sample) - _, indices = torch.topk(tot_input, k=k, dim=0) - - # Apply the activation function for each input sample - activations = torch.zeros((num_hidden_units, batch_size), device=weights.device) - activations[indices[0], torch.arange(batch_size)] = 1.0 - activations[indices[k - 1], torch.arange(batch_size)] = -self.delta - - # Sum the activations for each hidden unit, the batch dimension is removed here - xx = torch.sum(torch.mul(activations, tot_input), 1) - - # Apply the actual learning rule, from here on the tensor has the same dimension as the weights - norm_factor = torch.mul( - xx.view(xx.shape[0], 1).repeat((1, input_size)), weights - ) - ds = torch.matmul(activations, torch.t(inputs)) - norm_factor - - # Normalize the weight updates so that the largest update is 1 (which is then multiplied by the learning rate) - nc = torch.max(torch.abs(ds)) - if nc < self.precision: - nc = self.precision - d_w = torch.true_divide(ds, nc) - - return d_w - - class KrotovsRule(LearningRule): """Krotov-Hopfield Hebbian learning rule fast implementation. - This code is taken from https://github.com/Joxis/pytorch-hebbian.git - The code is licensed under the MIT license. - - Please reference the following paper if you use this code: - @inproceedings{talloen2020pytorchhebbian, - author = {Jules Talloen and Joni Dambre and Alexander Vandesompele}, - location = {Online}, - title = {PyTorch-Hebbian: facilitating local learning in a deep learning framework}, - year = {2020}, - } Original source: https://github.com/DimaKrotov/Biological_Learning Args: @@ -191,16 +160,32 @@ class KrotovsRule(LearningRule): k_ratio: Ranking parameter """ + name: str = "KrotovsRule" + requires_grad: bool = False + def __init__( - self, k_ratio=0.5, delta=0.4, norm=2, normalize=False, precision=1e-30 + self, + learning_rate: float = 0.1, + k_ratio: float = 0.5, + delta: float = 0.4, + norm: int = 2, + normalize_kwargs: dict = {"dim": 0}, + precision: float = 1e-30, ): super().__init__() - self.precision = precision - self.delta = delta - self.norm = norm + self.learning_rate = learning_rate assert k_ratio <= 1, "k_ratio should be smaller or equal to 1" self.k_ratio = k_ratio - self.normalize = normalize + self.delta = delta + self.norm = norm + self.normalize_kwargs = normalize_kwargs + if normalize_kwargs is not None: + self.normalize_fcn = lambda x: torch.nn.functional.normalize( + x, **normalize_kwargs + ) + else: + self.normalize_fcn = lambda x: x + self.precision = precision def __str__(self): return f"KrotovsRule(k_ratio={self.k_ratio}, delta={self.delta}, norm={self.norm}, normalize={self.normalize})" @@ -209,64 +194,110 @@ def init_layers(self, layer): if hasattr(layer, "weight"): layer.weight.data.normal_(mean=0.0, std=1.0) - def update(self, inputs: torch.Tensor, module: torch.nn.Module): - if hasattr(module, "kernel_size"): - # Extract patches for convolutional layers - inputs = extract_kernel_patches( - inputs, module.kernel_size, module.stride, module.dilation - ) - weights = module.weight.view( - -1, torch.prod(torch.as_tensor(module.kernel_size)) - ) - else: - weights = module.weight - inputs = inputs.view(inputs.size(0), -1) - - batch_size = inputs.shape[0] - num_hidden_units = weights.shape[0] - input_size = inputs[0].shape[0] - k = int(self.k_ratio * num_hidden_units) - - # TODO: WIP - if self.normalize: - norm = torch.norm(inputs, dim=1) - norm[norm == 0] = 1 - inputs = torch.div(inputs, norm.view(-1, 1)) - - inputs = torch.t(inputs) - - # Calculate overlap for each hidden unit and input sample - tot_input = torch.matmul( - torch.sign(weights) * torch.abs(weights) ** (self.norm - 1), inputs - ) - - # Get the top k activations for each input sample (hidden units ranked per input sample) - _, indices = torch.topk(tot_input, k=k, dim=0) - - # Apply the activation function for each input sample - activations = torch.zeros((num_hidden_units, batch_size), device=weights.device) - activations[indices[0], torch.arange(batch_size)] = 1.0 - activations[indices[k - 1], torch.arange(batch_size)] = -self.delta - - # Sum the activations for each hidden unit, the batch dimension is removed here - xx = torch.sum(torch.mul(activations, tot_input), 1) - - # Apply the actual learning rule, from here on the tensor has the same dimension as the weights - norm_factor = torch.mul( - xx.view(xx.shape[0], 1).repeat((1, input_size)), weights - ) - ds = torch.matmul(activations, torch.t(inputs)) - norm_factor - - # Normalize the weight updates so that the largest update is 1 (which is then multiplied by the learning rate) - nc = torch.max(torch.abs(ds)) - if nc < self.precision: - nc = self.precision - d_w = torch.true_divide(ds, nc) - - return d_w + @torch.no_grad() + def update(self, module, args, kwargs, output): + pass + # if module.training: + # with torch.no_grad(): + # inputs = args[0] + # N = inputs.shape[0] + # weights = module.weight + # num_hidden_units = weights.shape[0] + # k = int(self.k_ratio * num_hidden_units) + # if hasattr(module, "kernel_size"): + # ndims = len(module.kernel_size) + + # # Extract patches for convolutional layers + # X = extract_kernel_patches( + # inputs, + # module.in_channels, + # module.kernel_size, + # module.stride, + # module.dilation, + # ) # = c1 x 3 x 3 x 3 x N + + # input_size = X.shape[-1] + + # # Calculate overlap for each hidden unit and input sample + # tot_input = torch.tensordot( + # torch.sign(weights) * torch.abs(weights) ** (self.norm - 1), X + # ) + + # # Get the top k activations for each input sample (hidden units ranked per input sample) + # _, indices = torch.topk(tot_input, k=k, dim=0) + + # # Apply the activation function for each input sample + # activations = torch.zeros( + # (num_hidden_units, N), device=weights.device + # ) + # activations[indices[0], ...] = 1.0 + # activations[indices[k - 1], ...] = -self.delta + # # ================== WIP ================== + # Y = extract_image_patches( + # output, module.out_channels, ndims + # ).T # = N x c2 + + # d_W = X @ Y # = c1 x 3 x 3 x 3 x c2 + # if ndims == 2: + # d_W = d_W.permute(3, 0, 1, 2) + # if ndims == 3: + # d_W = d_W.permute(4, 0, 1, 2, 3) + # elif ndims == 4: + # d_W = d_W.permute(5, 0, 1, 2, 3, 4) + + # # TODO: WIP + # weights = module.weight.view( + # -1, torch.prod(torch.as_tensor(module.kernel_size)) + # ) + # else: + # # TODO: WIP + # # ndims = None + # # d_W = inputs * output # = c1 x c2 + # # weights = module.weight + + # inputs = inputs.view(inputs.size(0), -1) + + # inputs[0].shape[0] + + # inputs = torch.t(inputs) + + # # Calculate overlap for each hidden unit and input sample + # tot_input = torch.dot( + # torch.sign(weights) * torch.abs(weights) ** (self.norm - 1), inputs + # ) + + # # Get the top k activations for each input sample (hidden units ranked per input sample) + # _, indices = torch.topk(tot_input, k=k, dim=0) + + # # Apply the activation function for each input sample + # activations = torch.zeros( + # (num_hidden_units, batch_size), device=weights.device + # ) + # activations[indices[0], torch.arange(batch_size)] = 1.0 + # activations[indices[k - 1], torch.arange(batch_size)] = -self.delta + + # # Sum the activations for each hidden unit, the batch dimension is removed here + # xx = torch.sum(torch.mul(activations, tot_input), 1) + + # # Apply the actual learning rule, from here on the tensor has the same dimension as the weights + # norm_factor = torch.mul( + # xx.view(xx.shape[0], 1).repeat((1, input_size)), weights + # ) + # ds = torch.dot(activations, torch.t(inputs)) - norm_factor + + # # Normalize the weight updates so that the largest update is 1 (which is then multiplied by the learning rate) + # nc = torch.max(torch.abs(ds)) + # if nc < self.precision: + # nc = self.precision + # d_W = torch.true_divide(ds, nc) + + # d_W = self.normalize_fcn(d_W) + # module.weight.data += d_W * self.learning_rate class OjasRule(LearningRule): + name: str = "OjasRule" + requires_grad: bool = False def __init__(self, learning_rate=0.1, normalize_kwargs={"dim": 0}): super().__init__() @@ -282,6 +313,10 @@ def __init__(self, learning_rate=0.1, normalize_kwargs={"dim": 0}): def __str__(self): return f"OjasRule(learning_rate={self.learning_rate}, normalize_kwargs={self.normalize_kwargs})" + def init_layers(self, layer): + if hasattr(layer, "weight"): + layer.weight.data.normal_(mean=0.0, std=1.0) + @torch.no_grad() def update(self, module, args, kwargs, output): if module.training: @@ -317,7 +352,6 @@ def update(self, module, args, kwargs, output): elif ndims == 4: d_W = d_W.permute(5, 0, 1, 2, 3, 4) else: - ndims = None W = module.weight d_W = output @ (inputs - output @ W) @@ -361,6 +395,20 @@ def extract_image_patches(x, channels, ndims): return x.view(channels, -1) +def _add_learning_parts( + model, rule: LearningRule, hook: torch.utils.hooks.RemovableHandle | list +): + if not hasattr(model, "learning_hooks"): + setattr(model, "learning_hooks", []) + if not hasattr(model, "learning_rules"): + setattr(model, "learning_rules", []) + if isinstance(hook, list): + model.learning_hooks.extend(hook) + else: + model.learning_hooks.append(hook) + model.learning_rules.append(rule) + + def convert_to_bio(model: LeibNet, learning_rule: LearningRule, init_layers=True): """Converts a LeibNet model to use local bio-inspired learning rules. @@ -385,15 +433,11 @@ def convert_to_bio(model: LeibNet, learning_rule: LearningRule, init_layers=True learning_rule.init_layers(module) else: torch.nn.init.sparse_(module.weight, sparsity=0.5) - module.weight.requires_grad = False - setattr(module, "learning_rule", learning_rule) - setattr(module, "learning_rate", learning_rule.learning_rate) - setattr(module, "learning_hook", hooks[-1]) + module.weight.requires_grad = learning_rule.requires_grad + _add_learning_parts(module, learning_rule, hooks[-1]) - setattr(model, "learning_rule", learning_rule) - setattr(model, "learning_rate", learning_rule.learning_rate) - setattr(model, "learning_hooks", hooks) - model.requires_grad_(False) + _add_learning_parts(model, learning_rule, hooks) + model.requires_grad_(learning_rule.requires_grad) return model @@ -423,7 +467,7 @@ def convert_to_backprop(model: LeibNet): return model -# # %% +# # # %% # from leibnetz.nets import build_unet # unet = build_unet() @@ -459,3 +503,4 @@ def convert_to_backprop(model: LeibNet): # model = convert_to_bio(unet, OjasRule()) # batch = model(model.get_example_inputs()) # # %% +# %% diff --git a/src/leibnetz/nodes/conv_pass_node.py b/src/leibnetz/nodes/conv_pass_node.py index fa08044..54c4766 100644 --- a/src/leibnetz/nodes/conv_pass_node.py +++ b/src/leibnetz/nodes/conv_pass_node.py @@ -17,7 +17,7 @@ def __init__( output_key_channels=None, activation="ReLU", padding="valid", - residual=False, # TODO: Breaks with residual=True + residual=False, padding_mode="reflect", norm_layer=None, dropout_prob=None, diff --git a/src/leibnetz/nodes/factor_compressed_conv_node.py b/src/leibnetz/nodes/factor_compressed_conv_node.py new file mode 100644 index 0000000..f9e3d74 --- /dev/null +++ b/src/leibnetz/nodes/factor_compressed_conv_node.py @@ -0,0 +1,22 @@ +from leibnetz.nodes import Node + +class FactorCompressedConvNode(Node): + def __init__( + self, + input_keys, + output_keys, + input_nc, + output_nc, + kernel_sizes, + output_key_channels=None, + activation="ReLU", + padding="valid", + residual=False, + padding_mode="reflect", + norm_layer=None, + dropout_prob=None, + identifier=None, + ) -> None: + self.rank_A = ... + self.rank_B = ... + # matrices A and B are multiplied to compute the full weight matrix W for the forward pass \ No newline at end of file diff --git a/src/leibnetz/nodes/node_ops.py b/src/leibnetz/nodes/node_ops.py index 307e54c..254619f 100644 --- a/src/leibnetz/nodes/node_ops.py +++ b/src/leibnetz/nodes/node_ops.py @@ -97,10 +97,12 @@ def __init__( ) ) if residual and i == 0: - if input_nc < output_nc: + if input_nc < output_nc and output_nc % input_nc == 0: groups = input_nc - else: + elif input_nc % output_nc == 0: groups = output_nc + else: + groups = 1 self.x_init_map = conv( input_nc, output_nc,