Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot apply sparse_switchnorm to 2D input? #6

Open
Ning5195 opened this issue Oct 24, 2020 · 0 comments
Open

Cannot apply sparse_switchnorm to 2D input? #6

Ning5195 opened this issue Oct 24, 2020 · 0 comments

Comments

@Ning5195
Copy link

Ning5195 commented Oct 24, 2020

When I want to apply sparse_switchnorm to a 2D tensor, it fails at self.var_weight and meets the same problem as #2 ?
I modified the code as follows.

`class SSN(nn.Module):
def init(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True, last_gamma=False):
super(SSN1d, self).init()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.weight = nn.Parameter(torch.ones(1, num_features))
self.bias = nn.Parameter(torch.zeros(1, num_features))

    self.mean_weight = nn.Parameter(torch.ones(2))
    self.var_weight = nn.Parameter(torch.ones(2))
    self.register_buffer('running_mean', torch.zeros(1, num_features))
    self.register_buffer('running_var', torch.zeros(1, num_features))

    # self.rad = 0.
    self.register_buffer('mean_fixed', torch.LongTensor([0]))
    self.register_buffer('var_fixed', torch.LongTensor([0]))
    self.register_buffer('radius', torch.zeros(1))

    self.mean_weight_ = torch.cuda.FloatTensor([1.,1.])
    self.var_weight_ = torch.cuda.FloatTensor([1.,1.])

    self.reset_parameters()

def reset_parameters(self):
    self.running_mean.zero_()
    self.running_var.zero_()
    self.weight.data.fill_(1)
    self.mean_fixed.data.fill_(0)
    self.var_fixed.data.fill_(0)
    self.bias.data.zero_()

def _check_input_dim(self, input):
    if input.dim() != 2:
        raise ValueError('expected 2D input (got {}D input)'
                         .format(input.dim()))

def forward(self, x):
    self._check_input_dim(x)

    mean_ln = x.mean(1, keepdim=True)
    var_ln = x.var(1, keepdim=True)

    if self.training:
        mean_bn = x.mean(0, keepdim=True)
        var_bn = x.var(0, keepdim=True)
        if self.using_moving_average:
            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * mean_bn.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * var_bn.data)
        else:
            self.running_mean.add_(mean_bn.data)
            self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
    else:
        mean_bn = torch.autograd.Variable(self.running_mean)
        var_bn = torch.autograd.Variable(self.running_var)

    rad = self.radius.item()
    if not self.mean_fixed:
        self.mean_weight_ = sparsestmax(self.mean_weight, rad)
        if max(self.mean_weight_) - min(self.mean_weight_) >= 1:
            self.mean_fixed.data.fill_(1)
            self.mean_weight.data = self.mean_weight_.data
            self.mean_weight_ = self.mean_weight.detach()
    else:
        self.mean_weight_ = self.mean_weight.detach()

    if not self.var_fixed:
        **self.var_weight_ = sparsestmax(self.var_weight, rad)**
        if max(self.var_weight_) - min(self.var_weight_) >= 1:
            self.var_fixed.data.fill_(1)
            self.var_weight.data = self.var_weight_.data
            self.var_weight_ = self.var_weight.detach()
    else:
        self.var_weight_ = self.var_weight.detach()

    mean = self.mean_weight_[0] * mean_ln + self.mean_weight_[1] * mean_bn
    var = self.var_weight_[0] * var_ln + self.var_weight_[1] * var_bn

    x = (x - mean) / (var + self.eps).sqrt()
    return x * self.weight + self.bias

def get_mean(self):
    return self.mean_weight_

def get_var(self):
    return self.var_weight_

def set_rad(self, rad):
    self.radius[0].fill_(rad)
    # self.rad = torch.squeeze(self.radius)

def get_rad(self):
    return torch.squeeze(self.radius)`

`

It seems that the value of self.var_weight in "self.var_weight_ = sparsestmax(self.var_weight, rad)" is nan.
Is it because of an error in the modified code?
Or sparse_switchnorm cannot apply to the 2D input?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant