Skip to content

Commit

Permalink
Merge pull request #477 from ValeevGroup/evaleev/feature/btas-retile
Browse files Browse the repository at this point in the history
better `btas::Tensor` interoperation with TA tensorials
  • Loading branch information
evaleev authored Sep 25, 2024
2 parents 0c7373c + 3c2f7e5 commit 2d9fd7f
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 11 deletions.
2 changes: 1 addition & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Both methods are supported. However, for most users we _strongly_ recommend to b
- Boost.Test: header-only or (optionally) as a compiled library, *only used for unit testing*
- Boost.Range: header-only, *only used for unit testing*
- [Range-V3](https://github.com/ericniebler/range-v3.git) -- a Ranges library that served as the basis for Ranges component of C++20 and later.
- [BTAS](http://github.com/ValeevGroup/BTAS), tag 4e8f5233aa7881dccdfcc37ce07128833926d3c2 . If usable BTAS installation is not found, TiledArray will download and compile
- [BTAS](http://github.com/ValeevGroup/BTAS), tag 4b3757cc2b5862f93589afc1e37523e543779c7a . If usable BTAS installation is not found, TiledArray will download and compile
BTAS from source. *This is the recommended way to compile BTAS for all users*.
- [MADNESS](https://github.com/m-a-d-n-e-s-s/madness), tag 95589b0d020a076f93d02eead6da654b23dd3d91 .
Only the MADworld runtime and BLAS/LAPACK C API component of MADNESS is used by TiledArray.
Expand Down
4 changes: 2 additions & 2 deletions external/versions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG 96ac90e8f193ccfaf16f346b4652927d2d362e75)
set(TA_TRACKED_MADNESS_VERSION 0.10.1)
set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1)

set(TA_TRACKED_BTAS_TAG 4e8f5233aa7881dccdfcc37ce07128833926d3c2)
set(TA_TRACKED_BTAS_PREVIOUS_TAG b7b2ea7513b087e35c6f1b26184a3904ac1e6b14)
set(TA_TRACKED_BTAS_TAG 4b3757cc2b5862f93589afc1e37523e543779c7a)
set(TA_TRACKED_BTAS_PREVIOUS_TAG 4e8f5233aa7881dccdfcc37ce07128833926d3c2)

set(TA_TRACKED_LIBRETT_TAG 6eed30d4dd2a5aa58840fe895dcffd80be7fbece)
set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 354e0ccee54aeb2f191c3ce2c617ebf437e49d83)
Expand Down
7 changes: 7 additions & 0 deletions src/TiledArray/external/btas.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ class boxrange_iteration_order<TiledArray::Range> {
static constexpr int value = row_major;
};

template <typename T, typename A>
class is_tensor<TiledArray::Tensor<T, A>> : public std::true_type {};

template <typename T, typename R, typename O>
class is_tensor<TiledArray::detail::TensorInterface<T, R, O>>
: public std::true_type {};

} // namespace btas

namespace TiledArray {
Expand Down
2 changes: 0 additions & 2 deletions src/TiledArray/range1.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ struct Range1 {
/// \return An iterator that points to the beginning of the local element set
const_iterator cend() const { return end(); }

/// @}

/// shifts this Range1

/// @param[in] shift the shift to apply
Expand Down
28 changes: 26 additions & 2 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ class Tensor {

/// Iterator factory

/// \return An iterator to the first data element
/// \return A const iterator to the first data element
const_iterator begin() const { return (this->data() ? this->data() : NULL); }

/// Iterator factory
Expand All @@ -1013,7 +1013,7 @@ class Tensor {

/// Iterator factory

/// \return An iterator to the last data element
/// \return A const iterator to the last data element
const_iterator end() const {
return (this->data() ? this->data() + this->size() : NULL);
}
Expand All @@ -1023,6 +1023,30 @@ class Tensor {
/// \return An iterator to the last data element
iterator end() { return (this->data() ? this->data() + this->size() : NULL); }

/// Iterator factory

/// \return A const iterator to the first data element
const_iterator cbegin() const { return (this->data() ? this->data() : NULL); }

/// Iterator factory

/// \return A const iterator to the first data element
const_iterator cbegin() { return (this->data() ? this->data() : NULL); }

/// Iterator factory

/// \return A const iterator to the last data element
const_iterator cend() const {
return (this->data() ? this->data() + this->size() : NULL);
}

/// Iterator factory

/// \return A const iterator to the last data element
const_iterator cend() {
return (this->data() ? this->data() + this->size() : NULL);
}

/// Read-only access to the data

/// \return A const pointer to the tensor data
Expand Down
82 changes: 78 additions & 4 deletions src/TiledArray/tensor/tensor_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ class TensorInterface {
/// \param idx The index pack
template <typename... Index>
reference operator()(const Index&... idx) {
TA_ASSERT(range_.includes(idx...));
return data_[range_.ordinal(idx...)];
const auto ord = range_.ordinal(idx...);
return data_[ord];
}

/// Element accessor
Expand All @@ -269,10 +269,84 @@ class TensorInterface {
/// \param idx The index pack
template <typename... Index>
const_reference operator()(const Index&... idx) const {
TA_ASSERT(range_.includes(idx...));
return data_[range_.ordinal(idx...)];
const auto ord = range_.ordinal(idx...);
return data_[ord];
}

/// \brief Tensor interface iterator type
///
/// Iterates over elements of a tensor interface whose range is iterable
template <typename TI = TensorInterface_>
class Iterator : public boost::iterator_facade<
Iterator<TI>,
std::conditional_t<std::is_const_v<TI>,
const typename TI::value_type,
typename TI::value_type>,
boost::forward_traversal_tag> {
public:
using range_iterator = typename TI::range_type::const_iterator;

Iterator(range_iterator idx_it, TI& ti) : idx_it(idx_it), ti(ti) {}

private:
range_iterator idx_it;
TI& ti;

friend class boost::iterator_core_access;

/// \brief increments this iterator
void increment() { ++idx_it; }

/// \brief Iterator comparer
/// \return true, if \c `*this==*other`
bool equal(Iterator const& other) const {
return this->idx_it == other.idx_it;
}

/// \brief dereferences this iterator
/// \return const reference to the current index
auto& dereference() const { return ti(*idx_it); }
};
friend class Iterator<TensorInterface_>;
friend class Iterator<const TensorInterface_>;

typedef Iterator<TensorInterface_> iterator; ///< Iterator type
typedef Iterator<const TensorInterface_> const_iterator; ///< Iterator type

/// Const begin iterator

/// \return An iterator that points to the beginning of this tensor view
const_iterator begin() const {
return const_iterator(range().begin(), *this);
}

/// Const end iterator

/// \return An iterator that points to the end of this tensor view
const_iterator end() const { return const_iterator(range().end(), *this); }

/// Nonconst begin iterator

/// \return An iterator that points to the beginning of this tensor view
iterator begin() { return iterator(range().begin(), *this); }

/// Nonconst begin iterator

/// \return An iterator that points to the beginning of this tensor view
iterator end() { return iterator(range().end(), *this); }

/// Const begin iterator

/// \return An iterator that points to the beginning of this tensor view
const_iterator cbegin() const {
return const_iterator(range().begin(), *this);
}

/// Const end iterator

/// \return An iterator that points to the end of this tensor view
const_iterator cend() const { return const_iterator(range().end(), *this); }

/// Check for empty view

/// \return \c false
Expand Down
20 changes: 20 additions & 0 deletions src/TiledArray/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,26 @@ class Tile {
/// \return A const iterator to the last data element
decltype(auto) end() const { return std::end(tensor()); }

/// Iterator factory

/// \return A const iterator to the first data element
decltype(auto) cbegin() { return std::cbegin(tensor()); }

/// Iterator factory

/// \return A const iterator to the first data element
decltype(auto) cbegin() const { return std::cbegin(tensor()); }

/// Iterator factory

/// \return A const iterator to the last data element
decltype(auto) cend() { return std::cend(tensor()); }

/// Iterator factory

/// \return A const iterator to the last data element
decltype(auto) cend() const { return std::cend(tensor()); }

// Data accessor -------------------------------------------------------

/// Data direct access
Expand Down
21 changes: 21 additions & 0 deletions tests/btas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,27 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(tensor_ctor, Tensor, tensor_types) {
BOOST_REQUIRE_NO_THROW(Tensor t1 = t0);
Tensor t1 = t0;
BOOST_CHECK(t1.empty());

// can copy TA::Tensor to btas::Tensor
TA::Tensor<typename Tensor::value_type> ta_tensor;
ta_tensor = make_rand_tile<decltype(ta_tensor)>(r);
BOOST_REQUIRE_NO_THROW(Tensor(ta_tensor));
Tensor t2(ta_tensor);
for (auto i : r) {
BOOST_CHECK_EQUAL(ta_tensor(i), t2(i));
}

// can copy TA::TensorInterface to btas::Tensor
{
const auto l = {3, 3, 3};
const auto u = r.upbound();
BOOST_REQUIRE(r.includes(l));
BOOST_REQUIRE_NO_THROW(Tensor(ta_tensor.block(l, u)));
Tensor t3(ta_tensor.block(l, u));
for (auto i : t3.range()) {
BOOST_CHECK_EQUAL(ta_tensor(i), t3(i));
}
}
}

BOOST_AUTO_TEST_CASE_TEMPLATE(copy, Array, array_types) {
Expand Down
2 changes: 2 additions & 0 deletions tests/expressions_btas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
*
*/

#include <TiledArray/config.h>

#ifdef TILEDARRAY_HAS_BTAS
#include "expressions_fixture.h"

Expand Down

0 comments on commit 2d9fd7f

Please sign in to comment.