Skip to content

Commit

Permalink
moar ToT * T progress
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Nov 20, 2023
1 parent c2998a5 commit e4eb2c9
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 122 deletions.
299 changes: 181 additions & 118 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,26 @@ class ContEngine : public BinaryEngine<Derived> {

protected:
op_type op_; ///< Tile operation
using tile_element_type = typename value_type::value_type;
std::function<void(tile_element_type&, const tile_element_type&,
const tile_element_type&)>
inner_tile_nonreturn_op_; ///< Tile element operation (only non-null for
///< nested tensor expressions)
std::function<tile_element_type(const tile_element_type&,
const tile_element_type&)>
inner_tile_return_op_; ///< Same as inner_tile_nonreturn_op_ but returns
///< the result

// tile types of the result and (after evaluation) left and right arguments
using result_tile_type = value_type;
using left_tile_type = typename EngineTrait<left_type>::eval_type;
using right_tile_type = typename EngineTrait<right_type>::eval_type;

// tile element types of the result and (after evaluation) left and right
// arguments
using result_tile_element_type = typename result_tile_type::value_type;
using left_tile_element_type = typename left_tile_type::value_type;
using right_tile_element_type = typename right_tile_type::value_type;

std::function<void(result_tile_element_type&, const left_tile_element_type&,
const right_tile_element_type&)>
element_nonreturn_op_; ///< Tile element operation (only non-null for
///< nested tensor expressions)
std::function<result_tile_element_type(const left_tile_element_type&,
const right_tile_element_type&)>
element_return_op_; ///< Same as inner_tile_nonreturn_op_ but returns
///< the result
TiledArray::detail::ProcGrid
proc_grid_; ///< Process grid for the contraction
size_type K_ = 1; ///< Inner dimension size
Expand Down Expand Up @@ -239,8 +250,8 @@ class ContEngine : public BinaryEngine<Derived> {
// precondition checks
// 1. if ToT inner tile op has been initialized
if constexpr (TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
TA_ASSERT(inner_tile_nonreturn_op_);
TA_ASSERT(inner_tile_return_op_);
TA_ASSERT(element_nonreturn_op_);
TA_ASSERT(element_return_op_);
}

// Initialize children
Expand Down Expand Up @@ -271,7 +282,7 @@ class ContEngine : public BinaryEngine<Derived> {
op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_),
(permute_tiles_ ? perm_ : BipartitePermutation{}),
this->inner_tile_nonreturn_op_);
this->element_nonreturn_op_);
}
trange_ = ContEngine_::make_trange(outer(perm_));
shape_ = ContEngine_::make_shape(outer(perm_));
Expand All @@ -284,7 +295,7 @@ class ContEngine : public BinaryEngine<Derived> {
// factor_ is absorbed into inner_tile_nonreturn_op_
op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_),
BipartitePermutation{}, this->inner_tile_nonreturn_op_);
BipartitePermutation{}, this->element_nonreturn_op_);
}
trange_ = ContEngine_::make_trange();
shape_ = ContEngine_::make_shape();
Expand Down Expand Up @@ -457,120 +468,172 @@ class ContEngine : public BinaryEngine<Derived> {

protected:
void init_inner_tile_op(const IndexList& inner_target_indices) {
if constexpr (TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
using inner_tile_type = typename value_type::value_type;
if constexpr (TiledArray::detail::is_tensor_of_tensor_v<result_tile_type>) {
constexpr bool tot_x_tot = TiledArray::detail::is_tensor_of_tensor_v<
result_tile_type, left_tile_type, right_tile_type>;
const auto inner_prod = this->inner_product_type();
TA_ASSERT(inner_prod == TensorProduct::Contraction ||
inner_prod == TensorProduct::Hadamard);
if (inner_prod == TensorProduct::Contraction) {
using inner_tile_type = typename value_type::value_type;
using contract_inner_tile_type =
TiledArray::detail::ContractReduce<inner_tile_type, inner_tile_type,
inner_tile_type, scalar_type>;
// factor_ is absorbed into inner_tile_nonreturn_op_
auto contrreduce_op =
(inner_target_indices != inner(this->indices_))
? contract_inner_tile_type(
to_cblas_op(this->left_inner_permtype_),
to_cblas_op(this->right_inner_permtype_), this->factor_,
inner_size(this->indices_),
inner_size(this->left_indices_),
inner_size(this->right_indices_),
(this->permute_tiles_ ? inner(this->perm_)
: Permutation{}))
: contract_inner_tile_type(
to_cblas_op(this->left_inner_permtype_),
to_cblas_op(this->right_inner_permtype_), this->factor_,
inner_size(this->indices_),
inner_size(this->left_indices_),
inner_size(this->right_indices_));
this->inner_tile_nonreturn_op_ = [contrreduce_op](
inner_tile_type& result,
const inner_tile_type& left,
const inner_tile_type& right) {
contrreduce_op(result, left, right);
};
TA_ASSERT(tot_x_tot);
if constexpr (tot_x_tot) {
using op_type = TiledArray::detail::ContractReduce<
result_tile_element_type, left_tile_element_type,
right_tile_element_type, scalar_type>;
// factor_ is absorbed into inner_tile_nonreturn_op_
auto contrreduce_op =
(inner_target_indices != inner(this->indices_))
? op_type(to_cblas_op(this->left_inner_permtype_),
to_cblas_op(this->right_inner_permtype_),
this->factor_, inner_size(this->indices_),
inner_size(this->left_indices_),
inner_size(this->right_indices_),
(this->permute_tiles_ ? inner(this->perm_)
: Permutation{}))
: op_type(to_cblas_op(this->left_inner_permtype_),
to_cblas_op(this->right_inner_permtype_),
this->factor_, inner_size(this->indices_),
inner_size(this->left_indices_),
inner_size(this->right_indices_));
this->element_nonreturn_op_ =
[contrreduce_op](result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
contrreduce_op(result, left, right);
};
} // ToT x ToT
} else if (inner_prod == TensorProduct::Hadamard) {
// inner tile op depends on the outer op ... e.g. if outer op
// is contract then inner must implement (ternary) multiply-add;
// if the outer is hadamard then the inner is binary multiply
const auto outer_prod = this->product_type();
if (this->factor_ == 1) {
using base_op_type =
TiledArray::detail::Mult<inner_tile_type, inner_tile_type,
inner_tile_type, false, false>;
using op_type = TiledArray::detail::BinaryWrapper<
base_op_type>; // can't consume inputs if they are used multiple
// times, e.g. when outer op is gemm
auto mult_op = (inner_target_indices != inner(this->indices_))
? op_type(base_op_type(), this->permute_tiles_
? inner(this->perm_)
: Permutation{})
: op_type(base_op_type());
this->inner_tile_nonreturn_op_ = [mult_op, outer_prod](
inner_tile_type& result,
const inner_tile_type& left,
const inner_tile_type& right) {
if (outer_prod == TensorProduct::Hadamard)
result = mult_op(left, right);
else {
TA_ASSERT(outer_prod == TensorProduct::Hadamard ||
outer_prod == TensorProduct::Contraction);
// there is currently no fused MultAdd ternary Op, only Add and
// Mult thus implement this as 2 separate steps
// TODO optimize by implementing (ternary) MultAdd
if (empty(result))
result = mult_op(left, right);
else {
auto result_increment = mult_op(left, right);
add_to(result, result_increment);
}
}
};
} else {
using base_op_type =
TiledArray::detail::ScalMult<inner_tile_type, inner_tile_type,
inner_tile_type, scalar_type, false,
false>;
using op_type = TiledArray::detail::BinaryWrapper<
base_op_type>; // can't consume inputs if they are used multiple
// times, e.g. when outer op is gemm
auto mult_op = (inner_target_indices != inner(this->indices_))
? op_type(base_op_type(this->factor_),
this->permute_tiles_ ? inner(this->perm_)
: Permutation{})
: op_type(base_op_type(this->factor_));
this->inner_tile_nonreturn_op_ = [mult_op, outer_prod](
inner_tile_type& result,
const inner_tile_type& left,
const inner_tile_type& right) {
TA_ASSERT(outer_prod == TensorProduct::Hadamard ||
outer_prod == TensorProduct::Contraction);
if (outer_prod == TensorProduct::Hadamard)
result = mult_op(left, right);
else {
// there is currently no fused MultAdd ternary Op, only Add and
// Mult thus implement this as 2 separate steps
// TODO optimize by implementing (ternary) MultAdd
if (empty(result))
result = mult_op(left, right);
else {
auto result_increment = mult_op(left, right);
add_to(result, result_increment);
}
}
TA_ASSERT(tot_x_tot);
if constexpr (tot_x_tot) {
// inner tile op depends on the outer op ... e.g. if outer op
// is contract then inner must implement (ternary) multiply-add;
// if the outer is hadamard then the inner is binary multiply
const auto outer_prod = this->product_type();
if (this->factor_ == 1) {
using base_op_type =
TiledArray::detail::Mult<result_tile_element_type,
left_tile_element_type,
right_tile_element_type, false, false>;
using op_type = TiledArray::detail::BinaryWrapper<
base_op_type>; // can't consume inputs if they are used
// multiple times, e.g. when outer op is gemm
auto mult_op =
(inner_target_indices != inner(this->indices_))
? op_type(base_op_type(), this->permute_tiles_
? inner(this->perm_)
: Permutation{})
: op_type(base_op_type());
this->element_nonreturn_op_ =
[mult_op, outer_prod](result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
if (outer_prod == TensorProduct::Hadamard)
result = mult_op(left, right);
else {
TA_ASSERT(outer_prod == TensorProduct::Hadamard ||
outer_prod == TensorProduct::Contraction);
// there is currently no fused MultAdd ternary Op, only Add
// and Mult thus implement this as 2 separate steps
// TODO optimize by implementing (ternary) MultAdd
if (empty(result))
result = mult_op(left, right);
else {
auto result_increment = mult_op(left, right);
add_to(result, result_increment);
}
}
};
} else {
using base_op_type = TiledArray::detail::ScalMult<
result_tile_element_type, left_tile_element_type,
right_tile_element_type, scalar_type, false, false>;
using op_type = TiledArray::detail::BinaryWrapper<
base_op_type>; // can't consume inputs if they are used
// multiple times, e.g. when outer op is gemm
auto mult_op =
(inner_target_indices != inner(this->indices_))
? op_type(base_op_type(this->factor_),
this->permute_tiles_ ? inner(this->perm_)
: Permutation{})
: op_type(base_op_type(this->factor_));
this->element_nonreturn_op_ =
[mult_op, outer_prod](result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
TA_ASSERT(outer_prod == TensorProduct::Hadamard ||
outer_prod == TensorProduct::Contraction);
if (outer_prod == TensorProduct::Hadamard)
result = mult_op(left, right);
else {
// there is currently no fused MultAdd ternary Op, only Add
// and Mult thus implement this as 2 separate steps
// TODO optimize by implementing (ternary) MultAdd
if (empty(result))
result = mult_op(left, right);
else {
auto result_increment = mult_op(left, right);
add_to(result, result_increment);
}
}
};
}
} // ToT x ToT
} else if (inner_prod == TensorProduct::General) {
TA_ASSERT(!tot_x_tot);
constexpr bool tot_x_t =
TiledArray::detail::is_tensor_of_tensor_v<result_tile_type,
left_tile_type> &&
TiledArray::detail::is_tensor_v<right_tile_type>;
constexpr bool t_x_tot =
TiledArray::detail::is_tensor_of_tensor_v<result_tile_type,
right_tile_type> &&
TiledArray::detail::is_tensor_v<left_tile_type>;
if constexpr (tot_x_t || t_x_tot) {
using arg_tile_element_type =
std::conditional_t<tot_x_t, left_tile_element_type,
right_tile_element_type>;
using scalar_type =
std::conditional_t<tot_x_t, right_tile_element_type,
left_tile_element_type>;

auto scal_op = [do_perm = this->permute_tiles_,
perm = this->permute_tiles_ ? inner(this->perm_)
: Permutation{}](
const left_tile_element_type& left,
const right_tile_element_type& right)
-> result_tile_element_type {
using TiledArray::scale;
if constexpr (tot_x_t) {
if (do_perm)
return scale(left, right, perm);
else
return scale(left, right);
} else if constexpr (tot_x_t) {
if (do_perm)
return scale(right, left, perm);
else
return scale(right, left);
} else
abort(); // unreachable
};
this->element_nonreturn_op_ =
[scal_op](result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
result = scal_op(left, right);
};
}
} else
abort(); // unsupported TensorProduct type
TA_ASSERT(inner_tile_nonreturn_op_);
this->inner_tile_return_op_ =
[inner_tile_nonreturn_op = this->inner_tile_nonreturn_op_](
const inner_tile_type& left, const inner_tile_type& right) {
inner_tile_type result;
inner_tile_nonreturn_op(result, left, right);
return result;
};
TA_ASSERT(element_nonreturn_op_);
this->element_return_op_ = [inner_tile_nonreturn_op =
this->element_nonreturn_op_](
const left_tile_element_type& left,
const right_tile_element_type& right) {
result_tile_element_type result;
inner_tile_nonreturn_op(result, left, right);
return result;
};
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/TiledArray/expressions/mult_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {
// dimensions as well
return op_type(op_base_type());
} else if (inner_prod == TensorProduct::Contraction) {
return op_type(op_base_type(this->inner_tile_return_op_));
return op_type(op_base_type(this->element_return_op_));
} else
abort();
} else { // plain tensors
Expand All @@ -431,7 +431,7 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {
// dimensions as well
return op_type(op_base_type(), perm);
} else if (inner_prod == TensorProduct::Contraction) {
return op_type(op_base_type(this->inner_tile_return_op_), perm);
return op_type(op_base_type(this->element_return_op_), perm);
} else
abort();
} else { // plain tensor
Expand Down
3 changes: 3 additions & 0 deletions src/TiledArray/expressions/product.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ inline TensorProduct compute_product_type(const IndexList& left_indices,
result = TensorProduct::Hadamard;
else
result = TensorProduct::Contraction;
} else if ((left_indices && !right_indices) ||
(!left_indices && right_indices)) { // used for ToT*T or T*ToT
result = TensorProduct::General;
}
return result;
}
Expand Down
2 changes: 2 additions & 0 deletions src/TiledArray/tile_op/scal.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class Scal {
return Scal_::template eval<can_consume>(arg);
}

void set_factor(const scalar_type factor) { factor_ = factor; }

}; // class Scal

} // namespace detail
Expand Down
Loading

0 comments on commit e4eb2c9

Please sign in to comment.