Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Replaced the type of parameter k from int32 to nnvm::dim
Browse files Browse the repository at this point in the history
  • Loading branch information
ifeherva committed Jul 17, 2018
1 parent c7e3941 commit f54b7ac
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/operator/tensor/diag_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace mxnet {
namespace op {

struct DiagParam : public dmlc::Parameter<DiagParam> {
int32_t k;
nnvm::dim_t k;
DMLC_DECLARE_PARAMETER(DiagParam) {
DMLC_DECLARE_FIELD(k)
.set_default(0)
Expand All @@ -48,7 +48,7 @@ struct DiagParam : public dmlc::Parameter<DiagParam> {
}
};

inline TShape DiagShapeImpl(const TShape& ishape, int32_t k) {
inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
if (ishape.ndim() == 1) {
auto s = ishape[0] + std::abs(k);
return TShape({s, s});
Expand Down Expand Up @@ -105,7 +105,7 @@ template<int req>
struct diag {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
mshadow::Shape<2> ishape, int32_t k) {
mshadow::Shape<2> ishape, const nnvm::dim_t k) {
using namespace mxnet_op;
int j = 0;
if (k > 0) {
Expand All @@ -124,15 +124,15 @@ template<int req>
struct diag_gen {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
mshadow::Shape<2> oshape, int32_t k) {
mshadow::Shape<2> oshape, const nnvm::dim_t k) {
using namespace mxnet_op;

auto j = unravel(i, oshape);
if (j[1] == (j[0] + k)) {
auto l = j[0] < j[1] ? j[0] : j[1];
KERNEL_ASSIGN(out[i], req, a[l]);
} else {
KERNEL_ASSIGN(out[i], req, 0.0);
KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
}
}
};
Expand Down

0 comments on commit f54b7ac

Please sign in to comment.