forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSpatialFractionalMaxPooling.cu
113 lines (98 loc) · 4.02 KB
/
SpatialFractionalMaxPooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <THCUNN/THCUNN.h>
#include <THCUNN/common.h>
#include <THC/THCDeviceTensor.cuh>
#include <THC/THCDeviceTensorUtils.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <THC/THCAtomics.cuh>
#include <cfloat>
template <typename Dtype, typename Acctype>
__device__ inline int getInterval(Acctype sample,
int index,
int inputSize,
int outputSize,
int poolSize) {
Acctype alpha = (Acctype)(inputSize - poolSize) / (Acctype) (outputSize - 1);
if (index == outputSize - 1) {
return inputSize - poolSize;
} else {
return (int) ((index + sample) * alpha) - (int) (sample * alpha);
}
}
// We template on poolSizeW to allow the innermost loop to be unrolled
template <int PoolSizeWStatic, typename Dtype, typename Acctype>
__global__ void SpatialFractionalMaxPooling_updateOutput(
THCDeviceTensor<Dtype, 4> input,
THCDeviceTensor<Dtype, 4> output,
THCDeviceTensor<THCIndex_t, 4> indices,
THCDeviceTensor<Dtype, 3> samples,
int poolSizeW, int poolSizeH) {
// Output (h, w) point that this thread is responsible for
int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
// Each thread generates a specific output point
if (ourOutputPoint < output.getSize(2) * output.getSize(3)) {
int outputW = ourOutputPoint % output.getSize(3);
int outputH = ourOutputPoint / output.getSize(3);
int poolW = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][0]), outputW,
input.getSize(3), output.getSize(3), poolSizeW);
int poolH = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][1]), outputH,
input.getSize(2), output.getSize(2), poolSizeH);
Dtype maxVal = THCNumerics<Dtype>::min();
int maxIndex = -1;
for (int h = poolH; h < poolH + poolSizeH; ++h) {
if (PoolSizeWStatic == -1) {
for (int w = poolW; w < poolW + poolSizeW; ++w) {
Dtype val = input[batch][plane][h][w];
// for consistency with THNN, favor the first max
if (val > maxVal) {
maxIndex = h * input.getSize(3) + w;
maxVal = val;
}
}
} else {
#pragma unroll
for (int i = 0; i < PoolSizeWStatic; ++i) {
int w = i + poolW;
Dtype val = input[batch][plane][h][w];
// for consistency with THNN, favor the first max
if (val > maxVal) {
maxIndex = h * input.getSize(3) + w;
maxVal = val;
}
}
}
}
assert(THCNumerics<Dtype>::ne(maxVal, THCNumerics<Dtype>::min()));
assert(maxIndex != -1);
// +1 for Lua index
indices[batch][plane][outputH][outputW] = maxIndex + TH_INDEX_BASE;
output[batch][plane][outputH][outputW] = maxVal;
}
}
template <typename Dtype>
__global__ void SpatialFractionalMaxPooling_updateGradInput(
THCDeviceTensor<Dtype, 4> gradInput,
THCDeviceTensor<Dtype, 4> gradOutput,
THCDeviceTensor<THCIndex_t, 4> indices) {
// Output (h, w) point that this thread is responsible for
int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
// Each thread generates a specific output point
if (ourOutputPoint < gradOutput.getSize(2) * gradOutput.getSize(3)) {
int outputW = ourOutputPoint % gradOutput.getSize(3);
int outputH = ourOutputPoint / gradOutput.getSize(3);
int index = indices[batch][plane][outputH][outputW] - TH_INDEX_BASE;
assert(index >= 0);
int inputW = index % gradInput.getSize(3);
int inputH = index / gradInput.getSize(3);
assert(inputH < gradInput.getSize(2));
atomicAdd(gradInput[batch][plane][inputH][inputW].data(),
gradOutput[batch][plane][outputH][outputW]);
}
}
#include <THCUNN/generic/SpatialFractionalMaxPooling.cu>
#include <THC/THCGenerateFloatTypes.h>