Skip to content

Commit

Permalink
[Core][MPI] Adding MinLocAll and MaxLocAll to DataCommunicator
Browse files Browse the repository at this point in the history
  • Loading branch information
loumalouomega committed Oct 23, 2023
1 parent 5bf5e40 commit 218acb5
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 91 deletions.
55 changes: 28 additions & 27 deletions kratos/includes/data_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,33 @@ virtual void Max(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_
* The returned value is defined on all ranks.
*/
#ifndef KRATOS_BASE_DATA_COMMUNICATOR_DECLARE_ALLREDUCE_INTERFACE_FOR_TYPE
#define KRATOS_BASE_DATA_COMMUNICATOR_DECLARE_ALLREDUCE_INTERFACE_FOR_TYPE(...) \
virtual __VA_ARGS__ SumAll(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual std::vector<__VA_ARGS__> SumAll(const std::vector<__VA_ARGS__>& rLocalValues) const { \
return rLocalValues; \
} \
virtual void SumAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const { \
KRATOS_DATA_COMMUNICATOR_DEBUG_SIZE_CHECK(rLocalValues.size(), rGlobalValues.size(), "SumAll"); \
rGlobalValues = SumAll(rLocalValues); \
} \
virtual __VA_ARGS__ MinAll(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual std::vector<__VA_ARGS__> MinAll(const std::vector<__VA_ARGS__>& rLocalValues) const { \
return rLocalValues; \
} \
virtual void MinAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const { \
KRATOS_DATA_COMMUNICATOR_DEBUG_SIZE_CHECK(rLocalValues.size(), rGlobalValues.size(), "MinAll"); \
rGlobalValues = MinAll(rLocalValues); \
} \
virtual __VA_ARGS__ MaxAll(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual std::vector<__VA_ARGS__> MaxAll(const std::vector<__VA_ARGS__>& rLocalValues) const { \
return rLocalValues; \
} \
virtual void MaxAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const { \
KRATOS_DATA_COMMUNICATOR_DEBUG_SIZE_CHECK(rLocalValues.size(), rGlobalValues.size(), "MaxAll"); \
rGlobalValues = MaxAll(rLocalValues); \
} \

#define KRATOS_BASE_DATA_COMMUNICATOR_DECLARE_ALLREDUCE_INTERFACE_FOR_TYPE(...) \
virtual __VA_ARGS__ SumAll(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual std::vector<__VA_ARGS__> SumAll(const std::vector<__VA_ARGS__>& rLocalValues) const { \
return rLocalValues; \
} \
virtual void SumAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const { \
KRATOS_DATA_COMMUNICATOR_DEBUG_SIZE_CHECK(rLocalValues.size(), rGlobalValues.size(), "SumAll"); \
rGlobalValues = SumAll(rLocalValues); \
} \
virtual __VA_ARGS__ MinAll(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual std::vector<__VA_ARGS__> MinAll(const std::vector<__VA_ARGS__>& rLocalValues) const { \
return rLocalValues; \
} \
virtual void MinAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const { \
KRATOS_DATA_COMMUNICATOR_DEBUG_SIZE_CHECK(rLocalValues.size(), rGlobalValues.size(), "MinAll"); \
rGlobalValues = MinAll(rLocalValues); \
} \
virtual __VA_ARGS__ MaxAll(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual std::vector<__VA_ARGS__> MaxAll(const std::vector<__VA_ARGS__>& rLocalValues) const { \
return rLocalValues; \
} \
virtual void MaxAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const { \
KRATOS_DATA_COMMUNICATOR_DEBUG_SIZE_CHECK(rLocalValues.size(), rGlobalValues.size(), "MaxAll"); \
rGlobalValues = MaxAll(rLocalValues); \
} \
virtual std::pair<__VA_ARGS__, int> MinLocAll(const __VA_ARGS__& rLocalValue) const { return std::pair<__VA_ARGS__, int>(rLocalValue, 0); } \
virtual std::pair<__VA_ARGS__, int> MaxLocAll(const __VA_ARGS__& rLocalValue) const { return std::pair<__VA_ARGS__, int>(rLocalValue, 0); }
#endif

// Compute the partial sum of the given quantity from rank 0 to the current rank (included).
Expand All @@ -115,7 +116,7 @@ virtual void MaxAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__
*/
#ifndef KRATOS_BASE_DATA_COMMUNICATOR_DECLARE_SCANSUM_INTERFACE_FOR_TYPE
#define KRATOS_BASE_DATA_COMMUNICATOR_DECLARE_SCANSUM_INTERFACE_FOR_TYPE(...) \
virtual __VA_ARGS__ ScanSum(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual __VA_ARGS__ ScanSum(const __VA_ARGS__& rLocalValue) const { return rLocalValue; } \
virtual std::vector<__VA_ARGS__> ScanSum(const std::vector<__VA_ARGS__>& rLocalValues) const { \
return rLocalValues; \
} \
Expand Down
8 changes: 8 additions & 0 deletions kratos/mpi/includes/mpi_data_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ void MinAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS_
__VA_ARGS__ MaxAll(const __VA_ARGS__& rLocalValue) const override; \
std::vector<__VA_ARGS__> MaxAll(const std::vector<__VA_ARGS__>& rLocalValues) const override; \
void MaxAll(const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const override; \
std::pair<__VA_ARGS__, int> MinLocAll(const __VA_ARGS__& rLocalValue) const override; \
std::pair<__VA_ARGS__, int> MaxLocAll(const __VA_ARGS__& rLocalValue) const override;

#endif

Expand Down Expand Up @@ -380,6 +382,12 @@ class KRATOS_API(KRATOS_MPI_CORE) MPIDataCommunicator: public DataCommunicator
const std::vector<TDataType>& rLocalValues,
MPI_Op Operation) const;

template<class TDataType>
std::pair<TDataType, int> AllReduceDetailWithLocation(
const std::pair<TDataType, int>& rLocalValues,
MPI_Op Operation
) const;

template<class TDataType> void ScanDetail(
const TDataType& rLocalValues,
TDataType& rReducedValues,
Expand Down
56 changes: 49 additions & 7 deletions kratos/mpi/sources/mpi_data_communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
// Main author: Jordi Cotela
//

// System includes
#include <algorithm>

// External includes

// Project includes
#include "includes/parallel_environment.h"

#include "mpi/includes/mpi_data_communicator.h"
Expand Down Expand Up @@ -110,7 +115,14 @@ void MPIDataCommunicator::MaxAll(
const std::vector<__VA_ARGS__>& rLocalValues, std::vector<__VA_ARGS__>& rGlobalValues) const { \
AllReduceDetail(rLocalValues, rGlobalValues, MPI_MAX); \
} \

std::pair<__VA_ARGS__, int> MPIDataCommunicator::MinLocAll(const __VA_ARGS__& rLocalValue) const { \
std::pair<__VA_ARGS__, int> local_values({rLocalValue, Rank()}); \
return AllReduceDetailWithLocation(local_values, MPI_MINLOC); \
} \
std::pair<__VA_ARGS__, int> MPIDataCommunicator::MaxLocAll(const __VA_ARGS__& rLocalValue) const { \
std::pair<__VA_ARGS__, int> local_values({rLocalValue, Rank()}); \
return AllReduceDetailWithLocation(local_values, MPI_MAXLOC); \
}
#endif

#ifndef KRATOS_MPI_DATA_COMMUNICATOR_DEFINE_SCANSUM_INTERFACE_FOR_TYPE
Expand Down Expand Up @@ -586,7 +598,7 @@ template<class TDataType> void MPIDataCommunicator::ReduceDetail(
{
MPIMessage<TDataType> mpi_send_msg, mpi_recv_msg;

#ifdef KRATOS_DEBUG
#ifdef KRATOS_DEBUG
KRATOS_ERROR_IF_NOT(ErrorIfFalseOnAnyRank(IsValidRank(Root)))
<< "In call to MPI_Reduce: " << Root << " is not a valid rank." << std::endl;
const int local_size = mpi_send_msg.Size(rLocalValues);
Expand All @@ -597,7 +609,7 @@ template<class TDataType> void MPIDataCommunicator::ReduceDetail(
KRATOS_ERROR_IF(BroadcastErrorIfTrue(local_size != reduced_size,Root))
<< "Input error in call to MPI_Reduce for rank " << Root << ": "
<< "Sending " << local_size << " values " << "but receiving " << reduced_size << " values." << std::endl;
#endif // KRATOS_DEBUG
#endif // KRATOS_DEBUG

const int ierr = MPI_Reduce(
mpi_send_msg.Buffer(rLocalValues), mpi_recv_msg.Buffer(rReducedValues),
Expand Down Expand Up @@ -649,7 +661,7 @@ template<class TDataType> void MPIDataCommunicator::AllReduceDetail(
{
MPIMessage<TDataType> mpi_send_msg, mpi_recv_msg;

#ifdef KRATOS_DEBUG
#ifdef KRATOS_DEBUG
const int local_size = mpi_send_msg.Size(rLocalValues);
const int reduced_size = mpi_recv_msg.Size(rReducedValues);
KRATOS_ERROR_IF_NOT(IsEqualOnAllRanks(local_size))
Expand All @@ -658,7 +670,7 @@ template<class TDataType> void MPIDataCommunicator::AllReduceDetail(
KRATOS_ERROR_IF(ErrorIfTrueOnAnyRank(local_size != reduced_size))
<< "Input error in call to MPI_Allreduce for rank " << Rank() << ": "
<< "Sending " << local_size << " values " << "but receiving " << reduced_size << " values." << std::endl;
#endif // KRATOS_DEBUG
#endif // KRATOS_DEBUG

const int ierr = MPI_Allreduce(
mpi_send_msg.Buffer(rLocalValues), mpi_recv_msg.Buffer(rReducedValues),
Expand All @@ -669,8 +681,11 @@ template<class TDataType> void MPIDataCommunicator::AllReduceDetail(
mpi_recv_msg.Update(rReducedValues);
}

template<class TDataType> TDataType MPIDataCommunicator::AllReduceDetail(
const TDataType& rLocalValues, MPI_Op Operation) const
template<class TDataType>
TDataType MPIDataCommunicator::AllReduceDetail(
const TDataType& rLocalValues,
MPI_Op Operation
) const
{
TDataType global_values(rLocalValues);
AllReduceDetail(rLocalValues, global_values, Operation);
Expand All @@ -694,6 +709,33 @@ std::vector<TDataType> MPIDataCommunicator::AllReduceDetailVector(
return reduced_values;
}

template<class TDataType>
std::pair<TDataType, int> MPIDataCommunicator::AllReduceDetailWithLocation(
const std::pair<TDataType, int>& rLocalValues,
MPI_Op Operation
) const
{
struct {
TDataType value;
int rank;
} local_reduce, global_reduce;
local_reduce.value = rLocalValues.first;
local_reduce.rank = rLocalValues.second;
MPI_Datatype data_type;
if constexpr (std::is_same_v<TDataType, double>) {
data_type = MPI_DOUBLE_INT;
} else if constexpr (std::is_same_v<TDataType, long int>) {
data_type = MPI_LONG_INT;
} else if constexpr (std::is_same_v<TDataType, int>) {
data_type = MPI_2INT;
} else {
KRATOS_ERROR << "Unsupported type for AllReduceDetailWithLocation" << std::endl;
}
MPI_Allreduce(&local_reduce, &global_reduce, 1, data_type, Operation, mComm);
std::pair<TDataType, int> global_values({global_reduce.value, global_reduce.rank});
return global_values;
}

template<class TDataType> void MPIDataCommunicator::ScanDetail(
const TDataType& rLocalValues, TDataType& rReducedValues,
MPI_Op Operation) const
Expand Down
Loading

0 comments on commit 218acb5

Please sign in to comment.