Skip to content

Commit

Permalink
Add type annotations to conv-relu (pytorch#47680)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#47679

Pull Request resolved: pytorch#47680

Reviewed By: zhangguanheng66

Differential Revision: D25416628

Pulled By: malfet

fbshipit-source-id: 103bea1e8c300990f74689787a71b1cfe916cfef
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Dec 10, 2020
1 parent e9ef1fe commit 5375a47
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
15 changes: 3 additions & 12 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ ignore_errors = True
[mypy-torch.nn.quantized.modules.conv]
ignore_errors = True

[mypy-torch._lobpcg]
ignore_errors = True

[mypy-torch._appdirs]
ignore_errors = True

[mypy-torch._utils]
ignore_errors = True

[mypy-torch._overrides]
ignore_errors = True

[mypy-torch.utils.tensorboard._caffe2_graph]
ignore_errors = True

Expand All @@ -131,15 +131,6 @@ ignore_errors = True
[mypy-torch.nn.quantized.modules.batchnorm]
ignore_errors = True

[mypy-torch.nn.intrinsic.quantized.modules.conv_relu]
ignore_errors = True

[mypy-torch.nn.intrinsic.quantized.modules.bn_relu]
ignore_errors = True

[mypy-torch.nn.intrinsic.quantized.modules.linear_relu]
ignore_errors = True

[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
ignore_errors = True

Expand Down
8 changes: 4 additions & 4 deletions torch/_lobpcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _symeig_backward(D_grad, U_grad, A, D, U, largest):
class LOBPCGAutogradFunction(torch.autograd.Function):

@staticmethod
def forward(ctx,
def forward(ctx, # type: ignore[override]
A: Tensor,
k: Optional[int] = None,
B: Optional[Tensor] = None,
Expand Down Expand Up @@ -606,7 +606,7 @@ def _lobpcg(A: Tensor,
bparams['ortho_use_drop'] = bparams.get('ortho_use_drop', False)

if not torch.jit.is_scripting():
LOBPCG.call_tracker = LOBPCG_call_tracker
LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore

if len(A.shape) > 2:
N = int(torch.prod(torch.tensor(A.shape[:-2])))
Expand All @@ -628,7 +628,7 @@ def _lobpcg(A: Tensor,
bXret[i] = worker.X[:, :k]

if not torch.jit.is_scripting():
LOBPCG.call_tracker = LOBPCG_call_tracker_orig
LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore

return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))

Expand All @@ -640,7 +640,7 @@ def _lobpcg(A: Tensor,
worker.run()

if not torch.jit.is_scripting():
LOBPCG.call_tracker = LOBPCG_call_tracker_orig
LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore

return worker.E[:k], worker.X[:, :k]

Expand Down
6 changes: 3 additions & 3 deletions torch/nn/intrinsic/quantized/modules/conv_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ConvReLU1d(nnq.Conv1d):
Same as torch.nn.quantized.Conv1d
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d # type: ignore[assignment]

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
Expand Down Expand Up @@ -55,7 +55,7 @@ class ConvReLU2d(nnq.Conv2d):
Same as torch.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d # type: ignore[assignment]

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
Expand Down Expand Up @@ -94,7 +94,7 @@ class ConvReLU3d(nnq.Conv3d):
Attributes: Same as torch.nn.quantized.Conv3d
"""
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d # type: ignore[assignment]

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
Expand Down

0 comments on commit 5375a47

Please sign in to comment.