Skip to content

Commit

Permalink
fixing *NLLCriterion for non mini-batch cases
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Oct 1, 2016
1 parent 688b7a7 commit 5ade793
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
11 changes: 9 additions & 2 deletions ClassNLLCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ end

function ClassNLLCriterion:updateOutput(input, target)
if type(target) == 'number' then
if input:type() ~= 'torch.CudaTensor' then
self.target = self.target:long()
if input:type() == 'torch.CudaTensor' then
self.target = self.target:cudaLong()
else
self.target = self.target:long()
end
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
Expand All @@ -52,6 +54,11 @@ end

function ClassNLLCriterion:updateGradInput(input, target)
if type(target) == 'number' then
if input:type() == 'torch.CudaTensor' then
self.target = self.target:cudaLong()
else
self.target = self.target:long()
end
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
self.target = target:cudaLong()
Expand Down
11 changes: 9 additions & 2 deletions SpatialClassNLLCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ end

function SpatialClassNLLCriterion:updateOutput(input, target)
if type(target) == 'number' then
if input:type() ~= 'torch.CudaTensor' then
self.target = self.target:long()
if input:type() == 'torch.CudaTensor' then
self.target = self.target:cudaLong()
else
self.target = self.target:long()
end
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
Expand All @@ -52,6 +54,11 @@ end

function SpatialClassNLLCriterion:updateGradInput(input, target)
if type(target) == 'number' then
if input:type() == 'torch.CudaTensor' then
self.target = self.target:cudaLong()
else
self.target = self.target:long()
end
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
self.target = target:cudaLong()
Expand Down

0 comments on commit 5ade793

Please sign in to comment.