Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saiftyfirst/imbalanced class weighting #255

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions snntorch/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@


class LossFunctions:
def __init__(self, reduction, weight):
self.reduction = reduction
self.weight = weight

def __call__(self, spk_out, targets):
loss = self._compute_loss(spk_out, targets)
return self._reduce(loss)

def _prediction_check(self, spk_out):
device = spk_out.device

Expand Down Expand Up @@ -51,6 +59,14 @@ def _population_code(self, spk_out, num_classes, num_outputs):
)
return pop_code

def _intermediate_reduction(self):
return self.reduction if self.weight is None else 'none'

def _reduce(self, loss):
# if reduction was delayed due to weight
requires_reduction = self.weight is not None and self.reduction == 'mean'
return loss.mean() if requires_reduction else loss


class ce_rate_loss(LossFunctions):
"""Cross Entropy Spike Rate Loss.
Expand Down Expand Up @@ -78,16 +94,19 @@ class ce_rate_loss(LossFunctions):

"""

def __init__(self):
def __init__(self, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.__name__ = "ce_rate_loss"

def __call__(self, spk_out, targets):
def _compute_loss(self, spk_out, targets):
device, num_steps, _ = self._prediction_check(spk_out)
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)

log_p_y = log_softmax_fn(spk_out)
loss = torch.zeros((1), dtype=dtype, device=device)

loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)

for step in range(num_steps):
loss += loss_fn(log_p_y[step], targets)
Expand Down Expand Up @@ -138,14 +157,15 @@ class ce_count_loss(LossFunctions):

"""

def __init__(self, population_code=False, num_classes=False):
def __init__(self, population_code=False, num_classes=False, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "ce_count_loss"

def __call__(self, spk_out, targets):
def _compute_loss(self, spk_out, targets):
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)

if self.population_code:
_, _, num_outputs = self._prediction_check(spk_out)
Expand Down Expand Up @@ -190,12 +210,13 @@ class ce_max_membrane_loss(LossFunctions):

"""

def __init__(self):
def __init__(self, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.__name__ = "ce_max_membrane_loss"

def __call__(self, mem_out, targets):
def _compute_loss(self, mem_out, targets):
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)

max_mem_out, _ = torch.max(mem_out, 0)
log_p_y = log_softmax_fn(max_mem_out)
Expand Down Expand Up @@ -256,16 +277,19 @@ def __init__(
incorrect_rate=0,
population_code=False,
num_classes=False,
reduction='mean',
weight=None
):
super().__init__(reduction=reduction, weight=weight)
self.correct_rate = correct_rate
self.incorrect_rate = incorrect_rate
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "mse_count_loss"

def __call__(self, spk_out, targets):
def _compute_loss(self, spk_out, targets):
_, num_steps, num_outputs = self._prediction_check(spk_out)
loss_fn = nn.MSELoss()
loss_fn = nn.MSELoss(reduction=self._intermediate_reduction())

if not self.population_code:

Expand Down Expand Up @@ -303,6 +327,10 @@ def __call__(self, spk_out, targets):
)

loss = loss_fn(spike_count, spike_count_target)

if self.weight is not None:
loss = loss * self.weight[targets]

return loss / num_steps


Expand Down Expand Up @@ -345,29 +373,36 @@ class mse_membrane_loss(LossFunctions):

# to-do: add **kwargs to modify other keyword args in
# spikegen.targets_convert
def __init__(self, time_var_targets=False, on_target=1, off_target=0):
def __init__(self, time_var_targets=False, on_target=1, off_target=0, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.time_var_targets = time_var_targets
self.on_target = on_target
self.off_target = off_target
self.__name__ = "mse_membrane_loss"

def __call__(self, mem_out, targets):
def _compute_loss(self, mem_out, targets):
device, num_steps, num_outputs = self._prediction_check(mem_out)
targets = spikegen.targets_convert(
targets_spikes = spikegen.targets_convert(
targets,
num_classes=num_outputs,
on_target=self.on_target,
off_target=self.off_target,
)
loss = torch.zeros((1), dtype=dtype, device=device)
loss_fn = nn.MSELoss()

loss_shape = mem_out[0].shape if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)

loss_fn = nn.MSELoss(reduction=self._intermediate_reduction())

if self.time_var_targets:
for step in range(num_steps):
loss += loss_fn(mem_out[step], targets[step])
loss += loss_fn(mem_out[step], targets_spikes[step])
else:
for step in range(num_steps):
loss += loss_fn(mem_out[step], targets)
loss += loss_fn(mem_out[step], targets_spikes)

if self.weight is not None:
loss = loss * self.weight[targets]

return loss / num_steps

Expand Down Expand Up @@ -735,23 +770,32 @@ def __init__(
off_target=-1,
tolerance=0,
multi_spike=False,
reduction='mean',
weight=None
):
super(mse_temporal_loss, self).__init__()

self.loss_fn = nn.MSELoss()
self.reduction = reduction
self.weight = weight
self.loss_fn = nn.MSELoss(reduction=('none' if self.weight is not None else self.reduction))
self.spk_time_fn = SpikeTime(
target_is_time, on_target, off_target, tolerance, multi_spike
)
self.__name__ = "mse_temporal_loss"

def __call__(self, spk_rec, targets):
spk_time, targets = self.spk_time_fn(
spk_time, target_time = self.spk_time_fn(
spk_rec, targets
) # return encoded targets
loss = self.loss_fn(
spk_time / spk_rec.size(0), targets / spk_rec.size(0)
spk_time / spk_rec.size(0), target_time / spk_rec.size(0)
) # spk_time_final: num_spikes x B x Nc. # Same with targets.

if self.weight is not None:
loss = loss * self.weight[targets]
if self.reduction == 'mean':
loss = loss.mean()

return loss


Expand Down Expand Up @@ -799,10 +843,12 @@ class decreases the first spike time (i.e., earlier spike).

"""

def __init__(self, inverse="negate"):
def __init__(self, inverse="negate", reduction='mean', weight=None):
super(ce_temporal_loss, self).__init__()

self.loss_fn = nn.CrossEntropyLoss()
self.reduction = reduction
self.weight = weight
self.loss_fn = nn.CrossEntropyLoss(reduction=self.reduction, weight=self.weight)
self.spk_time_fn = SpikeTime(target_is_time=False)
self.inverse = inverse
self._ce_temporal_cases()
Expand Down
Loading
Loading