Skip to content

Commit

Permalink
ta_cc_abcd: do not assume real double everywhere!
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Jul 22, 2024
1 parent 269c59b commit 0b5244f
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions examples/gemm/ta_cc_abcd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ void cc_abcd(TA::World& world, const TA::TiledRange1& trange_occ,
flops_per_fma * std::pow(n_occ, 2) * std::pow(n_uocc, 4) / 1e9;

// Construct tensors
TA::TArrayD t2(world, trange_oovv);
TA::TArrayD v(world, trange_vvvv);
TA::TArrayD t2_v;
TA::TSpArray<T> t2(world, trange_oovv);
TA::TSpArray<T> v(world, trange_vvvv);
TA::TSpArray<T> t2_v;
// To validate, fill input tensors with random data, otherwise just with 1s
if (do_validate) {
rand_fill_array(t2);
Expand Down Expand Up @@ -247,9 +247,9 @@ void cc_abcd(TA::World& world, const TA::TiledRange1& trange_occ,
// to validate replace: false -> true
if (do_validate) {
// obtain reference result using the high-level DSL
TA::TArrayD t2_v_ref;
TA::TSpArray<T> t2_v_ref;
t2_v_ref("i,j,a,b") = t2("i,j,c,d") * v("c,d,a,b");
TA::TArrayD error;
TA::TSpArray<T> error;
error("i,j,a,b") = t2_v_ref("i,j,a,b") - t2_v("i,j,a,b");
std::cout << "Validating the result (ignore the timings/performance!): "
"||ref_result - result||_2^2 = "
Expand Down Expand Up @@ -394,6 +394,7 @@ template <typename Tile, typename Policy>
void tensor_contract_444(TA::DistArray<Tile, Policy>& tv,
const TA::DistArray<Tile, Policy>& t,
const TA::DistArray<Tile, Policy>& v) {
using Shape = typename Policy::shape_type;
// for convenience, obtain the tiled ranges for the two kinds of dimensions
// used to define t, v, and tv
auto trange_occ = t.trange().dim(0); // the first dimension of t is occ
Expand All @@ -415,10 +416,10 @@ void tensor_contract_444(TA::DistArray<Tile, Policy>& tv,
auto ncols = n_uocc * n_uocc;
TA::detail::ProcGrid proc_grid(world, nrowtiles, ncoltiles, nrows, ncols);
std::shared_ptr<TA::Pmap> pmap;
auto t_eval = make_array_eval(t, t.world(), TA::DenseShape(),
auto t_eval = make_array_eval(t, t.world(), Shape(),
proc_grid.make_row_phase_pmap(ninttiles),
TA::Permutation(), make_array_noop<Tile>());
auto v_eval = make_array_eval(v, v.world(), TA::DenseShape(),
auto v_eval = make_array_eval(v, v.world(), Shape(),
proc_grid.make_col_phase_pmap(ninttiles),
TA::Permutation(), make_array_noop<Tile>());

Expand All @@ -437,7 +438,7 @@ void tensor_contract_444(TA::DistArray<Tile, Policy>& tv,
// 2. there will be a dummy output ArrayEval, its Futures will be set by the
// PTG
auto contract =
make_contract_eval(t_eval, v_eval, world, TA::DenseShape(), pmap,
make_contract_eval(t_eval, v_eval, world, Shape(), pmap,
TA::Permutation(), make_contract<Tile>(4u, 4u, 4u));

// eval() just schedules the Summa task and proceeds
Expand Down

0 comments on commit 0b5244f

Please sign in to comment.