diff --git a/configs/config_rodnet_cdc_win16.py b/configs/config_rodnet_cdc_win16.py index 455ce3e..93a50c3 100644 --- a/configs/config_rodnet_cdc_win16.py +++ b/configs/config_rodnet_cdc_win16.py @@ -18,7 +18,7 @@ ), demo=dict( subdir='demo', - seqs=[], + # seqs=[], ), ) diff --git a/rodnet/datasets/CRDataset.py b/rodnet/datasets/CRDataset.py index e60c3b0..852d3d4 100644 --- a/rodnet/datasets/CRDataset.py +++ b/rodnet/datasets/CRDataset.py @@ -2,6 +2,8 @@ import time import random import pickle +import traceback + import numpy as np from tqdm import tqdm @@ -165,6 +167,7 @@ def __getitem__(self, index): data_dict['end_frame'] = data_id + self.win_size * self.step - 1 except: + print(f"\033[1;36m {traceback.format_exc()}\033[0m") # in case load npy fail data_dict['status'] = False if not os.path.exists('./tmp'): diff --git a/rodnet/ops/dcn/src/deform_conv_2d_cuda.cpp b/rodnet/ops/dcn/src/deform_conv_2d_cuda.cpp index 2321e02..8ef0d23 100644 --- a/rodnet/ops/dcn/src/deform_conv_2d_cuda.cpp +++ b/rodnet/ops/dcn/src/deform_conv_2d_cuda.cpp @@ -63,26 +63,26 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, at::Tensor weight, int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW, int group, int deformable_group) { - AT_CHECK(weight.ndimension() == 4, + TORCH_CHECK(weight.ndimension() == 4, "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " "but got: %s", weight.ndimension()); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - AT_CHECK(kW > 0 && kH > 0, + TORCH_CHECK(kW > 0 && kH > 0, "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); - AT_CHECK((weight.size(2) == kH && weight.size(3) == kW), + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), "kernel size should be consistent with weight, ", "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, kW, weight.size(2), weight.size(3)); - AT_CHECK(dW > 0 && dH > 0, + TORCH_CHECK(dW > 0 && dH > 0, "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); - AT_CHECK( + TORCH_CHECK( dilationW > 0 && dilationH > 0, "dilation should be greater than 0, but got dilationH: %d dilationW: %d", dilationH, dilationW); @@ -98,7 +98,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, dimw++; } - AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", ndim); long nInputPlane = weight.size(1) * group; @@ -110,7 +110,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - AT_CHECK(nInputPlane % deformable_group == 0, + TORCH_CHECK(nInputPlane % deformable_group == 0, "input channels must divide deformable group size"); if (outputWidth < 1 || outputHeight < 1) @@ -120,27 +120,27 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, outputWidth); - AT_CHECK(input.size(1) == nInputPlane, + TORCH_CHECK(input.size(1) == nInputPlane, "invalid number of input planes, expected: %d, but got: %d", nInputPlane, input.size(1)); - AT_CHECK((inputHeight >= kH && inputWidth >= kW), + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), "input image is smaller than kernel"); - AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), "invalid spatial size of offset, expected height: %d width: %d, but " "got height: %d width: %d", outputHeight, outputWidth, offset.size(2), offset.size(3)); - AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), "invalid number of channels of offset"); if (gradOutput != NULL) { - AT_CHECK(gradOutput->size(dimf) == nOutputPlane, + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, "invalid number of gradOutput planes, expected: %d, but got: %d", nOutputPlane, gradOutput->size(dimf)); - AT_CHECK((gradOutput->size(dimh) == outputHeight && + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && gradOutput->size(dimw) == outputWidth), "invalid size of gradOutput, expected height: %d width: %d , but " "got height: %d width: %d", @@ -191,7 +191,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth}); @@ -298,7 +298,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); columns = at::zeros( {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, @@ -414,7 +414,7 @@ int deform_conv_backward_parameters_cuda( long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); columns = at::zeros( {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, @@ -494,8 +494,8 @@ void modulated_deform_conv_cuda_forward( const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int deformable_group, const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); @@ -576,8 +576,8 @@ void modulated_deform_conv_cuda_backward( int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); diff --git a/rodnet/ops/dcn/src/deform_conv_2d_cuda_kernel.cu b/rodnet/ops/dcn/src/deform_conv_2d_cuda_kernel.cu index e7a26f2..e51b1fb 100644 --- a/rodnet/ops/dcn/src/deform_conv_2d_cuda_kernel.cu +++ b/rodnet/ops/dcn/src/deform_conv_2d_cuda_kernel.cu @@ -62,6 +62,7 @@ #include #include +#include #include #include #include @@ -262,7 +263,7 @@ void deformable_im2col( const scalar_t *data_offset_ = data_offset.data(); scalar_t *data_col_ = data_col.data(); - deformable_im2col_gpu_kernel<<>>( + deformable_im2col_gpu_kernel<<>>( num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, channels, deformable_group, @@ -356,7 +357,7 @@ void deformable_col2im( const scalar_t *data_offset_ = data_offset.data(); scalar_t *grad_im_ = grad_im.data(); - deformable_col2im_gpu_kernel<<>>( + deformable_col2im_gpu_kernel<<>>( num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, @@ -455,7 +456,7 @@ void deformable_col2im_coord( const scalar_t *data_offset_ = data_offset.data(); scalar_t *grad_offset_ = grad_offset.data(); - deformable_col2im_coord_gpu_kernel<<>>( + deformable_col2im_coord_gpu_kernel<<>>( num_kernels, data_col_, data_im_, data_offset_, channels, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, @@ -785,7 +786,7 @@ void modulated_deformable_im2col_cuda( const scalar_t *data_mask_ = data_mask.data(); scalar_t *data_col_ = data_col.data(); - modulated_deformable_im2col_gpu_kernel<<>>( + modulated_deformable_im2col_gpu_kernel<<>>( num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, width_col, data_col_); @@ -817,7 +818,7 @@ void modulated_deformable_col2im_cuda( const scalar_t *data_mask_ = data_mask.data(); scalar_t *grad_im_ = grad_im.data(); - modulated_deformable_col2im_gpu_kernel<<>>( + modulated_deformable_col2im_gpu_kernel<<>>( num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, @@ -852,7 +853,7 @@ void modulated_deformable_col2im_coord_cuda( scalar_t *grad_offset_ = grad_offset.data(); scalar_t *grad_mask_ = grad_mask.data(); - modulated_deformable_col2im_coord_gpu_kernel<<>>( + modulated_deformable_col2im_coord_gpu_kernel<<>>( num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, diff --git a/rodnet/ops/dcn/src/deform_conv_3d_cuda.cpp b/rodnet/ops/dcn/src/deform_conv_3d_cuda.cpp index 8b5bea4..63ee1a5 100644 --- a/rodnet/ops/dcn/src/deform_conv_3d_cuda.cpp +++ b/rodnet/ops/dcn/src/deform_conv_3d_cuda.cpp @@ -71,26 +71,26 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, at::Tensor weight, int kH, int kW, int kT, int dH, int dW, int dT, int padH, int padW, int padT, int dilationH, int dilationW, int dilationT, int group, int deformable_group) { - AT_CHECK(weight.ndimension() == 5, + TORCH_CHECK(weight.ndimension() == 5, "5D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " "but got: %s", weight.ndimension()); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - AT_CHECK(kW > 0 && kH > 0 && kT > 0, + TORCH_CHECK(kW > 0 && kH > 0 && kT > 0, "kernel size should be greater than zero, but got kH: %d kW: %d kT: %d", kH, kW, kT); - AT_CHECK((weight.size(2) == kT && weight.size(3) == kH && weight.size(4) == kW), + TORCH_CHECK((weight.size(2) == kT && weight.size(3) == kH && weight.size(4) == kW), "kernel size should be consistent with weight, ", "but got kH: %d kW: %d kT: %d weight.size(2): %d, weight.size(3): %d, weight.size(4): %d", kH, kW, kT, weight.size(2), weight.size(3), weight.size(4)); - AT_CHECK(dW > 0 && dH > 0 && dT > 0, + TORCH_CHECK(dW > 0 && dH > 0 && dT > 0, "stride should be greater than zero, but got dH: %d dW: %d dT: %d", dH, dW, dT); - AT_CHECK( + TORCH_CHECK( dilationW > 0 && dilationH > 0 && dilationT > 0, "dilation should be greater than 0, but got dilationH: %d dilationW: %d dilationT: %d", dilationH, dilationW, dilationT); @@ -108,7 +108,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, dimw++; } - AT_CHECK(ndim == 4 || ndim == 5, "4D or 5D input tensor expected but got: %s", + TORCH_CHECK(ndim == 4 || ndim == 5, "4D or 5D input tensor expected but got: %s", ndim); long nInputPlane = weight.size(1) * group; @@ -123,7 +123,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, long outputTime = (inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1; - AT_CHECK(nInputPlane % deformable_group == 0, + TORCH_CHECK(nInputPlane % deformable_group == 0, "input channels must divide deformable group size"); if (outputWidth < 1 || outputHeight < 1) @@ -133,27 +133,27 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, outputWidth); - AT_CHECK(input.size(1) == nInputPlane, + TORCH_CHECK(input.size(1) == nInputPlane, "invalid number of input planes, expected: %d, but got: %d", nInputPlane, input.size(1)); - AT_CHECK((inputHeight >= kH && inputWidth >= kW && inputTime >= kT), + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW && inputTime >= kT), "input data is smaller than kernel"); - AT_CHECK((offset.size(2) == outputTime && offset.size(3) == outputHeight && offset.size(4) == outputWidth), + TORCH_CHECK((offset.size(2) == outputTime && offset.size(3) == outputHeight && offset.size(4) == outputWidth), "invalid spatial size of offset, expected time: %d height: %d width: %d, but " "got time: %d height: %d width: %d", outputTime, outputHeight, outputWidth, offset.size(2), offset.size(3), offset.size(4)); - AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW * kT), + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW * kT), "invalid number of channels of offset"); if (gradOutput != NULL) { - AT_CHECK(gradOutput->size(dimf) == nOutputPlane, + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, "invalid number of gradOutput planes, expected: %d, but got: %d", nOutputPlane, gradOutput->size(dimf)); - AT_CHECK((gradOutput->size(dimt) == outputTime && + TORCH_CHECK((gradOutput->size(dimt) == outputTime && gradOutput->size(dimh) == outputHeight && gradOutput->size(dimw) == outputWidth), "invalid size of gradOutput, expected time: %d height: %d width: %d, but " @@ -214,7 +214,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, long outputTime = (inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1; - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, outputTime, outputHeight, outputWidth}); @@ -341,7 +341,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, long outputTime = (inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1; - AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); gradInput = gradInput.view({batchSize, nInputPlane, inputTime, inputHeight, inputWidth}); columns = at::zeros( {nInputPlane * kW * kH * kT, im2col_step * outputTime * outputHeight * outputWidth}, @@ -463,7 +463,7 @@ int deform_conv_backward_parameters_cuda( long outputTime = (inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1; - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); columns = at::zeros( {nInputPlane * kW * kH * kT, im2col_step * outputHeight * outputWidth * outputTime}, @@ -543,8 +543,8 @@ void modulated_deform_conv_cuda_forward( const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int deformable_group, const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); @@ -625,8 +625,8 @@ void modulated_deform_conv_cuda_backward( int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); diff --git a/rodnet/ops/dcn/src/deform_conv_3d_cuda_kernel.cu b/rodnet/ops/dcn/src/deform_conv_3d_cuda_kernel.cu index 07794d7..7a84c2b 100644 --- a/rodnet/ops/dcn/src/deform_conv_3d_cuda_kernel.cu +++ b/rodnet/ops/dcn/src/deform_conv_3d_cuda_kernel.cu @@ -63,6 +63,7 @@ #include #include #include +#include #include #include #include @@ -321,7 +322,7 @@ void deformable_im2col( const scalar_t *data_offset_ = data_offset.data(); scalar_t *data_col_ = data_col.data(); - deformable_im2col_gpu_kernel<<>>( + deformable_im2col_gpu_kernel<<>>( num_kernels, data_im_, data_offset_, time, height, width, ksize_t, ksize_h, ksize_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, channels, deformable_group, @@ -437,7 +438,7 @@ void deformable_col2im( const scalar_t *data_offset_ = data_offset.data(); scalar_t *grad_im_ = grad_im.data(); - deformable_col2im_gpu_kernel<<>>( + deformable_col2im_gpu_kernel<<>>( num_kernels, data_col_, data_offset_, channels, time, height, width, ksize_t, ksize_h, ksize_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w, channel_per_deformable_group, @@ -560,7 +561,7 @@ void deformable_col2im_coord( const scalar_t *data_offset_ = data_offset.data(); scalar_t *grad_offset_ = grad_offset.data(); - deformable_col2im_coord_gpu_kernel<<>>( + deformable_col2im_coord_gpu_kernel<<>>( num_kernels, data_col_, data_im_, data_offset_, channels, time, height, width, ksize_t, ksize_h, ksize_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w, channel_per_deformable_group, @@ -890,7 +891,7 @@ void modulated_deformable_im2col_cuda( const scalar_t *data_mask_ = data_mask.data(); scalar_t *data_col_ = data_col.data(); - modulated_deformable_im2col_gpu_kernel<<>>( + modulated_deformable_im2col_gpu_kernel<<>>( num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, width_col, data_col_); @@ -922,7 +923,7 @@ void modulated_deformable_col2im_cuda( const scalar_t *data_mask_ = data_mask.data(); scalar_t *grad_im_ = grad_im.data(); - modulated_deformable_col2im_gpu_kernel<<>>( + modulated_deformable_col2im_gpu_kernel<<>>( num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, @@ -957,7 +958,7 @@ void modulated_deformable_col2im_coord_cuda( scalar_t *grad_offset_ = grad_offset.data(); scalar_t *grad_mask_ = grad_mask.data(); - modulated_deformable_col2im_coord_gpu_kernel<<>>( + modulated_deformable_col2im_coord_gpu_kernel<<>>( num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, diff --git a/rodnet/ops/dcn/src/deform_pool_2d_cuda.cpp b/rodnet/ops/dcn/src/deform_pool_2d_cuda.cpp index 9e0e3ff..f6f087b 100644 --- a/rodnet/ops/dcn/src/deform_pool_2d_cuda.cpp +++ b/rodnet/ops/dcn/src/deform_pool_2d_cuda.cpp @@ -33,7 +33,7 @@ void deform_psroi_pooling_cuda_forward( at::Tensor top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); @@ -59,8 +59,8 @@ void deform_psroi_pooling_cuda_backward( const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { - AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); diff --git a/rodnet/ops/dcn/src/deform_pool_2d_cuda_kernel.cu b/rodnet/ops/dcn/src/deform_pool_2d_cuda_kernel.cu index 05b00d4..308fe1e 100644 --- a/rodnet/ops/dcn/src/deform_pool_2d_cuda_kernel.cu +++ b/rodnet/ops/dcn/src/deform_pool_2d_cuda_kernel.cu @@ -9,6 +9,7 @@ // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu #include +#include #include #include #include @@ -296,7 +297,7 @@ void DeformablePSROIPoolForward(const at::Tensor data, scalar_t *top_data = out.data(); scalar_t *top_count_data = top_count.data(); - DeformablePSROIPoolForwardKernel<<>>( + DeformablePSROIPoolForwardKernel<<>>( count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim, group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); @@ -349,7 +350,7 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data(); const scalar_t *top_count_data = top_count.data(); - DeformablePSROIPoolBackwardAccKernel<<>>( + DeformablePSROIPoolBackwardAccKernel<<>>( count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, diff --git a/rodnet/ops/dcn/src/deform_pool_3d_cuda.cpp b/rodnet/ops/dcn/src/deform_pool_3d_cuda.cpp index 9e0e3ff..f6f087b 100644 --- a/rodnet/ops/dcn/src/deform_pool_3d_cuda.cpp +++ b/rodnet/ops/dcn/src/deform_pool_3d_cuda.cpp @@ -33,7 +33,7 @@ void deform_psroi_pooling_cuda_forward( at::Tensor top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); @@ -59,8 +59,8 @@ void deform_psroi_pooling_cuda_backward( const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { - AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); at::DeviceGuard guard(input.device()); const int batch = input.size(0); diff --git a/rodnet/ops/dcn/src/deform_pool_3d_cuda_kernel.cu b/rodnet/ops/dcn/src/deform_pool_3d_cuda_kernel.cu index 05b00d4..308fe1e 100644 --- a/rodnet/ops/dcn/src/deform_pool_3d_cuda_kernel.cu +++ b/rodnet/ops/dcn/src/deform_pool_3d_cuda_kernel.cu @@ -9,6 +9,7 @@ // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu #include +#include #include #include #include @@ -296,7 +297,7 @@ void DeformablePSROIPoolForward(const at::Tensor data, scalar_t *top_data = out.data(); scalar_t *top_count_data = top_count.data(); - DeformablePSROIPoolForwardKernel<<>>( + DeformablePSROIPoolForwardKernel<<>>( count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim, group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); @@ -349,7 +350,7 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data(); const scalar_t *top_count_data = top_count.data(); - DeformablePSROIPoolBackwardAccKernel<<>>( + DeformablePSROIPoolBackwardAccKernel<<>>( count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, diff --git a/tools/test.py b/tools/test.py index 59a07c3..295579b 100644 --- a/tools/test.py +++ b/tools/test.py @@ -142,7 +142,7 @@ def parse_args(): data_root = dataset_configs['data_root'] if not args.demo: - seq_names = sorted(os.listdir(os.path.join(data_root, dataset_configs['test']['subdir']))) + seq_names = sorted(os.listdir(os.path.join(data_root, dataset_configs['train']['subdir']))) else: seq_names = sorted(os.listdir(os.path.join(data_root, dataset_configs['demo']['subdir']))) print(seq_names) @@ -160,7 +160,7 @@ def parse_args(): for subset in seq_names: print(subset) if not args.demo: - crdata_test = CRDataset(data_dir=args.data_dir, dataset=dataset, config_dict=config_dict, split='test', + crdata_test = CRDataset(data_dir=args.data_dir, dataset=dataset, config_dict=config_dict, split='train', noise_channel=args.use_noise_channel, subset=subset, is_random_chirp=False) else: crdata_test = CRDataset(data_dir=args.data_dir, dataset=dataset, config_dict=config_dict, split='demo',