Skip to content

Commit

Permalink
introduced TA::Tile::at_ordinal + strengthen disambiguation checks fo…
Browse files Browse the repository at this point in the history
…r potential at_ordinal uses
  • Loading branch information
evaleev committed Sep 22, 2024
1 parent f294db3 commit f613831
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ class Tensor {
const_reference operator[](const Ordinal ord) const {
TA_ASSERT(!this->empty());
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus assume at_ordinal() if this->rank()==1
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator[](index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
Expand All @@ -726,7 +726,7 @@ class Tensor {
reference operator[](const Ordinal ord) {
TA_ASSERT(!this->empty());
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus assume at_ordinal() if this->rank()==1
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator[](index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
Expand Down Expand Up @@ -848,7 +848,7 @@ class Tensor {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus assume at_ordinal() if this->rank()==1
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator()(index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
Expand All @@ -869,7 +869,7 @@ class Tensor {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus assume at_ordinal() if this->rank()==1
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator()(index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
Expand Down Expand Up @@ -960,6 +960,12 @@ class Tensor {
const_reference operator()(const Index&... i) const {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
TA_ASSERT(this->range().rank() == sizeof...(Index));
// can't distinguish between operator()(Index...) and operator()(ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator()(index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
using Int = std::common_type_t<Index...>;
const auto iord = this->range_.ordinal(
std::array<Int, sizeof...(Index)>{{static_cast<Int>(i)...}});
Expand All @@ -982,6 +988,12 @@ class Tensor {
reference operator()(const Index&... i) {
TA_ASSERT(!this->empty());
TA_ASSERT(this->nbatch() == 1);
TA_ASSERT(this->range().rank() == sizeof...(Index));
// can't distinguish between operator()(Index...) and operator()(ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range_.rank() != 1 &&
"use Tensor::operator()(index) or "
"Tensor::at_ordinal(index_ordinal) if this->range().rank()==1");
using Int = std::common_type_t<Index...>;
const auto iord = this->range_.ordinal(
std::array<Int, sizeof...(Index)>{{static_cast<Int>(i)...}});
Expand Down
52 changes: 52 additions & 0 deletions src/TiledArray/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ class Tile {
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
const_reference operator[](const Ordinal ord) const {
TA_ASSERT(pimpl_);
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range().rank() != 1 &&
"use Tile::operator[](index) or "
"Tile::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(tensor().range().includes_ordinal(ord));
return tensor().data()[ord];
}
Expand All @@ -264,6 +269,41 @@ class Tile {
template <typename Ordinal,
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
reference operator[](const Ordinal ord) {
TA_ASSERT(pimpl_);
// can't distinguish between operator[](Index...) and operator[](ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range().rank() != 1 &&
"use Tile::operator[](index) or "
"Tile::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(tensor().range().includes_ordinal(ord));
return tensor().data()[ord];
}

/// Const element accessor

/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Const reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty and ord is
/// included in the range
template <typename Ordinal,
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
const_reference at_ordinal(const Ordinal ord) const {
TA_ASSERT(pimpl_);
TA_ASSERT(tensor().range().includes_ordinal(ord));
return tensor().data()[ord];
}

/// Element accessor

/// \tparam Ordinal an integer type that represents an ordinal
/// \param[in] ord an ordinal index
/// \return Reference to the element at position \c ord .
/// \note This asserts (using TA_ASSERT) that this is not empty and ord is
/// included in the range
template <typename Ordinal,
std::enable_if_t<std::is_integral<Ordinal>::value>* = nullptr>
reference at_ordinal(const Ordinal ord) {
TA_ASSERT(pimpl_);
TA_ASSERT(tensor().range().includes_ordinal(ord));
return tensor().data()[ord];
Expand Down Expand Up @@ -401,6 +441,12 @@ class Tile {
detail::is_integral_list<Index...>::value>* = nullptr>
const_reference operator()(const Index&... i) const {
TA_ASSERT(pimpl_);
TA_ASSERT(this->range().rank() == sizeof...(Index));
// can't distinguish between operator()(Index...) and operator()(ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range().rank() != 1 &&
"use Tile::operator()(index) or "
"Tile::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(tensor().range().includes(i...));
return tensor().data()[tensor().range().ordinal(i...)];
}
Expand All @@ -417,6 +463,12 @@ class Tile {
detail::is_integral_list<Index...>::value>* = nullptr>
reference operator()(const Index&... i) {
TA_ASSERT(pimpl_);
TA_ASSERT(this->range().rank() == sizeof...(Index));
// can't distinguish between operator()(Index...) and operator()(ordinal)
// thus insist on at_ordinal() if this->rank()==1
TA_ASSERT(this->range().rank() != 1 &&
"use Tile::operator()(index) or "
"Tile::at_ordinal(index_ordinal) if this->range().rank()==1");
TA_ASSERT(tensor().range().includes(i...));
return tensor().data()[tensor().range().ordinal(i...)];
}
Expand Down

0 comments on commit f613831

Please sign in to comment.