Skip to content

Commit

Permalink
Merge pull request #180 from Krzmbrzl/fix-tensor-hashing
Browse files Browse the repository at this point in the history
Fix hash_value template resolution
  • Loading branch information
evaleev authored Feb 15, 2024
2 parents 3df7bfd + 074dc6a commit e8aab56
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 45 deletions.
68 changes: 27 additions & 41 deletions SeQuant/core/hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ constexpr hash::Impl hash_version() {
#endif
}

class Expr;

namespace detail {
template <typename T, typename Enabler = void>
struct has_hash_value_member_fn_helper : public std::false_type {};
Expand All @@ -54,30 +52,15 @@ template <typename T>
static constexpr bool has_hash_value_member_fn_v =
detail::has_hash_value_member_fn_helper<T>::value;

#ifdef SEQUANT_USE_SYSTEM_BOOST_HASH
#if SEQUANT_BOOST_VERSION < 108100
template <typename T,
typename = std::enable_if_t<has_hash_value_member_fn_v<T>>>
auto hash_value(const T& obj) {
return obj.hash_value();
}
#endif
#endif // SEQUANT_USE_SYSTEM_BOOST_HASH

namespace detail {

template <typename T, typename = std::void_t<>>
struct has_boost_hash_value : std::false_type {};

#ifdef SEQUANT_USE_SYSTEM_BOOST_HASH
template <typename T>
struct has_boost_hash_value<
T, std::void_t<decltype(sequant_boost::hash_value(std::declval<T>()))>>
struct has_boost_hash_value<T, std::void_t<decltype(sequant_boost::hash_value(
std::declval<const T&>()))>>
: std::true_type {};
#endif // SEQUANT_USE_SYSTEM_BOOST_HASH

template <typename T>
constexpr bool has_boost_hash_value_v = has_boost_hash_value<const T&>::value;

template <typename T, typename = std::void_t<>>
struct has_hash_value : std::false_type {};
Expand All @@ -87,12 +70,30 @@ struct has_hash_value<
T, std::void_t<decltype(hash_value(std::declval<const T&>()))>>
: std::true_type {};

} // namespace detail

template <typename T>
constexpr bool has_hash_value_v = has_hash_value<T>::value;
constexpr bool has_boost_hash_value_v =
detail::has_boost_hash_value<const T&>::value;

} // namespace detail
template <typename T>
constexpr bool has_hash_value_v = detail::has_hash_value<T>::value;

// hash_value specialization for types that have a hash_value member function
template <typename T,
std::enable_if_t<has_hash_value_member_fn_v<T>, short> = 0>
auto hash_value(const T& obj) {
return obj.hash_value();
}

using sequant_boost::hash_value;
// hash_value specialization that don't have a hash_value member function but
// have an applicable boost::hash_value function
template <typename T, std::enable_if_t<!has_hash_value_member_fn_v<T> &&
has_boost_hash_value_v<T>,
int> = 0>
auto hash_value(const T& obj) {
return sequant_boost::hash_value(obj);
}

// clang-format off
// rationale:
Expand Down Expand Up @@ -187,13 +188,7 @@ inline void range(std::size_t& seed, It first, It last) {
template <typename It>
std::size_t hash_range(It begin, It end) {
if (begin != end) {
std::size_t seed;
if constexpr (has_hash_value_member_fn_v<std::decay_t<decltype(*begin)>>)
seed = begin->hash_value();
else {
using sequant_boost::hash_value;
[[maybe_unused]] std::size_t seed = hash_value(*begin);
}
std::size_t seed = hash_value(*begin);
sequant_boost::hash_range(seed, begin + 1, end);
return seed;
} else
Expand All @@ -207,22 +202,13 @@ void hash_range(size_t& seed, It begin, It end) {
}

template <typename T>
struct _<
T, std::enable_if_t<!(detail::has_hash_value_v<T>)&&meta::is_range_v<T>>> {
struct _<T, std::enable_if_t<!(has_hash_value_v<T>)&&meta::is_range_v<T>>> {
std::size_t operator()(T const& v) const { return range(begin(v), end(v)); }
};

template <typename T>
struct _<T, std::enable_if_t<!(
!(detail::has_hash_value_v<T>)&&meta::is_range_v<T>)>> {
std::size_t operator()(T const& v) const {
if constexpr (has_hash_value_member_fn_v<T>)
return v.hash_value();
else {
using sequant_boost::hash_value;
return hash_value(v);
}
}
struct _<T, std::enable_if_t<!(!(has_hash_value_v<T>)&&meta::is_range_v<T>)>> {
std::size_t operator()(T const& v) const { return hash_value(v); }
};

template <typename T>
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_mbpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ TEST_CASE("NBodyOp", "[mbpt]") {
if constexpr (hash_version() == hash::Impl::BoostPre181) {
REQUIRE(
to_latex(simplify(f * t * t)) ==
to_latex(ex<Constant>(2) * f * t1 * t2 + f * t1 * t1 + f * t2 * t2));
to_latex(f * t1 * t1 + f * t2 * t2 + ex<Constant>(2) * f * t1 * t2));
} else {
// std::wcout << "to_latex(simplify(f * t * t)): "
// << to_latex(simplify(f * t * t)) << std::endl;
Expand All @@ -290,8 +290,8 @@ TEST_CASE("NBodyOp", "[mbpt]") {

if constexpr (hash_version() == hash::Impl::BoostPre181) {
REQUIRE(to_latex(simplify(f * t * t * t)) ==
to_latex(f * t1 * t1 * t1 + ex<Constant>(3) * f * t1 * t2 * t2 +
f * t2 * t2 * t2 + ex<Constant>(3) * f * t1 * t1 * t2));
to_latex(ex<Constant>(3) * f * t1 * t2 * t2 + f * t2 * t2 * t2 +
ex<Constant>(3) * f * t1 * t1 * t2 + f * t1 * t1 * t1));
} else {
// std::wcout << "to_latex(simplify(f * t * t * t): "
// << to_latex(simplify(f * t * t * t)) << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ TEST_CASE("Tensor", "[elements]") {
auto t2 = Tensor(L"F", {L"i_2"}, {L"i_1"});
size_t t2_hash;
REQUIRE_NOTHROW(t2_hash = hash_value(t2));
REQUIRE_NOTHROW(t1_hash != t2_hash);
REQUIRE(t1_hash != t2_hash);

} // SECTION("hash")

Expand Down

0 comments on commit e8aab56

Please sign in to comment.