forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAutogradComposite.cpp
95 lines (79 loc) · 4.06 KB
/
AutogradComposite.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <c10/util/SmallBuffer.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_has_same_storage_numel_native.h>
#include <ATen/ops/_make_dual_native.h>
#include <ATen/ops/_new_zeros_with_same_feature_meta_native.h>
#include <ATen/ops/_unpack_dual_native.h>
#include <ATen/ops/alias.h>
#include <ATen/ops/zeros.h>
#endif
namespace at {
namespace native {
// We expect this code to only be reached in inference mode and when all inputs are inference tensors
Tensor _make_dual(const Tensor& primal, const Tensor& tangent, int64_t level) {
TORCH_INTERNAL_ASSERT(
InferenceMode::is_enabled() && primal.is_inference() && tangent.is_inference(),
"Expected this function to only be reached in inference mode and when all the "
"inputs are inference tensors. You should NOT call this function directly as "
"native::_make_dual. Please use the dispatcher, i.e., at::_make_dual. Please "
"file an issue if you come across this error otherwise.");
return at::alias(primal);
}
/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
/// is a view of the dual and the tangent is returned as is.
/// This function is backward differentiable.
std::tuple<at::Tensor, at::Tensor> _unpack_dual(const at::Tensor& tensor, int64_t level) {
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor._fw_grad(level));
}
// NB: This function can be called directly from _set_fw_grad or
// if self is batched, from this function's batching rule.
// See NOTE: [_new_zeros_with_same_feature_meta] for more information.
Tensor _new_zeros_with_same_feature_meta(
const at::Tensor& self,
const at::Tensor& other,
int64_t self_num_batch_dims) {
auto other_sizes = other.sizes();
auto other_strides = other.strides();
auto other_storage_offset = other.storage_offset();
int64_t other_storage_numel = other.storage().nbytes() / other.itemsize();
if (self_num_batch_dims == 0) {
auto new_tensor = at::zeros({other_storage_numel}, other.options());
return new_tensor.as_strided(other_sizes, other_strides, other_storage_offset);
}
auto self_sizes = self.sizes();
// NB: We don't check that the sizes of self is the same as that of other
// because this function is also used in the inplace over view case
// In the inplace over view case we cannot rely on self and other being
// the same size. So we will use the size of other, and simply tack on
// the batch dims from self. For example: If self.sizes: [B, 2, 3],
// and other.size: [6], we return [B, 6].
// Also see the test test_inplace_on_view_not_same_layout, for when we reach
// this case.
constexpr int64_t kSmallBufferSizeHint = 8;
auto out_sizes = c10::SmallBuffer<int64_t, kSmallBufferSizeHint>(other.dim() + self_num_batch_dims);
std::copy(self_sizes.begin(), self_sizes.begin() + self_num_batch_dims, out_sizes.begin());
std::copy(other_sizes.begin(), other_sizes.end(), out_sizes.begin() + self_num_batch_dims);
// We use the strides of other, and tack on the strides computed with
// the batch dims of self, so that the slices are arranged contiguously
auto out_strides = c10::SmallBuffer<int64_t, kSmallBufferSizeHint>(other.dim() + self_num_batch_dims);
int64_t prod = other_storage_numel;
for (int64_t i = self_num_batch_dims - 1; i >= 0; --i) {
out_strides[i] = prod;
prod *= self_sizes[i];
}
std::copy(other_strides.begin(), other_strides.end(), out_strides.begin() + self_num_batch_dims);
int64_t storage_numel = prod;
// Inherit the TensorOptions of the primal
auto new_tensor = at::zeros({storage_numel}, other.options());
return new_tensor.as_strided(out_sizes, out_strides, other_storage_offset);
}
bool _has_same_storage_numel(const at::Tensor& base, const at::Tensor& other) {
return base.storage().nbytes() / base.itemsize() == other.storage().nbytes() / other.itemsize();
}
} // namespace native
} // namespace at