Skip to content

Commit

Permalink
Tensor(OtherTensor, BipartitePermutation) can handle the case where o…
Browse files Browse the repository at this point in the history
…uter(BipartitePermutation) is null
  • Loading branch information
evaleev committed Jan 22, 2024
1 parent 3657752 commit 751bba4
Showing 1 changed file with 11 additions and 28 deletions.
39 changes: 11 additions & 28 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class Tensor {
/// \param perm The permutation that will be applied to the copy
/// \warning if `T1` is a tensor of tensors its elements are _cloned_ rather
/// than copied to make the semantics of this to be consistent
/// between tensors of scalars and tensors of scalars; specifically,
/// between tensors of scalars and tensors of tensors; specifically,
/// if `T1` is a tensor of scalars the constructed tensor is
/// is independent of \p other, thus should apply clone to inner
/// tensor nests to behave similarly for nested tensors
Expand All @@ -398,8 +398,14 @@ class Tensor {
detail::is_permutation_v<Perm>>::type* = nullptr>
Tensor(const T1& other, const Perm& perm)
: Tensor(outer(perm) * other.range(), 1, default_construct{false}) {
detail::tensor_init(value_converter<typename T1::value_type>, outer(perm),
*this, other);
const auto outer_perm = outer(perm);
if (outer_perm) {
detail::tensor_init(value_converter<typename T1::value_type>, outer_perm,
*this, other);
} else {
detail::tensor_init(value_converter<typename T1::value_type>, *this,
other);
}

// If we actually have a ToT the inner permutation was not applied above so
// we do that now
Expand All @@ -410,7 +416,7 @@ class Tensor {
// not match Tensor");
if constexpr (is_tot && is_bperm) {
if (inner_size(perm) != 0) {
auto inner_perm = inner(perm);
const auto inner_perm = inner(perm);
Permute<value_type, value_type> p;
for (auto& x : *this) x = p(x, inner_perm);
}
Expand Down Expand Up @@ -1285,30 +1291,7 @@ class Tensor {
constexpr bool is_tot = detail::is_tensor_of_tensor_v<Tensor>;
[[maybe_unused]] constexpr bool is_bperm =
detail::is_bipartite_permutation_v<Perm>;
// tile ops pass bipartite permutations here even if this is a plain tensor
// static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does
// not match Tensor");
if constexpr (!is_tot) {
if constexpr (is_bperm) {
TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation
return Tensor(*this, outer(perm));
} else
return Tensor(*this, perm);
} else {
// If we have a ToT we need to apply the permutation in two steps. The
// first step is identical to the non-ToT case (permute the outer modes)
// the second step does the inner modes
Tensor rv(*this, outer(perm));
if constexpr (is_bperm) {
if (inner_size(perm) != 0) {
auto inner_perm = inner(perm);
Permute<value_type, value_type> p;
for (auto& inner_t : rv) inner_t = p(inner_t, inner_perm);
}
}
return rv;
}
abort(); // unreachable
return Tensor(*this, perm);
}

/// Shift the lower and upper bound of this tensor
Expand Down

0 comments on commit 751bba4

Please sign in to comment.