From 1fa2d2037faf32dd9303216f56b1021c2b216669 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Fri, 9 Oct 2020 22:52:04 +0300 Subject: [PATCH 1/8] ARROW-10479: Get rid of code duplication at decimal type builders Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/array/array_decimal.cc | 29 ++++-------- cpp/src/arrow/array/array_decimal.h | 41 ++++++++--------- cpp/src/arrow/array/builder_decimal.cc | 57 +++++++----------------- cpp/src/arrow/array/builder_decimal.h | 56 +++++++++-------------- cpp/src/arrow/scalar.h | 25 +++++------ cpp/src/arrow/type.cc | 40 ++++++----------- cpp/src/arrow/type.h | 51 +++++++++------------ cpp/src/arrow/type_fwd.h | 23 +++++----- cpp/src/arrow/type_traits.h | 25 +++++------ cpp/src/arrow/util/basic_decimal.cc | 13 +++--- cpp/src/arrow/util/decimal_meta.h | 37 +++++++++++++++ cpp/src/arrow/util/decimal_type_traits.h | 41 +++++++++++++++++ 12 files changed, 221 insertions(+), 217 deletions(-) create mode 100644 cpp/src/arrow/util/decimal_meta.h create mode 100644 cpp/src/arrow/util/decimal_type_traits.h diff --git a/cpp/src/arrow/array/array_decimal.cc b/cpp/src/arrow/array/array_decimal.cc index d65f6ee53564f..b895ba72061c6 100644 --- a/cpp/src/arrow/array/array_decimal.cc +++ b/cpp/src/arrow/array/array_decimal.cc @@ -32,32 +32,21 @@ namespace arrow { using internal::checked_cast; -// ---------------------------------------------------------------------- -// Decimal128 -Decimal128Array::Decimal128Array(const std::shared_ptr& data) +template +BaseDecimalArray::BaseDecimalArray(const std::shared_ptr& data) : FixedSizeBinaryArray(data) { - ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL128); + ARROW_CHECK_EQ(data->type->id(), DecimalTypeTraits::Id); } -std::string Decimal128Array::FormatValue(int64_t i) const { - const auto& type_ = checked_cast(*type()); - const Decimal128 value(GetValue(i)); +template +std::string BaseDecimalArray::FormatValue(int64_t i) const { + const auto& type_ = checked_cast(*type()); + const ValueType value(GetValue(i)); return value.ToString(type_.scale()); } -// ---------------------------------------------------------------------- -// Decimal256 - -Decimal256Array::Decimal256Array(const std::shared_ptr& data) - : FixedSizeBinaryArray(data) { - ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL256); -} - -std::string Decimal256Array::FormatValue(int64_t i) const { - const auto& type_ = checked_cast(*type()); - const Decimal256 value(GetValue(i)); - return value.ToString(type_.scale()); -} +template class BaseDecimalArray<128>; +template class BaseDecimalArray<256>; } // namespace arrow diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index 8d7d1c59cd0a8..fbac3ac094507 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -22,45 +22,40 @@ #include #include "arrow/array/array_binary.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/array/data.h" #include "arrow/type.h" #include "arrow/util/visibility.h" namespace arrow { -// ---------------------------------------------------------------------- -// Decimal128Array - -/// Concrete Array class for 128-bit decimal data -class ARROW_EXPORT Decimal128Array : public FixedSizeBinaryArray { +/// Template Array class for decimal data +template +class BaseDecimalArray : public FixedSizeBinaryArray { public: - using TypeClass = Decimal128Type; + using TypeClass = typename DecimalTypeTraits::TypeClass; + using ValueType = typename DecimalTypeTraits::ValueType; using FixedSizeBinaryArray::FixedSizeBinaryArray; - /// \brief Construct Decimal128Array from ArrayData instance - explicit Decimal128Array(const std::shared_ptr& data); + /// \brief Construct DecimalArray from ArrayData instance + explicit BaseDecimalArray(const std::shared_ptr& data); std::string FormatValue(int64_t i) const; }; -// Backward compatibility -using DecimalArray = Decimal128Array; - -// ---------------------------------------------------------------------- -// Decimal256Array - -/// Concrete Array class for 256-bit decimal data -class ARROW_EXPORT Decimal256Array : public FixedSizeBinaryArray { - public: - using TypeClass = Decimal256Type; +/// Array class for decimal 128-bit data +class ARROW_EXPORT Decimal128Array : public BaseDecimalArray<128> { + using BaseDecimalArray<128>::BaseDecimalArray; +}; - using FixedSizeBinaryArray::FixedSizeBinaryArray; +/// Array class for decimal 256-bit data +class ARROW_EXPORT Decimal256Array : public BaseDecimalArray<256> { + using BaseDecimalArray<256>::BaseDecimalArray; +}; - /// \brief Construct Decimal256Array from ArrayData instance - explicit Decimal256Array(const std::shared_ptr& data); +// Backward compatibility +using DecimalArray = Decimal128Array; - std::string FormatValue(int64_t i) const; -}; } // namespace arrow diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index bd7615a730939..6e6f5dc8e199f 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -33,30 +33,35 @@ class Buffer; class MemoryPool; // ---------------------------------------------------------------------- -// Decimal128Builder +// BaseDecimalBuilder -Decimal128Builder::Decimal128Builder(const std::shared_ptr& type, +template +BaseDecimalBuilder::BaseDecimalBuilder(const std::shared_ptr& type, MemoryPool* pool) : FixedSizeBinaryBuilder(type, pool), - decimal_type_(internal::checked_pointer_cast(type)) {} + decimal_type_(internal::checked_pointer_cast(type)) {} -Status Decimal128Builder::Append(Decimal128 value) { +template +Status BaseDecimalBuilder::Append(ValueType value) { RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); UnsafeAppend(value); return Status::OK(); } -void Decimal128Builder::UnsafeAppend(Decimal128 value) { +template +void BaseDecimalBuilder::UnsafeAppend(ValueType value) { value.ToBytes(GetMutableValue(length())); - byte_builder_.UnsafeAdvance(16); + byte_builder_.UnsafeAdvance((width >> 3)); UnsafeAppendToBitmap(true); } -void Decimal128Builder::UnsafeAppend(util::string_view value) { +template +void BaseDecimalBuilder::UnsafeAppend(util::string_view value) { FixedSizeBinaryBuilder::UnsafeAppend(value); } -Status Decimal128Builder::FinishInternal(std::shared_ptr* out) { +template +Status BaseDecimalBuilder::FinishInternal(std::shared_ptr* out) { std::shared_ptr data; RETURN_NOT_OK(byte_builder_.Finish(&data)); std::shared_ptr null_bitmap; @@ -67,39 +72,7 @@ Status Decimal128Builder::FinishInternal(std::shared_ptr* out) { return Status::OK(); } -// ---------------------------------------------------------------------- -// Decimal256Builder - -Decimal256Builder::Decimal256Builder(const std::shared_ptr& type, - MemoryPool* pool) - : FixedSizeBinaryBuilder(type, pool), - decimal_type_(internal::checked_pointer_cast(type)) {} - -Status Decimal256Builder::Append(const Decimal256& value) { - RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); - UnsafeAppend(value); - return Status::OK(); -} - -void Decimal256Builder::UnsafeAppend(const Decimal256& value) { - value.ToBytes(GetMutableValue(length())); - byte_builder_.UnsafeAdvance(32); - UnsafeAppendToBitmap(true); -} - -void Decimal256Builder::UnsafeAppend(util::string_view value) { - FixedSizeBinaryBuilder::UnsafeAppend(value); -} - -Status Decimal256Builder::FinishInternal(std::shared_ptr* out) { - std::shared_ptr data; - RETURN_NOT_OK(byte_builder_.Finish(&data)); - std::shared_ptr null_bitmap; - RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); - - *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_); - capacity_ = length_ = null_count_ = 0; - return Status::OK(); -} +template class BaseDecimalBuilder<128>; +template class BaseDecimalBuilder<256>; } // namespace arrow diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index 8c75e7dd6747c..f1d7b8387201e 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -23,25 +23,29 @@ #include "arrow/array/builder_base.h" #include "arrow/array/builder_binary.h" #include "arrow/array/data.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/visibility.h" namespace arrow { -class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder { - public: - using TypeClass = Decimal128Type; +template +class BaseDecimalBuilder : public FixedSizeBinaryBuilder { +public: + using TypeClass = typename DecimalTypeTraits::TypeClass; + using ArrayType = typename DecimalTypeTraits::ArrayType; + using ValueType = typename DecimalTypeTraits::ValueType; - explicit Decimal128Builder(const std::shared_ptr& type, - MemoryPool* pool = default_memory_pool()); + explicit BaseDecimalBuilder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool()); using FixedSizeBinaryBuilder::Append; using FixedSizeBinaryBuilder::AppendValues; using FixedSizeBinaryBuilder::Reset; - Status Append(Decimal128 val); - void UnsafeAppend(Decimal128 val); + Status Append(ValueType val); + void UnsafeAppend(ValueType val); void UnsafeAppend(util::string_view val); Status FinishInternal(std::shared_ptr* out) override; @@ -50,43 +54,25 @@ class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder { using ArrayBuilder::Finish; /// \endcond - Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } std::shared_ptr type() const override { return decimal_type_; } protected: - std::shared_ptr decimal_type_; + std::shared_ptr decimal_type_; }; -class ARROW_EXPORT Decimal256Builder : public FixedSizeBinaryBuilder { - public: - using TypeClass = Decimal256Type; - - explicit Decimal256Builder(const std::shared_ptr& type, - MemoryPool* pool = default_memory_pool()); - - using FixedSizeBinaryBuilder::Append; - using FixedSizeBinaryBuilder::AppendValues; - using FixedSizeBinaryBuilder::Reset; - - Status Append(const Decimal256& val); - void UnsafeAppend(const Decimal256& val); - void UnsafeAppend(util::string_view val); - - Status FinishInternal(std::shared_ptr* out) override; - - /// \cond FALSE - using ArrayBuilder::Finish; - /// \endcond - - Status Finish(std::shared_ptr* out) { return FinishTyped(out); } - - std::shared_ptr type() const override { return decimal_type_; } +/// Builder class for decimal 128-bit +class ARROW_EXPORT Decimal128Builder : public BaseDecimalBuilder<128> { + using BaseDecimalBuilder<128>::BaseDecimalBuilder; +}; - protected: - std::shared_ptr decimal_type_; +/// Builder class for decimal 256-bit +class ARROW_EXPORT Decimal256Builder : public BaseDecimalBuilder<256> { + using BaseDecimalBuilder<256>::BaseDecimalBuilder; }; +// Backward compatibility using DecimalBuilder = Decimal128Builder; } // namespace arrow diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 2888874d29287..99f691e043615 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -36,6 +36,7 @@ #include "arrow/type_traits.h" #include "arrow/util/compare.h" #include "arrow/util/decimal.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/util/string_view.h" #include "arrow/util/visibility.h" @@ -339,26 +340,24 @@ struct ARROW_EXPORT DurationScalar : public TemporalScalar { using TemporalScalar::TemporalScalar; }; -struct ARROW_EXPORT Decimal128Scalar : public Scalar { +template +struct BaseDecimalScalar : public Scalar { using Scalar::Scalar; - using TypeClass = Decimal128Type; - using ValueType = Decimal128; + using TypeClass = typename DecimalTypeTraits::TypeClass; + using ValueType = typename DecimalTypeTraits::ValueType; - Decimal128Scalar(Decimal128 value, std::shared_ptr type) + BaseDecimalScalar(ValueType value, std::shared_ptr type) : Scalar(std::move(type), true), value(value) {} - Decimal128 value; + ValueType value; }; -struct ARROW_EXPORT Decimal256Scalar : public Scalar { - using Scalar::Scalar; - using TypeClass = Decimal256Type; - using ValueType = Decimal256; - - Decimal256Scalar(Decimal256 value, std::shared_ptr type) - : Scalar(std::move(type), true), value(value) {} +struct ARROW_EXPORT Decimal128Scalar : public BaseDecimalScalar<128> { + using BaseDecimalScalar<128>::BaseDecimalScalar; +}; - Decimal256 value; +struct ARROW_EXPORT Decimal256Scalar : public BaseDecimalScalar<256> { + using BaseDecimalScalar<256>::BaseDecimalScalar; }; struct ARROW_EXPORT BaseListScalar : public Scalar { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 12d3951865f78..0a8dbff9ace57 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -42,6 +42,7 @@ #include "arrow/util/make_unique.h" #include "arrow/util/range.h" #include "arrow/util/vector.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/visitor_inline.h" namespace arrow { @@ -794,35 +795,22 @@ int32_t DecimalType::DecimalSize(int32_t precision) { } // ---------------------------------------------------------------------- -// Decimal128 type +// Decimal type -Decimal128Type::Decimal128Type(int32_t precision, int32_t scale) - : DecimalType(type_id, 16, precision, scale) { - ARROW_CHECK_GE(precision, kMinPrecision); - ARROW_CHECK_LE(precision, kMaxPrecision); -} -Result> Decimal128Type::Make(int32_t precision, int32_t scale) { - if (precision < kMinPrecision || precision > kMaxPrecision) { - return Status::Invalid("Decimal precision out of range: ", precision); - } - return std::make_shared(precision, scale); -} - -// ---------------------------------------------------------------------- -// Decimal256 type - -Decimal256Type::Decimal256Type(int32_t precision, int32_t scale) - : DecimalType(type_id, 32, precision, scale) { +template +BaseDecimalType::BaseDecimalType(int32_t precision, int32_t scale) + : DecimalType(DecimalTypeTraits::Id, (width >> 3), precision, scale) { ARROW_CHECK_GE(precision, kMinPrecision); ARROW_CHECK_LE(precision, kMaxPrecision); } -Result> Decimal256Type::Make(int32_t precision, int32_t scale) { +template +Result> BaseDecimalType::Make(int32_t precision, int32_t scale) { if (precision < kMinPrecision || precision > kMaxPrecision) { return Status::Invalid("Decimal precision out of range: ", precision); } - return std::make_shared(precision, scale); + return std::make_shared::TypeClass>(precision, scale); } // ---------------------------------------------------------------------- @@ -2203,16 +2191,14 @@ std::shared_ptr decimal256(int32_t precision, int32_t scale) { return std::make_shared(precision, scale); } -std::string Decimal128Type::ToString() const { +template +std::string BaseDecimalType::ToString() const { std::stringstream s; - s << "decimal(" << precision_ << ", " << scale_ << ")"; + s << type_name() << "(" << precision_ << ", " << scale_ << ")"; return s.str(); } -std::string Decimal256Type::ToString() const { - std::stringstream s; - s << "decimal256(" << precision_ << ", " << scale_ << ")"; - return s.str(); -} +template class BaseDecimalType<128>; +template class BaseDecimalType<256>; } // namespace arrow diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 127ed598399d7..4437c5e76c242 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -33,6 +33,7 @@ #include "arrow/util/macros.h" #include "arrow/util/variant.h" #include "arrow/util/visibility.h" +#include "arrow/util/decimal_meta.h" #include "arrow/visitor.h" // IWYU pragma: keep namespace arrow { @@ -892,46 +893,38 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { int32_t scale_; }; -/// \brief Concrete type class for 128-bit decimal data -class ARROW_EXPORT Decimal128Type : public DecimalType { +/// \brief Template type class for decimal data +template +class BaseDecimalType : public DecimalType { public: - static constexpr Type::type type_id = Type::DECIMAL128; - - static constexpr const char* type_name() { return "decimal"; } + static constexpr const char* type_name() { return DecimalMeta::name; } - /// Decimal128Type constructor that aborts on invalid input. - explicit Decimal128Type(int32_t precision, int32_t scale); + /// BaseDecimalType constructor that aborts on invalid input. + explicit BaseDecimalType(int32_t precision, int32_t scale); - /// Decimal128Type constructor that returns an error on invalid input. + /// BaseDecimalType constructor that returns an error on invalid input. static Result> Make(int32_t precision, int32_t scale); std::string ToString() const override; - std::string name() const override { return "decimal"; } + std::string name() const override { return DecimalMeta::name; } static constexpr int32_t kMinPrecision = 1; - static constexpr int32_t kMaxPrecision = 38; - static constexpr int32_t kByteWidth = 16; + static constexpr int32_t kMaxPrecision = DecimalMeta::max_precision; + static constexpr int32_t kByteWidth = width / 8; }; -/// \brief Concrete type class for 256-bit decimal data -class ARROW_EXPORT Decimal256Type : public DecimalType { - public: - static constexpr Type::type type_id = Type::DECIMAL256; - - static constexpr const char* type_name() { return "decimal256"; } - - /// Decimal256Type constructor that aborts on invalid input. - explicit Decimal256Type(int32_t precision, int32_t scale); - - /// Decimal256Type constructor that returns an error on invalid input. - static Result> Make(int32_t precision, int32_t scale); - - std::string ToString() const override; - std::string name() const override { return "decimal256"; } +/// \brief Concrete type class for decimal 128-bit data +class ARROW_EXPORT Decimal128Type : public BaseDecimalType<128> { +public: + static constexpr Type::type type_id = Type::DECIMAL128; + using BaseDecimalType<128>::BaseDecimalType; +}; - static constexpr int32_t kMinPrecision = 1; - static constexpr int32_t kMaxPrecision = 76; - static constexpr int32_t kByteWidth = 32; +/// \brief Concrete type class for decimal 256-bit data +class ARROW_EXPORT Decimal256Type : public BaseDecimalType<256> { +public: + static constexpr Type::type type_id = Type::DECIMAL256; + using BaseDecimalType<256>::BaseDecimalType; }; /// \brief Concrete type class for union data diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index f1000d1fe7fb7..ea8a39f970ebc 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -143,17 +143,20 @@ class StructArray; class StructBuilder; struct StructScalar; -class Decimal128; -class Decimal256; class DecimalType; -class Decimal128Type; -class Decimal256Type; -class Decimal128Array; -class Decimal256Array; -class Decimal128Builder; -class Decimal256Builder; -struct Decimal128Scalar; -struct Decimal256Scalar; + +#define DECIMAL_DECL(width) \ +class Decimal##width; \ +class Decimal##width##Type; \ +class Decimal##width##Array; \ +class Decimal##width##Builder; \ +struct Decimal##width##Scalar; + +DECIMAL_DECL(128) +DECIMAL_DECL(256) + +#undef DECIMAL_DECL + struct UnionMode { enum type { SPARSE, DENSE }; diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 2dcfc77c437e2..a569a4cb598a8 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -281,22 +281,21 @@ struct TypeTraits { static inline std::shared_ptr type_singleton() { return float16(); } }; -template <> -struct TypeTraits { - using ArrayType = Decimal128Array; - using BuilderType = Decimal128Builder; - using ScalarType = Decimal128Scalar; - constexpr static bool is_parameter_free = false; -}; -template <> -struct TypeTraits { - using ArrayType = Decimal256Array; - using BuilderType = Decimal256Builder; - using ScalarType = Decimal256Scalar; - constexpr static bool is_parameter_free = false; +#define DECIMAL_TYPE_TRAITS_DEF(width) \ +template <> \ +struct TypeTraits { \ + using ArrayType = Decimal##width##Array; \ + using BuilderType = Decimal##width##Builder; \ + using ScalarType = Decimal##width##Scalar; \ + constexpr static bool is_parameter_free = false; \ }; +DECIMAL_TYPE_TRAITS_DEF(128) +DECIMAL_TYPE_TRAITS_DEF(256) + +#undef DECIMAL_TYPE_TRAITS_DEF + template <> struct TypeTraits { using ArrayType = BinaryArray; diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index 78d5b15d1c040..b2864239b98bf 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -30,6 +30,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/int128_internal.h" #include "arrow/util/int_util_internal.h" +#include "arrow/util/decimal_meta.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" @@ -281,7 +282,7 @@ BasicDecimal128 BasicDecimal128::Abs(const BasicDecimal128& in) { bool BasicDecimal128::FitsInPrecision(int32_t precision) const { DCHECK_GT(precision, 0); - DCHECK_LE(precision, 38); + DCHECK_LE(precision, DecimalMeta<128>::max_precision); return BasicDecimal128::Abs(*this) < ScaleMultipliers[precision]; } @@ -930,6 +931,8 @@ DecimalStatus DecimalRescale(const DecimalClass& value, int32_t original_scale, const int32_t abs_delta_scale = std::abs(delta_scale); DecimalClass multiplier = DecimalClass::GetScaleMultiplier(abs_delta_scale); + DCHECK_GE(abs_delta_scale, 1); + DCHECK_LE(abs_delta_scale, DecimalMeta<128>::max_precision); const bool rescale_would_cause_data_loss = RescaleWouldCauseDataLoss(value, delta_scale, multiplier, out); @@ -950,7 +953,7 @@ DecimalStatus BasicDecimal128::Rescale(int32_t original_scale, int32_t new_scale void BasicDecimal128::GetWholeAndFraction(int scale, BasicDecimal128* whole, BasicDecimal128* fraction) const { DCHECK_GE(scale, 0); - DCHECK_LE(scale, 38); + DCHECK_LE(scale, DecimalMeta<128>::max_precision); BasicDecimal128 multiplier(ScaleMultipliers[scale]); auto s = Divide(multiplier, whole, fraction); @@ -959,7 +962,7 @@ void BasicDecimal128::GetWholeAndFraction(int scale, BasicDecimal128* whole, const BasicDecimal128& BasicDecimal128::GetScaleMultiplier(int32_t scale) { DCHECK_GE(scale, 0); - DCHECK_LE(scale, 38); + DCHECK_LE(scale, DecimalMeta<128>::max_precision); return ScaleMultipliers[scale]; } @@ -968,14 +971,14 @@ const BasicDecimal128& BasicDecimal128::GetMaxValue() { return kMaxValue; } BasicDecimal128 BasicDecimal128::IncreaseScaleBy(int32_t increase_by) const { DCHECK_GE(increase_by, 0); - DCHECK_LE(increase_by, 38); + DCHECK_LE(increase_by, DecimalMeta<128>::max_precision); return (*this) * ScaleMultipliers[increase_by]; } BasicDecimal128 BasicDecimal128::ReduceScaleBy(int32_t reduce_by, bool round) const { DCHECK_GE(reduce_by, 0); - DCHECK_LE(reduce_by, 38); + DCHECK_LE(reduce_by, DecimalMeta<128>::max_precision); if (reduce_by == 0) { return *this; diff --git a/cpp/src/arrow/util/decimal_meta.h b/cpp/src/arrow/util/decimal_meta.h new file mode 100644 index 0000000000000..e5eb4d907a22a --- /dev/null +++ b/cpp/src/arrow/util/decimal_meta.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow { + +template +struct DecimalMeta; + +template<> +struct DecimalMeta<128> { + static constexpr const char* name = "decimal"; + static constexpr int32_t max_precision = 38; +}; + +template<> +struct DecimalMeta<256> { + static constexpr const char* name = "decimal256"; + static constexpr int32_t max_precision = 76; +}; + +} // namespace arrow diff --git a/cpp/src/arrow/util/decimal_type_traits.h b/cpp/src/arrow/util/decimal_type_traits.h new file mode 100644 index 0000000000000..fd8e9a5e1ff8d --- /dev/null +++ b/cpp/src/arrow/util/decimal_type_traits.h @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type_fwd.h" + +namespace arrow { + +template +struct DecimalTypeTraits; + +#define DECIMAL_TYPE_TRAITS_DECL(width) \ +template<> \ +struct DecimalTypeTraits { \ + static constexpr Type::type Id = Type::DECIMAL##width; \ + using ArrayType = Decimal##width##Array; \ + using BuilderType = Decimal##width##Builder; \ + using ScalarType = Decimal##width##Scalar; \ + using TypeClass = Decimal##width##Type; \ + using ValueType = Decimal##width; \ +}; + +DECIMAL_TYPE_TRAITS_DECL(128) +DECIMAL_TYPE_TRAITS_DECL(256) + +} // namespace arrow From b6f3cec754c2799c7a6823a502a99e0ea1109b60 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Tue, 10 Nov 2020 18:05:49 +0300 Subject: [PATCH 2/8] Draft implementation of low bitness decimals for C++ Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/array/array_base.cc | 15 +- cpp/src/arrow/array/array_decimal.cc | 3 + cpp/src/arrow/array/array_decimal.h | 10 - cpp/src/arrow/array/array_dict_test.cc | 18 + cpp/src/arrow/array/array_test.cc | 119 +- cpp/src/arrow/array/builder_decimal.cc | 3 + cpp/src/arrow/array/builder_decimal.h | 10 - cpp/src/arrow/array/builder_dict.h | 24 + cpp/src/arrow/array/validate.cc | 8 + cpp/src/arrow/builder.cc | 6 + cpp/src/arrow/compare.cc | 27 +- cpp/src/arrow/dataset/filter.cc | 1775 +++++++++++++++++ cpp/src/arrow/ipc/json_simple.cc | 21 +- cpp/src/arrow/ipc/json_simple_test.cc | 69 +- cpp/src/arrow/pretty_print.cc | 8 +- cpp/src/arrow/pretty_print_test.cc | 8 +- cpp/src/arrow/scalar.cc | 5 + cpp/src/arrow/scalar.h | 8 - cpp/src/arrow/scalar_test.cc | 31 +- cpp/src/arrow/testing/gtest_util.cc | 3 + cpp/src/arrow/testing/json_internal.cc | 57 +- cpp/src/arrow/type.cc | 24 + cpp/src/arrow/type.h | 21 + cpp/src/arrow/type_fwd.h | 50 +- cpp/src/arrow/type_test.cc | 42 + cpp/src/arrow/type_traits.h | 27 + cpp/src/arrow/util/basic_decimal.cc | 231 ++- cpp/src/arrow/util/basic_decimal.h | 160 ++ cpp/src/arrow/util/decimal.cc | 206 +- cpp/src/arrow/util/decimal.h | 74 + cpp/src/arrow/util/decimal_meta.h | 32 + .../arrow/util/decimal_scale_multipliers.h | 143 ++ cpp/src/arrow/util/decimal_test.cc | 187 +- cpp/src/arrow/util/decimal_type_traits.h | 3 + cpp/src/arrow/util/int_util_internal.h | 8 + cpp/src/arrow/visitor.cc | 9 + cpp/src/arrow/visitor.h | 9 + cpp/src/arrow/visitor_inline.h | 3 + 38 files changed, 3110 insertions(+), 347 deletions(-) create mode 100644 cpp/src/arrow/dataset/filter.cc create mode 100644 cpp/src/arrow/util/decimal_scale_multipliers.h diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index 67c5ca84e1f52..2a0a7b7b6338c 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -69,13 +69,18 @@ struct ScalarFromArraySlotImpl { return Finish(a.Value(index_)); } - Status Visit(const Decimal128Array& a) { - return Finish(Decimal128(a.GetValue(index_))); + #define DECL_DECIMAL_VISIT(width) \ + Status Visit(const Decimal##width##Array& a) { \ + return Finish(Decimal##width(a.GetValue(index_))); \ } - Status Visit(const Decimal256Array& a) { - return Finish(Decimal256(a.GetValue(index_))); - } + DECL_DECIMAL_VISIT(16) + DECL_DECIMAL_VISIT(32) + DECL_DECIMAL_VISIT(64) + DECL_DECIMAL_VISIT(128) + DECL_DECIMAL_VISIT(256) + + #undef DECL_DECIMAL_VISIT template Status Visit(const BaseBinaryArray& a) { diff --git a/cpp/src/arrow/array/array_decimal.cc b/cpp/src/arrow/array/array_decimal.cc index b895ba72061c6..f05fe7960f2db 100644 --- a/cpp/src/arrow/array/array_decimal.cc +++ b/cpp/src/arrow/array/array_decimal.cc @@ -46,6 +46,9 @@ std::string BaseDecimalArray::FormatValue(int64_t i) const { return value.ToString(type_.scale()); } +template class BaseDecimalArray<16>; +template class BaseDecimalArray<32>; +template class BaseDecimalArray<64>; template class BaseDecimalArray<128>; template class BaseDecimalArray<256>; diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index fbac3ac094507..c9a0aff1bf36c 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -44,16 +44,6 @@ class BaseDecimalArray : public FixedSizeBinaryArray { std::string FormatValue(int64_t i) const; }; -/// Array class for decimal 128-bit data -class ARROW_EXPORT Decimal128Array : public BaseDecimalArray<128> { - using BaseDecimalArray<128>::BaseDecimalArray; -}; - -/// Array class for decimal 256-bit data -class ARROW_EXPORT Decimal256Array : public BaseDecimalArray<256> { - using BaseDecimalArray<256>::BaseDecimalArray; -}; - // Backward compatibility using DecimalArray = Decimal128Array; diff --git a/cpp/src/arrow/array/array_dict_test.cc b/cpp/src/arrow/array/array_dict_test.cc index fca442b256750..498fbd8a812fb 100644 --- a/cpp/src/arrow/array/array_dict_test.cc +++ b/cpp/src/arrow/array/array_dict_test.cc @@ -859,6 +859,18 @@ void TestDecimalDictionaryBuilderBasic(std::shared_ptr decimal_type) { ASSERT_TRUE(expected.Equals(result)); } +TEST(TestDecimal16DictionaryBuilder, Basic) { + TestDecimalDictionaryBuilderBasic(arrow::decimal16(2, 0)); +} + +TEST(TestDecimal32DictionaryBuilder, Basic) { + TestDecimalDictionaryBuilderBasic(arrow::decimal32(2, 0)); +} + +TEST(TestDecimal64DictionaryBuilder, Basic) { + TestDecimalDictionaryBuilderBasic(arrow::decimal64(2, 0)); +} + TEST(TestDecimal128DictionaryBuilder, Basic) { TestDecimalDictionaryBuilderBasic(arrow::decimal128(2, 0)); } @@ -921,6 +933,12 @@ void TestDecimalDictionaryBuilderDoubleTableSize( ASSERT_TRUE(expected.Equals(result)); } +// TEST(TestDecimal64DictionaryBuilder, DoubleTableSize) { +// const auto& decimal_type = arrow::decimal64(18, 0); +// Decimal64Builder decimal_builder(decimal_type); +// TestDecimalDictionaryBuilderDoubleTableSize(decimal_type, decimal_builder); +// } + TEST(TestDecimal128DictionaryBuilder, DoubleTableSize) { const auto& decimal_type = arrow::decimal128(21, 0); Decimal128Builder decimal_builder(decimal_type); diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 89087ee318c60..38c19e4c50109 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -2407,6 +2407,32 @@ class DecimalTest : public ::testing::TestWithParam { } } + void InitNoNullsTest(int32_t precision) { + std::vector draw = {DecimalValue(1), DecimalValue(-2), DecimalValue(2389), + DecimalValue(4), DecimalValue(-12348)}; + std::vector valid_bytes = {true, true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); + } + + void InitWithNullsTest(int32_t precision, std::string big_value, std::string big_negate_value) { + std::vector draw = {DecimalValue(1), DecimalValue(2), DecimalValue(-1), + DecimalValue(4), DecimalValue(-1), DecimalValue(1), + DecimalValue(2)}; + DecimalValue big; + ASSERT_OK_AND_ASSIGN(big, DecimalValue::FromString(big_value)); + draw.push_back(big); + + DecimalValue big_negative; + ASSERT_OK_AND_ASSIGN(big_negative, DecimalValue::FromString(big_negate_value)); + draw.push_back(big_negative); + + std::vector valid_bytes = {true, true, false, true, false, + true, true, true, true}; + this->TestCreate(precision, draw, valid_bytes, 0); + this->TestCreate(precision, draw, valid_bytes, 2); + } + template void TestCreate(int32_t precision, const DecimalVector& draw, const std::vector& valid_bytes, int64_t offset) const { @@ -2451,34 +2477,58 @@ class DecimalTest : public ::testing::TestWithParam { } }; +using Decimal16Test = DecimalTest; + +TEST_P(Decimal16Test, NoNulls) { + int32_t precision = GetParam(); + this->InitNoNullsTest(precision); +} + +TEST_P(Decimal16Test, WithNulls) { + int32_t precision = GetParam(); + this->InitWithNullsTest(precision, "163.84", "-163.84"); +} + +INSTANTIATE_TEST_SUITE_P(Decimal16Test, Decimal16Test, ::testing::Range(1, 5)); + +using Decimal32Test = DecimalTest; + +TEST_P(Decimal32Test, NoNulls) { + int32_t precision = GetParam(); + this->InitNoNullsTest(precision); +} + +TEST_P(Decimal32Test, WithNulls) { + int32_t precision = GetParam(); + this->InitWithNullsTest(precision, "107374.1824", "-107374.1824"); +} + +INSTANTIATE_TEST_SUITE_P(Decimal32Test, Decimal32Test, ::testing::Range(1, 10)); + +using Decimal64Test = DecimalTest; + +TEST_P(Decimal64Test, NoNulls) { + int32_t precision = GetParam(); + this->InitNoNullsTest(precision); +} + +TEST_P(Decimal64Test, WithNulls) { + int32_t precision = GetParam(); + this->InitWithNullsTest(precision, "46116860184.27387904", "-46116860184.27387904"); +} + +INSTANTIATE_TEST_SUITE_P(Decimal64Test, Decimal64Test, ::testing::Range(1, 19)); + using Decimal128Test = DecimalTest; TEST_P(Decimal128Test, NoNulls) { int32_t precision = GetParam(); - std::vector draw = {Decimal128(1), Decimal128(-2), Decimal128(2389), - Decimal128(4), Decimal128(-12348)}; - std::vector valid_bytes = {true, true, true, true, true}; - this->TestCreate(precision, draw, valid_bytes, 0); - this->TestCreate(precision, draw, valid_bytes, 2); + this->InitNoNullsTest(precision); } TEST_P(Decimal128Test, WithNulls) { int32_t precision = GetParam(); - std::vector draw = {Decimal128(1), Decimal128(2), Decimal128(-1), - Decimal128(4), Decimal128(-1), Decimal128(1), - Decimal128(2)}; - Decimal128 big; - ASSERT_OK_AND_ASSIGN(big, Decimal128::FromString("230342903942.234234")); - draw.push_back(big); - - Decimal128 big_negative; - ASSERT_OK_AND_ASSIGN(big_negative, Decimal128::FromString("-23049302932.235234")); - draw.push_back(big_negative); - - std::vector valid_bytes = {true, true, false, true, false, - true, true, true, true}; - this->TestCreate(precision, draw, valid_bytes, 0); - this->TestCreate(precision, draw, valid_bytes, 2); + this->InitWithNullsTest(precision, "23049302932.235234", "-23049302932.235234"); } INSTANTIATE_TEST_SUITE_P(Decimal128Test, Decimal128Test, ::testing::Range(1, 38)); @@ -2487,34 +2537,15 @@ using Decimal256Test = DecimalTest; TEST_P(Decimal256Test, NoNulls) { int32_t precision = GetParam(); - std::vector draw = {Decimal256(1), Decimal256(-2), Decimal256(2389), - Decimal256(4), Decimal256(-12348)}; - std::vector valid_bytes = {true, true, true, true, true}; - this->TestCreate(precision, draw, valid_bytes, 0); - this->TestCreate(precision, draw, valid_bytes, 2); + this->InitNoNullsTest(precision); } TEST_P(Decimal256Test, WithNulls) { int32_t precision = GetParam(); - std::vector draw = {Decimal256(1), Decimal256(2), Decimal256(-1), - Decimal256(4), Decimal256(-1), Decimal256(1), - Decimal256(2)}; - Decimal256 big; // (pow(2, 255) - 1) / pow(10, 38) - ASSERT_OK_AND_ASSIGN(big, - Decimal256::FromString("578960446186580977117854925043439539266." - "34992332820282019728792003956564819967")); - draw.push_back(big); - - Decimal256 big_negative; // -pow(2, 255) / pow(10, 38) - ASSERT_OK_AND_ASSIGN(big_negative, - Decimal256::FromString("-578960446186580977117854925043439539266." - "34992332820282019728792003956564819968")); - draw.push_back(big_negative); - - std::vector valid_bytes = {true, true, false, true, false, - true, true, true, true}; - this->TestCreate(precision, draw, valid_bytes, 0); - this->TestCreate(precision, draw, valid_bytes, 2); + this->InitWithNullsTest(precision, "578960446186580977117854925043439539266." + "34992332820282019728792003956564819967", + "-578960446186580977117854925043439539266." + "34992332820282019728792003956564819968"); } INSTANTIATE_TEST_SUITE_P(Decimal256Test, Decimal256Test, diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index 6e6f5dc8e199f..f29fe6cc8c205 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -72,6 +72,9 @@ Status BaseDecimalBuilder::FinishInternal(std::shared_ptr* out return Status::OK(); } +template class BaseDecimalBuilder<16>; +template class BaseDecimalBuilder<32>; +template class BaseDecimalBuilder<64>; template class BaseDecimalBuilder<128>; template class BaseDecimalBuilder<256>; diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index f1d7b8387201e..410bcac235aeb 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -62,16 +62,6 @@ class BaseDecimalBuilder : public FixedSizeBinaryBuilder { std::shared_ptr decimal_type_; }; -/// Builder class for decimal 128-bit -class ARROW_EXPORT Decimal128Builder : public BaseDecimalBuilder<128> { - using BaseDecimalBuilder<128>::BaseDecimalBuilder; -}; - -/// Builder class for decimal 256-bit -class ARROW_EXPORT Decimal256Builder : public BaseDecimalBuilder<256> { - using BaseDecimalBuilder<256>::BaseDecimalBuilder; -}; - // Backward compatibility using DecimalBuilder = Decimal128Builder; diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 40d6ce1ba9a78..131f256c99209 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -238,6 +238,30 @@ class DictionaryBuilderBase : public ArrayBuilder { return Append(util::string_view(value, length)); } + /// \brief Append a decimal (only for Decimal16Type) + template + enable_if_decimal16 Append(const Decimal16& value) { + uint8_t data[2]; + value.ToBytes(data); + return Append(data, 2); + } + + /// \brief Append a decimal (only for Decimal32Type) + template + enable_if_decimal32 Append(const Decimal32& value) { + uint8_t data[4]; + value.ToBytes(data); + return Append(data, 4); + } + + /// \brief Append a decimal (only for Decimal64Type) + template + enable_if_decimal64 Append(const Decimal64& value) { + uint8_t data[8]; + value.ToBytes(data); + return Append(data, 8); + } + /// \brief Append a decimal (only for Decimal128Type) template enable_if_decimal128 Append(const Decimal128& value) { diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index 38092045aab74..691d7e4ad939f 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -62,6 +62,14 @@ struct ValidateArrayImpl { return Status::OK(); } + template + Status Visit(const BaseDecimalArray& array) { + if (array.length() > 0 && array.values() == nullptr) { + return Status::Invalid("values is null"); + } + return Status::OK(); + } + Status Visit(const StringType& type) { return ValidateBinaryLike(type); } Status Visit(const BinaryType& type) { return ValidateBinaryLike(type); } diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index f22228a458897..e3b5a69971226 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -50,6 +50,9 @@ struct DictionaryBuilderCase { return Create>(); } Status Visit(const FixedSizeBinaryType&) { return CreateFor(); } + Status Visit(const Decimal16Type&) { return CreateFor(); } + Status Visit(const Decimal32Type&) { return CreateFor(); } + Status Visit(const Decimal64Type&) { return CreateFor(); } Status Visit(const Decimal128Type&) { return CreateFor(); } Status Visit(const Decimal256Type&) { return CreateFor(); } @@ -138,6 +141,9 @@ Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, BUILDER_CASE(LargeString); BUILDER_CASE(LargeBinary); BUILDER_CASE(FixedSizeBinary); + BUILDER_CASE(Decimal16); + BUILDER_CASE(Decimal32); + BUILDER_CASE(Decimal64); BUILDER_CASE(Decimal128); BUILDER_CASE(Decimal256); diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 4c6f97faf9513..1e5c350e63c70 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -264,6 +264,11 @@ class RangeDataEqualsImpl { Status Visit(const LargeListType& type) { return CompareList(type); } + template + Status Visit(const BaseDecimalArray& left) { + return Visit(checked_cast(left)); + } + Status Visit(const FixedSizeListType& type) { const auto list_size = type.list_size(); const ArrayData& left_data = *left_.child_data[0]; @@ -605,14 +610,9 @@ class TypeEqualsVisitor { return Status::OK(); } - Status Visit(const Decimal128Type& left) { - const auto& right = checked_cast(right_); - result_ = left.precision() == right.precision() && left.scale() == right.scale(); - return Status::OK(); - } - - Status Visit(const Decimal256Type& left) { - const auto& right = checked_cast(right_); + template + Status Visit(const BaseDecimalType& left) { + const auto& right = checked_cast&>(right_); result_ = left.precision() == right.precision() && left.scale() == right.scale(); return Status::OK(); } @@ -721,14 +721,9 @@ class ScalarEqualsVisitor { return Status::OK(); } - Status Visit(const Decimal128Scalar& left) { - const auto& right = checked_cast(right_); - result_ = left.value == right.value; - return Status::OK(); - } - - Status Visit(const Decimal256Scalar& left) { - const auto& right = checked_cast(right_); + template + Status Visit(const BaseDecimalScalar& left) { + const auto& right = checked_cast&>(right_); result_ = left.value == right.value; return Status::OK(); } diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc new file mode 100644 index 0000000000000..a7ce9761feed2 --- /dev/null +++ b/cpp/src/arrow/dataset/filter.cc @@ -0,0 +1,1775 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/filter.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/builder_primitive.h" +#include "arrow/buffer.h" +#include "arrow/compute/api.h" +#include "arrow/dataset/dataset.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/scalar.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/int_util_internal.h" +#include "arrow/util/iterator.h" +#include "arrow/util/logging.h" +#include "arrow/util/string.h" +#include "arrow/visitor_inline.h" + +namespace arrow { + +using compute::CompareOperator; +using compute::ExecContext; + +namespace dataset { + +using arrow::internal::checked_cast; +using arrow::internal::checked_pointer_cast; + +inline std::shared_ptr NullExpression() { + return std::make_shared(std::make_shared()); +} + +inline Datum NullDatum() { return Datum(std::make_shared()); } + +bool IsNullDatum(const Datum& datum) { + if (datum.is_scalar()) { + auto scalar = datum.scalar(); + return !scalar->is_valid; + } + + auto array_data = datum.array(); + return array_data->GetNullCount() == array_data->length; +} + +struct Comparison { + enum type { + LESS, + EQUAL, + GREATER, + NULL_, + }; +}; + +Result> EnsureNotDictionary( + const std::shared_ptr& scalar) { + if (scalar->type->id() == Type::DICTIONARY) { + return checked_cast(*scalar).GetEncodedValue(); + } + return scalar; +} + +Result Compare(const Scalar& lhs, const Scalar& rhs); + +struct CompareVisitor { + template + using ScalarType = typename TypeTraits::ScalarType; + + Status Visit(const NullType&) { + result_ = Comparison::NULL_; + return Status::OK(); + } + + Status Visit(const BooleanType&) { return CompareValues(); } + + template + enable_if_physical_floating_point Visit(const T&) { + return CompareValues(); + } + + template + enable_if_physical_signed_integer Visit(const T&) { + return CompareValues(); + } + + template + enable_if_physical_unsigned_integer Visit(const T&) { + return CompareValues(); + } + + template + enable_if_nested Visit(const T&) { + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + + template + enable_if_binary_like Visit(const T&) { + auto lhs = checked_cast&>(lhs_).value; + auto rhs = checked_cast&>(rhs_).value; + auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); + if (cmp == 0) { + return CompareValues(lhs->size(), rhs->size()); + } + return CompareValues(cmp, 0); + } + + template + enable_if_string_like Visit(const T&) { + auto lhs = checked_cast&>(lhs_).value; + auto rhs = checked_cast&>(rhs_).value; + auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); + if (cmp == 0) { + return CompareValues(lhs->size(), rhs->size()); + } + return CompareValues(cmp, 0); + } + + Status Visit(const Decimal16Type&) { return CompareValues(); } + Status Visit(const Decimal32Type&) { return CompareValues(); } + Status Visit(const Decimal64Type&) { return CompareValues(); } + Status Visit(const Decimal128Type&) { return CompareValues(); } + Status Visit(const Decimal256Type&) { return CompareValues(); } + + // Explicit because it falls under `physical_unsigned_integer`. + // TODO(bkietz) whenever we vendor a float16, this can be implemented + Status Visit(const HalfFloatType&) { + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + + Status Visit(const ExtensionType&) { + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + + Status Visit(const DictionaryType&) { + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + + // defer comparison to ScalarType::value + template + Status CompareValues() { + auto lhs = checked_cast&>(lhs_).value; + auto rhs = checked_cast&>(rhs_).value; + return CompareValues(lhs, rhs); + } + + // defer comparison to explicit values + template + Status CompareValues(Value lhs, Value rhs) { + result_ = lhs < rhs ? Comparison::LESS + : lhs == rhs ? Comparison::EQUAL : Comparison::GREATER; + return Status::OK(); + } + + Comparison::type result_; + const Scalar& lhs_; + const Scalar& rhs_; +}; + +// Compare two scalars +// if either is null, return is null +// TODO(bkietz) extract this to the scalar comparison kernels +Result Compare(const Scalar& lhs, const Scalar& rhs) { + if (!lhs.type->Equals(*rhs.type)) { + return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, + " vs ", *rhs.type); + } + if (!lhs.is_valid || !rhs.is_valid) { + return Comparison::NULL_; + } + CompareVisitor vis{Comparison::NULL_, lhs, rhs}; + RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); + return vis.result_; +} + +CompareOperator InvertCompareOperator(CompareOperator op) { + switch (op) { + case CompareOperator::EQUAL: + return CompareOperator::NOT_EQUAL; + + case CompareOperator::NOT_EQUAL: + return CompareOperator::EQUAL; + + case CompareOperator::GREATER: + return CompareOperator::LESS_EQUAL; + + case CompareOperator::GREATER_EQUAL: + return CompareOperator::LESS; + + case CompareOperator::LESS: + return CompareOperator::GREATER_EQUAL; + + case CompareOperator::LESS_EQUAL: + return CompareOperator::GREATER; + + default: + break; + } + + DCHECK(false); + return CompareOperator::EQUAL; +} + +template +std::shared_ptr InvertBoolean(const Boolean& expr) { + auto lhs = Invert(*expr.left_operand()); + auto rhs = Invert(*expr.right_operand()); + + if (std::is_same::value) { + return std::make_shared(std::move(lhs), std::move(rhs)); + } + + if (std::is_same::value) { + return std::make_shared(std::move(lhs), std::move(rhs)); + } + + return nullptr; +} + +std::shared_ptr Invert(const Expression& expr) { + switch (expr.type()) { + case ExpressionType::NOT: + return checked_cast(expr).operand(); + + case ExpressionType::AND: + return InvertBoolean(checked_cast(expr)); + + case ExpressionType::OR: + return InvertBoolean(checked_cast(expr)); + + case ExpressionType::COMPARISON: { + const auto& comparison = checked_cast(expr); + auto inverted_op = InvertCompareOperator(comparison.op()); + return std::make_shared( + inverted_op, comparison.left_operand(), comparison.right_operand()); + } + + default: + break; + } + return nullptr; +} + +std::shared_ptr Expression::Assume(const Expression& given) const { + std::shared_ptr out; + + DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) { + if (out != nullptr) { + return Status::OK(); + } + + if (given.type() != ExpressionType::COMPARISON) { + return Status::OK(); + } + + const auto& given_cmp = checked_cast(given); + if (given_cmp.op() != CompareOperator::EQUAL) { + return Status::OK(); + } + + if (this->Equals(given_cmp.left_operand())) { + out = given_cmp.right_operand(); + return Status::OK(); + } + + if (this->Equals(given_cmp.right_operand())) { + out = given_cmp.left_operand(); + return Status::OK(); + } + + return Status::OK(); + })); + + return out ? out : Copy(); +} + +std::shared_ptr ComparisonExpression::Assume(const Expression& given) const { + switch (given.type()) { + case ExpressionType::COMPARISON: { + return AssumeGivenComparison(checked_cast(given)); + } + + case ExpressionType::NOT: { + const auto& given_not = checked_cast(given); + if (auto inverted = Invert(*given_not.operand())) { + return Assume(*inverted); + } + return Copy(); + } + + case ExpressionType::OR: { + const auto& given_or = checked_cast(given); + + auto left_simplified = Assume(*given_or.left_operand()); + auto right_simplified = Assume(*given_or.right_operand()); + + // The result of simplification against the operands of an OrExpression + // cannot be used unless they are identical + if (left_simplified->Equals(right_simplified)) { + return left_simplified; + } + + return Copy(); + } + + case ExpressionType::AND: { + const auto& given_and = checked_cast(given); + + auto simplified = Copy(); + simplified = simplified->Assume(*given_and.left_operand()); + simplified = simplified->Assume(*given_and.right_operand()); + return simplified; + } + + // TODO(bkietz) we should be able to use ExpressionType::IN here + + default: + break; + } + + return Copy(); +} + +// Try to simplify one comparison against another comparison. +// For example, +// ("x"_ > 3) is a subset of ("x"_ > 2), so ("x"_ > 2).Assume("x"_ > 3) == (true) +// ("x"_ < 0) is disjoint with ("x"_ > 2), so ("x"_ > 2).Assume("x"_ < 0) == (false) +// If simplification to (true) or (false) is not possible, pass e through unchanged. +std::shared_ptr ComparisonExpression::AssumeGivenComparison( + const ComparisonExpression& given) const { + if (!left_operand_->Equals(given.left_operand_)) { + return Copy(); + } + + for (auto rhs : {right_operand_, given.right_operand_}) { + if (rhs->type() != ExpressionType::SCALAR) { + return Copy(); + } + } + + auto this_rhs = + EnsureNotDictionary(checked_cast(*right_operand_).value()) + .ValueOr(nullptr); + auto given_rhs = + EnsureNotDictionary( + checked_cast(*given.right_operand_).value()) + .ValueOr(nullptr); + + if (!this_rhs || !given_rhs) { + return Copy(); + } + + auto cmp = Compare(*this_rhs, *given_rhs).ValueOrDie(); + + if (cmp == Comparison::NULL_) { + // the RHS of e or given was null + return NullExpression(); + } + + static auto always = scalar(true); + static auto never = scalar(false); + + if (cmp == Comparison::GREATER) { + // the rhs of e is greater than that of given + switch (op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + return never; + default: + return Copy(); + } + case CompareOperator::NOT_EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + return always; + default: + return Copy(); + } + default: + return Copy(); + } + } + + if (cmp == Comparison::LESS) { + // the rhs of e is less than that of given + switch (op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return never; + default: + return Copy(); + } + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return always; + default: + return Copy(); + } + default: + return Copy(); + } + } + + DCHECK_EQ(cmp, Comparison::EQUAL); + + // the rhs of the comparisons are equal + switch (op_) { + case CompareOperator::EQUAL: + switch (given.op()) { + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: + return never; + case CompareOperator::EQUAL: + return always; + default: + return Copy(); + } + case CompareOperator::NOT_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + return never; + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: + return always; + default: + return Copy(); + } + case CompareOperator::GREATER: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS_EQUAL: + case CompareOperator::LESS: + return never; + case CompareOperator::GREATER: + return always; + default: + return Copy(); + } + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::LESS: + return never; + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return always; + default: + return Copy(); + } + case CompareOperator::LESS: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return never; + case CompareOperator::LESS: + return always; + default: + return Copy(); + } + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::GREATER: + return never; + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + return always; + default: + return Copy(); + } + default: + return Copy(); + } + return Copy(); +} + +std::shared_ptr AndExpression::Assume(const Expression& given) const { + auto left_operand = left_operand_->Assume(given); + auto right_operand = right_operand_->Assume(given); + + // if either of the operands is trivially false then so is this AND + if (left_operand->Equals(false) || right_operand->Equals(false)) { + return scalar(false); + } + + // if either operand is trivially null then so is this AND + if (left_operand->IsNull() || right_operand->IsNull()) { + return NullExpression(); + } + + // if one of the operands is trivially true then drop it + if (left_operand->Equals(true)) { + return right_operand; + } + if (right_operand->Equals(true)) { + return left_operand; + } + + // if neither of the operands is trivial, simply construct a new AND + return and_(std::move(left_operand), std::move(right_operand)); +} + +std::shared_ptr OrExpression::Assume(const Expression& given) const { + auto left_operand = left_operand_->Assume(given); + auto right_operand = right_operand_->Assume(given); + + // if either of the operands is trivially true then so is this OR + if (left_operand->Equals(true) || right_operand->Equals(true)) { + return scalar(true); + } + + // if either operand is trivially null then so is this OR + if (left_operand->IsNull() || right_operand->IsNull()) { + return NullExpression(); + } + + // if one of the operands is trivially false then drop it + if (left_operand->Equals(false)) { + return right_operand; + } + if (right_operand->Equals(false)) { + return left_operand; + } + + // if neither of the operands is trivial, simply construct a new OR + return or_(std::move(left_operand), std::move(right_operand)); +} + +std::shared_ptr NotExpression::Assume(const Expression& given) const { + auto operand = operand_->Assume(given); + + if (operand->IsNull()) { + return NullExpression(); + } + if (operand->Equals(true)) { + return scalar(false); + } + if (operand->Equals(false)) { + return scalar(true); + } + + return Copy(); +} + +std::shared_ptr InExpression::Assume(const Expression& given) const { + auto operand = operand_->Assume(given); + if (operand->type() != ExpressionType::SCALAR) { + return std::make_shared(std::move(operand), set_); + } + + if (operand->IsNull()) { + return scalar(set_->null_count() > 0); + } + + Datum set, value; + if (set_->type_id() == Type::DICTIONARY) { + const auto& dict_set = checked_cast(*set_); + auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices()); + auto maybe_value = checked_cast( + *checked_cast(*operand).value()) + .GetEncodedValue(); + if (!maybe_decoded.ok() || !maybe_value.ok()) { + return std::make_shared(std::move(operand), set_); + } + set = *maybe_decoded; + value = *maybe_value; + } else { + set = set_; + value = checked_cast(*operand).value(); + } + + compute::CompareOptions eq(CompareOperator::EQUAL); + Result maybe_out = compute::Compare(set, value, eq); + if (!maybe_out.ok()) { + return std::make_shared(std::move(operand), set_); + } + + Datum out = maybe_out.ValueOrDie(); + + DCHECK(out.is_array()); + DCHECK_EQ(out.type()->id(), Type::BOOL); + auto out_array = checked_pointer_cast(out.make_array()); + + for (int64_t i = 0; i < out_array->length(); ++i) { + if (out_array->IsValid(i) && out_array->Value(i)) { + return scalar(true); + } + } + return scalar(false); +} + +std::shared_ptr IsValidExpression::Assume(const Expression& given) const { + auto operand = operand_->Assume(given); + if (operand->type() == ExpressionType::SCALAR) { + return scalar(!operand->IsNull()); + } + + return std::make_shared(std::move(operand)); +} + +std::shared_ptr CastExpression::Assume(const Expression& given) const { + auto operand = operand_->Assume(given); + if (arrow::util::holds_alternative>(to_)) { + auto to_type = arrow::util::get>(to_); + return std::make_shared(std::move(operand), std::move(to_type), + options_); + } + auto like = arrow::util::get>(to_)->Assume(given); + return std::make_shared(std::move(operand), std::move(like), options_); +} + +const std::shared_ptr& CastExpression::to_type() const { + if (arrow::util::holds_alternative>(to_)) { + return arrow::util::get>(to_); + } + static std::shared_ptr null; + return null; +} + +const std::shared_ptr& CastExpression::like_expr() const { + if (arrow::util::holds_alternative>(to_)) { + return arrow::util::get>(to_); + } + static std::shared_ptr null; + return null; +} + +std::string FieldExpression::ToString() const { return name_; } + +std::string OperatorName(compute::CompareOperator op) { + switch (op) { + case CompareOperator::EQUAL: + return "=="; + case CompareOperator::NOT_EQUAL: + return "!="; + case CompareOperator::LESS: + return "<"; + case CompareOperator::LESS_EQUAL: + return "<="; + case CompareOperator::GREATER: + return ">"; + case CompareOperator::GREATER_EQUAL: + return ">="; + default: + DCHECK(false); + } + return ""; +} + +std::string ScalarExpression::ToString() const { + auto type_repr = value_->type->ToString(); + if (!value_->is_valid) { + return "null:" + type_repr; + } + + return value_->ToString() + ":" + type_repr; +} + +using arrow::internal::JoinStrings; + +std::string AndExpression::ToString() const { + return JoinStrings( + {"(", left_operand_->ToString(), " and ", right_operand_->ToString(), ")"}, ""); +} + +std::string OrExpression::ToString() const { + return JoinStrings( + {"(", left_operand_->ToString(), " or ", right_operand_->ToString(), ")"}, ""); +} + +std::string NotExpression::ToString() const { + if (operand_->type() == ExpressionType::IS_VALID) { + const auto& is_valid = checked_cast(*operand_); + return JoinStrings({"(", is_valid.operand()->ToString(), " is null)"}, ""); + } + return JoinStrings({"(not ", operand_->ToString(), ")"}, ""); +} + +std::string IsValidExpression::ToString() const { + return JoinStrings({"(", operand_->ToString(), " is not null)"}, ""); +} + +std::string InExpression::ToString() const { + return JoinStrings({"(", operand_->ToString(), " is in ", set_->ToString(), ")"}, ""); +} + +std::string CastExpression::ToString() const { + std::string to; + if (arrow::util::holds_alternative>(to_)) { + auto to_type = arrow::util::get>(to_); + to = " to " + to_type->ToString(); + } else { + auto like = arrow::util::get>(to_); + to = " like " + like->ToString(); + } + return JoinStrings({"(cast ", operand_->ToString(), std::move(to), ")"}, ""); +} + +std::string ComparisonExpression::ToString() const { + return JoinStrings({"(", left_operand_->ToString(), " ", OperatorName(op()), " ", + right_operand_->ToString(), ")"}, + ""); +} + +bool UnaryExpression::Equals(const Expression& other) const { + return type_ == other.type() && + operand_->Equals(checked_cast(other).operand_); +} + +bool BinaryExpression::Equals(const Expression& other) const { + return type_ == other.type() && + left_operand_->Equals( + checked_cast(other).left_operand_) && + right_operand_->Equals( + checked_cast(other).right_operand_); +} + +bool ComparisonExpression::Equals(const Expression& other) const { + return BinaryExpression::Equals(other) && + op_ == checked_cast(other).op_; +} + +bool ScalarExpression::Equals(const Expression& other) const { + return other.type() == ExpressionType::SCALAR && + value_->Equals(*checked_cast(other).value_); +} + +bool FieldExpression::Equals(const Expression& other) const { + return other.type() == ExpressionType::FIELD && + name_ == checked_cast(other).name_; +} + +bool Expression::Equals(const std::shared_ptr& other) const { + if (other == nullptr) { + return false; + } + return Equals(*other); +} + +bool Expression::IsNull() const { + if (type_ != ExpressionType::SCALAR) { + return false; + } + + const auto& scalar = checked_cast(*this).value(); + if (!scalar->is_valid) { + return true; + } + + return false; +} + +InExpression Expression::In(std::shared_ptr set) const { + return InExpression(Copy(), std::move(set)); +} + +IsValidExpression Expression::IsValid() const { return IsValidExpression(Copy()); } + +std::shared_ptr FieldExpression::Copy() const { + return std::make_shared(*this); +} + +std::shared_ptr ScalarExpression::Copy() const { + return std::make_shared(*this); +} + +std::shared_ptr and_(std::shared_ptr lhs, + std::shared_ptr rhs) { + return std::make_shared(std::move(lhs), std::move(rhs)); +} + +std::shared_ptr and_(const ExpressionVector& subexpressions) { + auto acc = scalar(true); + for (const auto& next : subexpressions) { + if (next->Equals(false)) return next; + acc = acc->Equals(true) ? next : and_(std::move(acc), next); + } + return acc; +} + +std::shared_ptr or_(std::shared_ptr lhs, + std::shared_ptr rhs) { + return std::make_shared(std::move(lhs), std::move(rhs)); +} + +std::shared_ptr or_(const ExpressionVector& subexpressions) { + auto acc = scalar(false); + for (const auto& next : subexpressions) { + if (next->Equals(true)) return next; + acc = acc->Equals(false) ? next : or_(std::move(acc), next); + } + return acc; +} + +std::shared_ptr not_(std::shared_ptr operand) { + return std::make_shared(std::move(operand)); +} + +AndExpression operator&&(const Expression& lhs, const Expression& rhs) { + return AndExpression(lhs.Copy(), rhs.Copy()); +} + +OrExpression operator||(const Expression& lhs, const Expression& rhs) { + return OrExpression(lhs.Copy(), rhs.Copy()); +} + +NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } + +CastExpression Expression::CastTo(std::shared_ptr type, + compute::CastOptions options) const { + return CastExpression(Copy(), type, std::move(options)); +} + +CastExpression Expression::CastLike(std::shared_ptr expr, + compute::CastOptions options) const { + return CastExpression(Copy(), std::move(expr), std::move(options)); +} + +CastExpression Expression::CastLike(const Expression& expr, + compute::CastOptions options) const { + return CastLike(expr.Copy(), std::move(options)); +} + +Result> ComparisonExpression::Validate( + const Schema& schema) const { + ARROW_ASSIGN_OR_RAISE(auto lhs_type, left_operand_->Validate(schema)); + ARROW_ASSIGN_OR_RAISE(auto rhs_type, right_operand_->Validate(schema)); + + if (lhs_type->id() == Type::NA || rhs_type->id() == Type::NA) { + return boolean(); + } + + if (!lhs_type->Equals(rhs_type)) { + return Status::TypeError("cannot compare expressions of differing type, ", *lhs_type, + " vs ", *rhs_type); + } + + return boolean(); +} + +Status EnsureNullOrBool(const std::string& msg_prefix, + const std::shared_ptr& type) { + if (type->id() == Type::BOOL || type->id() == Type::NA) { + return Status::OK(); + } + return Status::TypeError(msg_prefix, *type); +} + +Result> ValidateBoolean(const ExpressionVector& operands, + const Schema& schema) { + for (const auto& operand : operands) { + ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); + RETURN_NOT_OK( + EnsureNullOrBool("cannot combine expressions including one of type ", type)); + } + return boolean(); +} + +Result> AndExpression::Validate(const Schema& schema) const { + return ValidateBoolean({left_operand_, right_operand_}, schema); +} + +Result> OrExpression::Validate(const Schema& schema) const { + return ValidateBoolean({left_operand_, right_operand_}, schema); +} + +Result> NotExpression::Validate(const Schema& schema) const { + return ValidateBoolean({operand_}, schema); +} + +Result> InExpression::Validate(const Schema& schema) const { + ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); + if (operand_type->id() == Type::NA || set_->type()->id() == Type::NA) { + return boolean(); + } + + if (!operand_type->Equals(set_->type())) { + return Status::TypeError("mismatch: set type ", *set_->type(), " vs operand type ", + *operand_type); + } + // TODO(bkietz) check if IsIn supports operand_type + return boolean(); +} + +Result> IsValidExpression::Validate( + const Schema& schema) const { + ARROW_ASSIGN_OR_RAISE(std::ignore, operand_->Validate(schema)); + return boolean(); +} + +Result> CastExpression::Validate(const Schema& schema) const { + ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); + std::shared_ptr to_type; + if (arrow::util::holds_alternative>(to_)) { + to_type = arrow::util::get>(to_); + } else { + auto like = arrow::util::get>(to_); + ARROW_ASSIGN_OR_RAISE(to_type, like->Validate(schema)); + } + + // Until expressions carry a shape, detect scalar and try to cast it. Works + // if the operand is a scalar leaf. + if (operand_->type() == ExpressionType::SCALAR) { + auto scalar_expr = checked_pointer_cast(operand_); + ARROW_ASSIGN_OR_RAISE(std::ignore, scalar_expr->value()->CastTo(to_type)); + return to_type; + } + + if (!compute::CanCast(*operand_type, *to_type)) { + return Status::Invalid("Cannot cast to ", to_type->ToString()); + } + + return to_type; +} + +Result> ScalarExpression::Validate(const Schema& schema) const { + return value_->type; +} + +Result> FieldExpression::Validate(const Schema& schema) const { + ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(name_).GetOneOrNone(schema)); + if (field != nullptr) { + return field->type(); + } + return null(); +} + +Result CastOrDictionaryEncode(const Datum& arr, + const std::shared_ptr& type, + const compute::CastOptions opts) { + if (type->id() == Type::DICTIONARY) { + const auto& dict_type = checked_cast(*type); + if (dict_type.index_type()->id() != Type::INT32) { + return Status::TypeError("cannot DictionaryEncode to index type ", + *dict_type.index_type()); + } + ARROW_ASSIGN_OR_RAISE(auto dense, compute::Cast(arr, dict_type.value_type(), opts)); + return compute::DictionaryEncode(dense); + } + + return compute::Cast(arr, type, opts); +} + +struct InsertImplicitCastsImpl { + struct ValidatedAndCast { + std::shared_ptr expr; + std::shared_ptr type; + }; + + Result InsertCastsAndValidate(const Expression& expr) { + ValidatedAndCast out; + ARROW_ASSIGN_OR_RAISE(out.expr, InsertImplicitCasts(expr, schema_)); + ARROW_ASSIGN_OR_RAISE(out.type, out.expr->Validate(schema_)); + return std::move(out); + } + + Result> Cast(std::shared_ptr type, + const Expression& expr) { + if (expr.type() != ExpressionType::SCALAR) { + return expr.CastTo(type).Copy(); + } + + // cast the scalar directly + const auto& value = checked_cast(expr).value(); + ARROW_ASSIGN_OR_RAISE(auto cast_value, value->CastTo(std::move(type))); + return scalar(cast_value); + } + + Result> operator()(const InExpression& expr) { + ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); + auto set = expr.set(); + + if (!op.type->Equals(set->type())) { + // cast the set (which we assume to be small) to match op.type + ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode(*set, op.type, {})); + set = encoded_set.make_array(); + } + + return std::make_shared(std::move(op.expr), std::move(set)); + } + + Result> operator()(const NotExpression& expr) { + ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); + + if (op.type->id() != Type::BOOL) { + ARROW_ASSIGN_OR_RAISE(op.expr, Cast(boolean(), *op.expr)); + } + return not_(std::move(op.expr)); + } + + Result> operator()(const AndExpression& expr) { + ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); + + if (lhs.type->id() != Type::BOOL) { + ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); + } + if (rhs.type->id() != Type::BOOL) { + ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); + } + return and_(std::move(lhs.expr), std::move(rhs.expr)); + } + + Result> operator()(const OrExpression& expr) { + ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); + + if (lhs.type->id() != Type::BOOL) { + ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); + } + if (rhs.type->id() != Type::BOOL) { + ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); + } + return or_(std::move(lhs.expr), std::move(rhs.expr)); + } + + Result> operator()(const ComparisonExpression& expr) { + ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); + + if (lhs.type->Equals(rhs.type)) { + return expr.Copy(); + } + + if (lhs.expr->type() == ExpressionType::SCALAR) { + ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(rhs.type, *lhs.expr)); + } else { + ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(lhs.type, *rhs.expr)); + } + return std::make_shared(expr.op(), std::move(lhs.expr), + std::move(rhs.expr)); + } + + Result> operator()(const Expression& expr) const { + return expr.Copy(); + } + + const Schema& schema_; +}; + +Result> InsertImplicitCasts(const Expression& expr, + const Schema& schema) { + RETURN_NOT_OK(schema.CanReferenceFieldsByNames(FieldsInExpression(expr))); + return VisitExpression(expr, InsertImplicitCastsImpl{schema}); +} + +Status VisitConjunctionMembers(const Expression& expr, + const std::function& visitor) { + if (expr.type() == ExpressionType::AND) { + const auto& and_ = checked_cast(expr); + RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor)); + RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor)); + return Status::OK(); + } + + return visitor(expr); +} + +std::vector FieldsInExpression(const Expression& expr) { + struct { + void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); } + + void operator()(const UnaryExpression& expr) { + VisitExpression(*expr.operand(), *this); + } + + void operator()(const BinaryExpression& expr) { + VisitExpression(*expr.left_operand(), *this); + VisitExpression(*expr.right_operand(), *this); + } + + void operator()(const Expression&) const {} + + std::vector fields; + } visitor; + + VisitExpression(expr, visitor); + return std::move(visitor.fields); +} + +std::vector FieldsInExpression(const std::shared_ptr& expr) { + DCHECK_NE(expr, nullptr); + if (expr == nullptr) { + return {}; + } + + return FieldsInExpression(*expr); +} + +RecordBatchIterator ExpressionEvaluator::FilterBatches(RecordBatchIterator unfiltered, + std::shared_ptr filter, + MemoryPool* pool) { + auto filter_batches = [filter, pool, this](std::shared_ptr unfiltered) { + auto filtered = Evaluate(*filter, *unfiltered, pool).Map([&](Datum selection) { + return Filter(selection, unfiltered, pool); + }); + + if (filtered.ok() && (*filtered)->num_rows() == 0) { + // drop empty batches + return FilterIterator::Reject>(); + } + + return FilterIterator::MaybeAccept(std::move(filtered)); + }; + + return MakeFilterIterator(std::move(filter_batches), std::move(unfiltered)); +} + +std::shared_ptr ExpressionEvaluator::Null() { + struct Impl : ExpressionEvaluator { + Result Evaluate(const Expression& expr, const RecordBatch& batch, + MemoryPool* pool) const override { + ARROW_ASSIGN_OR_RAISE(auto type, expr.Validate(*batch.schema())); + return Datum(MakeNullScalar(type)); + } + + Result> Filter(const Datum& selection, + const std::shared_ptr& batch, + MemoryPool* pool) const override { + return batch; + } + }; + + return std::make_shared(); +} + +struct TreeEvaluator::Impl { + Result operator()(const ScalarExpression& expr) const { + return Datum(expr.value()); + } + + Result operator()(const FieldExpression& expr) const { + if (auto column = batch_.GetColumnByName(expr.name())) { + return std::move(column); + } + return NullDatum(); + } + + Result operator()(const AndExpression& expr) const { + return EvaluateBoolean(expr, compute::KleeneAnd); + } + + Result operator()(const OrExpression& expr) const { + return EvaluateBoolean(expr, compute::KleeneOr); + } + + Result EvaluateBoolean(const BinaryExpression& expr, + Result kernel(const Datum& left, + const Datum& right, + ExecContext* ctx)) const { + ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); + + if (lhs.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + auto lhs_array, + MakeArrayFromScalar(*lhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); + lhs = Datum(std::move(lhs_array)); + } + + if (rhs.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + auto rhs_array, + MakeArrayFromScalar(*rhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); + rhs = Datum(std::move(rhs_array)); + } + + return kernel(lhs, rhs, &ctx_); + } + + Result operator()(const NotExpression& expr) const { + ARROW_ASSIGN_OR_RAISE(Datum to_invert, Evaluate(*expr.operand())); + if (IsNullDatum(to_invert)) { + return NullDatum(); + } + + if (to_invert.is_scalar()) { + bool trivial_condition = + checked_cast(*to_invert.scalar()).value; + return Datum(std::make_shared(!trivial_condition)); + } + return compute::Invert(to_invert, &ctx_); + } + + Result operator()(const InExpression& expr) const { + ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); + if (IsNullDatum(operand_values)) { + return Datum(expr.set()->null_count() != 0); + } + + DCHECK(operand_values.is_array()); + return compute::IsIn(operand_values, expr.set(), &ctx_); + } + + Result operator()(const IsValidExpression& expr) const { + ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); + if (IsNullDatum(operand_values)) { + return Datum(false); + } + + if (operand_values.is_scalar()) { + return Datum(true); + } + + DCHECK(operand_values.is_array()); + if (operand_values.array()->GetNullCount() == 0) { + return Datum(true); + } + + return Datum(std::make_shared(operand_values.array()->length, + operand_values.array()->buffers[0])); + } + + Result operator()(const CastExpression& expr) const { + ARROW_ASSIGN_OR_RAISE(auto to_type, expr.Validate(*batch_.schema())); + + ARROW_ASSIGN_OR_RAISE(auto to_cast, Evaluate(*expr.operand())); + if (to_cast.is_scalar()) { + return to_cast.scalar()->CastTo(to_type); + } + + DCHECK(to_cast.is_array()); + return CastOrDictionaryEncode(to_cast, to_type, expr.options()); + } + + Result operator()(const ComparisonExpression& expr) const { + ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); + + if (IsNullDatum(lhs) || IsNullDatum(rhs)) { + return Datum(std::make_shared()); + } + + if (lhs.type()->id() == Type::DICTIONARY && rhs.type()->id() == Type::DICTIONARY) { + if (lhs.is_array() && rhs.is_array()) { + // decode dictionary arrays + for (Datum* arg : {&lhs, &rhs}) { + auto dict = checked_pointer_cast(arg->make_array()); + ARROW_ASSIGN_OR_RAISE(*arg, compute::Take(dict->dictionary(), dict->indices(), + compute::TakeOptions::Defaults())); + } + } else if (lhs.is_array() || rhs.is_array()) { + auto dict = checked_pointer_cast( + (lhs.is_array() ? lhs : rhs).make_array()); + + ARROW_ASSIGN_OR_RAISE(auto scalar, checked_cast( + *(lhs.is_scalar() ? lhs : rhs).scalar()) + .GetEncodedValue()); + if (lhs.is_array()) { + lhs = dict->dictionary(); + rhs = std::move(scalar); + } else { + lhs = std::move(scalar); + rhs = dict->dictionary(); + } + ARROW_ASSIGN_OR_RAISE( + Datum out_dict, + compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_)); + + return compute::Take(out_dict, dict->indices(), compute::TakeOptions::Defaults()); + } + } + + return compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_); + } + + Result operator()(const Expression& expr) const { + return Status::NotImplemented("evaluation of ", expr.ToString()); + } + + Result Evaluate(const Expression& expr) const { + return this_->Evaluate(expr, batch_, ctx_.memory_pool()); + } + + const TreeEvaluator* this_; + const RecordBatch& batch_; + mutable compute::ExecContext ctx_; +}; + +Result TreeEvaluator::Evaluate(const Expression& expr, const RecordBatch& batch, + MemoryPool* pool) const { + return VisitExpression(expr, Impl{this, batch, compute::ExecContext{pool}}); +} + +Result> TreeEvaluator::Filter( + const Datum& selection, const std::shared_ptr& batch, + MemoryPool* pool) const { + if (selection.is_array()) { + auto selection_array = selection.make_array(); + compute::ExecContext ctx(pool); + ARROW_ASSIGN_OR_RAISE(Datum filtered, + compute::Filter(batch, selection_array, + compute::FilterOptions::Defaults(), &ctx)); + return filtered.record_batch(); + } + + if (!selection.is_scalar() || selection.type()->id() != Type::BOOL) { + return Status::NotImplemented("Filtering batches against DatumKind::", + selection.kind(), " of type ", *selection.type()); + } + + if (BooleanScalar(true).Equals(*selection.scalar())) { + return batch; + } + + return batch->Slice(0, 0); +} + +const std::shared_ptr& scalar(bool value) { + static auto true_ = scalar(MakeScalar(true)); + static auto false_ = scalar(MakeScalar(false)); + return value ? true_ : false_; +} + +// Serialization is accomplished by converting expressions to single element StructArrays +// then writing that to an IPC file. The last field is always an int32 column containing +// ExpressionType, the rest store the Expression's members. +struct SerializeImpl { + Result> ToArray(const Expression& expr) const { + return VisitExpression(expr, *this); + } + + Result> TaggedWithChildren(const Expression& expr, + ArrayVector children) const { + children.emplace_back(); + ARROW_ASSIGN_OR_RAISE(children.back(), + MakeArrayFromScalar(Int32Scalar(expr.type()), 1)); + + return StructArray::Make(children, std::vector(children.size(), "")); + } + + Result> operator()(const FieldExpression& expr) const { + // store the field's name in a StringArray + ARROW_ASSIGN_OR_RAISE(auto name, MakeArrayFromScalar(StringScalar(expr.name()), 1)); + return TaggedWithChildren(expr, {name}); + } + + Result> operator()(const ScalarExpression& expr) const { + // store the scalar's value in a single element Array + ARROW_ASSIGN_OR_RAISE(auto value, MakeArrayFromScalar(*expr.value(), 1)); + return TaggedWithChildren(expr, {value}); + } + + Result> operator()(const UnaryExpression& expr) const { + // recurse to store the operand in a single element StructArray + ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); + return TaggedWithChildren(expr, {operand}); + } + + Result> operator()(const CastExpression& expr) const { + // recurse to store the operand in a single element StructArray + ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); + + // store the cast target and a discriminant + std::shared_ptr is_like_expr, to; + if (const auto& to_type = expr.to_type()) { + ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(false), 1)); + ARROW_ASSIGN_OR_RAISE(to, MakeArrayOfNull(to_type, 1)); + } + if (const auto& like_expr = expr.like_expr()) { + ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(true), 1)); + ARROW_ASSIGN_OR_RAISE(to, ToArray(*like_expr)); + } + + return TaggedWithChildren(expr, {operand, is_like_expr, to}); + } + + Result> operator()(const BinaryExpression& expr) const { + // recurse to store the operands in single element StructArrays + ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); + return TaggedWithChildren(expr, {left_operand, right_operand}); + } + + Result> operator()( + const ComparisonExpression& expr) const { + // recurse to store the operands in single element StructArrays + ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); + // store the CompareOperator in a single element Int32Array + ARROW_ASSIGN_OR_RAISE(auto op, MakeArrayFromScalar(Int32Scalar(expr.op()), 1)); + return TaggedWithChildren(expr, {left_operand, right_operand, op}); + } + + Result> operator()(const InExpression& expr) const { + // recurse to store the operand in a single element StructArray + ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); + + // store the set as a single element ListArray + auto set_type = list(expr.set()->type()); + + ARROW_ASSIGN_OR_RAISE(auto set_offsets, AllocateBuffer(sizeof(int32_t) * 2)); + reinterpret_cast(set_offsets->mutable_data())[0] = 0; + reinterpret_cast(set_offsets->mutable_data())[1] = + static_cast(expr.set()->length()); + + auto set_values = expr.set(); + + auto set = std::make_shared(std::move(set_type), 1, std::move(set_offsets), + std::move(set_values)); + return TaggedWithChildren(expr, {operand, set}); + } + + Result> operator()(const Expression& expr) const { + return Status::NotImplemented("serialization of ", expr.ToString()); + } + + Result> ToBuffer(const Expression& expr) const { + ARROW_ASSIGN_OR_RAISE(auto array, SerializeImpl{}.ToArray(expr)); + ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array)); + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); + ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + RETURN_NOT_OK(writer->Close()); + return stream->Finish(); + } +}; + +Result> Expression::Serialize() const { + return SerializeImpl{}.ToBuffer(*this); +} + +struct DeserializeImpl { + Result> FromArray(const Array& array) const { + if (array.type_id() != Type::STRUCT || array.length() != 1) { + return Status::Invalid("can only deserialize expressions from unit-length", + " StructArray, got ", array); + } + const auto& struct_array = checked_cast(array); + + ARROW_ASSIGN_OR_RAISE(auto expression_type, GetExpressionType(struct_array)); + switch (expression_type) { + case ExpressionType::FIELD: { + ARROW_ASSIGN_OR_RAISE(auto name, GetView(struct_array, 0)); + return field_ref(std::string(name)); + } + + case ExpressionType::SCALAR: { + ARROW_ASSIGN_OR_RAISE(auto value, struct_array.field(0)->GetScalar(0)); + return scalar(std::move(value)); + } + + case ExpressionType::NOT: { + ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); + return not_(std::move(operand)); + } + + case ExpressionType::CAST: { + ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); + ARROW_ASSIGN_OR_RAISE(auto is_like_expr, GetView(struct_array, 1)); + if (is_like_expr) { + ARROW_ASSIGN_OR_RAISE(auto like_expr, FromArray(*struct_array.field(2))); + return operand->CastLike(std::move(like_expr)).Copy(); + } + return operand->CastTo(struct_array.field(2)->type()).Copy(); + } + + case ExpressionType::AND: { + ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); + ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); + return and_(std::move(left_operand), std::move(right_operand)); + } + + case ExpressionType::OR: { + ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); + ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); + return or_(std::move(left_operand), std::move(right_operand)); + } + + case ExpressionType::COMPARISON: { + ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); + ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); + ARROW_ASSIGN_OR_RAISE(auto op, GetView(struct_array, 2)); + return std::make_shared(static_cast(op), + std::move(left_operand), + std::move(right_operand)); + } + + case ExpressionType::IS_VALID: { + ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); + return std::make_shared(std::move(operand)); + } + + case ExpressionType::IN: { + ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); + if (struct_array.field(1)->type_id() != Type::LIST) { + return Status::TypeError("expected field 1 of ", struct_array, + " to have list type"); + } + auto set = checked_cast(*struct_array.field(1)).values(); + return std::make_shared(std::move(operand), std::move(set)); + } + + default: + break; + } + + return Status::Invalid("non-deserializable ExpressionType ", expression_type); + } + + template ::ArrayType> + static Result().GetView(0))> GetView(const StructArray& array, + int index) { + if (index >= array.num_fields()) { + return Status::IndexError("expected ", array, " to have a child at index ", index); + } + + const auto& child = *array.field(index); + if (child.type_id() != T::type_id) { + return Status::TypeError("expected child ", index, " of ", array, " to have type ", + T::type_id); + } + + return checked_cast(child).GetView(0); + } + + static Result GetExpressionType(const StructArray& array) { + if (array.struct_type()->num_fields() < 1) { + return Status::Invalid("StructArray didn't contain ExpressionType member"); + } + + ARROW_ASSIGN_OR_RAISE(auto expression_type, + GetView(array, array.num_fields() - 1)); + return static_cast(expression_type); + } + + Result> FromBuffer(const Buffer& serialized) { + io::BufferReader stream(serialized); + ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); + ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); + ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray()); + return FromArray(*array); + } +}; + +Result> Expression::Deserialize(const Buffer& serialized) { + return DeserializeImpl{}.FromBuffer(serialized); +} + +// Transform an array of counts to offsets which will divide a ListArray +// into an equal number of slices with corresponding lengths. +inline Result> CountsToOffsets( + std::shared_ptr counts) { + Int32Builder offset_builder; + RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); + offset_builder.UnsafeAppend(0); + + for (int64_t i = 0; i < counts->length(); ++i) { + DCHECK_NE(counts->Value(i), 0); + auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); + offset_builder.UnsafeAppend(next_offset); + } + + std::shared_ptr offsets; + RETURN_NOT_OK(offset_builder.Finish(&offsets)); + return offsets; +} + +// Helper for simultaneous dictionary encoding of multiple arrays. +// +// The fused dictionary is the Cartesian product of the individual dictionaries. +// For example given two arrays A, B where A has unique values ["ex", "why"] +// and B has unique values [0, 1] the fused dictionary is the set of tuples +// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. +// +// TODO(bkietz) this capability belongs in an Action of the hash kernels, where +// it can be used to group aggregates without materializing a grouped batch. +// For the purposes of writing we need the materialized grouped batch anyway +// since no Writers accept a selection vector. +class StructDictionary { + public: + struct Encoded { + std::shared_ptr indices; + std::shared_ptr dictionary; + }; + + static Result Encode(const ArrayVector& columns) { + Encoded out{nullptr, std::make_shared()}; + + for (const auto& column : columns) { + if (column->null_count() != 0) { + return Status::NotImplemented("Grouping on a field with nulls"); + } + + RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); + } + + return out; + } + + Result> Decode(std::shared_ptr fused_indices, + FieldVector fields) { + std::vector builders(dictionaries_.size()); + for (Int32Builder& b : builders) { + RETURN_NOT_OK(b.Resize(fused_indices->length())); + } + + std::vector codes(dictionaries_.size()); + for (int64_t i = 0; i < fused_indices->length(); ++i) { + Expand(fused_indices->Value(i), codes.data()); + + auto builder_it = builders.begin(); + for (int32_t index : codes) { + builder_it++->UnsafeAppend(index); + } + } + + ArrayVector columns(dictionaries_.size()); + for (size_t i = 0; i < dictionaries_.size(); ++i) { + std::shared_ptr indices; + RETURN_NOT_OK(builders[i].FinishInternal(&indices)); + + ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); + columns[i] = column.make_array(); + } + + return StructArray::Make(std::move(columns), std::move(fields)); + } + + private: + Status AddOne(Datum column, std::shared_ptr* fused_indices) { + ArrayData* encoded; + if (column.type()->id() != Type::DICTIONARY) { + ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); + } + encoded = column.mutable_array(); + + auto indices = + std::make_shared(encoded->length, std::move(encoded->buffers[1])); + + dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); + auto dictionary_size = static_cast(dictionaries_.back()->length()); + + if (*fused_indices == nullptr) { + *fused_indices = std::move(indices); + size_ = dictionary_size; + return Status::OK(); + } + + // It's useful to think about the case where each of dictionaries_ has size 10. + // In this case the decimal digit in the ones place is the code in dictionaries_[0], + // the tens place corresponds to dictionaries_[1], etc. + // The incumbent indices must be shifted to the hundreds place so as not to collide. + ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, + compute::Multiply(indices, MakeScalar(size_))); + + ARROW_ASSIGN_OR_RAISE(new_fused_indices, + compute::Add(new_fused_indices, *fused_indices)); + + *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); + + // XXX should probably cap this at 2**15 or so + ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); + return Status::OK(); + } + + // expand a fused code into component dict codes, order is in order of addition + void Expand(int32_t fused_code, int32_t* codes) { + for (size_t i = 0; i < dictionaries_.size(); ++i) { + auto dictionary_size = static_cast(dictionaries_[i]->length()); + codes[i] = fused_code % dictionary_size; + fused_code /= dictionary_size; + } + } + + int32_t size_; + ArrayVector dictionaries_; +}; + +Result> MakeGroupings(const StructArray& by) { + if (by.num_fields() == 0) { + return Status::NotImplemented("Grouping with no criteria"); + } + + ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); + + ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); + ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); + fused.indices = checked_pointer_cast(sorted.make_array()); + + ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, + compute::ValueCounts(fused.indices)); + fused.indices.reset(); + + auto unique_fused_indices = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); + ARROW_ASSIGN_OR_RAISE( + auto unique_rows, + fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); + + auto counts = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); + ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); + + ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, + ListArray::FromArrays(*offsets, *sort_indices)); + + return StructArray::Make( + ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, + std::vector{"values", "groupings"}); +} + +Result> ApplyGroupings(const ListArray& groupings, + const Array& array) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(array, groupings.data()->child_data[0])); + + return std::make_shared(list(array.type()), groupings.length(), + groupings.value_offsets(), sorted.make_array()); +} + +Result ApplyGroupings(const ListArray& groupings, + const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(batch, groupings.data()->child_data[0])); + + const auto& sorted_batch = *sorted.record_batch(); + + RecordBatchVector out(static_cast(groupings.length())); + for (size_t i = 0; i < out.size(); ++i) { + out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); + } + + return out; +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index fba8194aeb165..5cb903c8ccfe8 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -332,10 +332,17 @@ class DecimalConverter final const DecimalSubtype* decimal_type_; }; -template ::BuilderType> -using Decimal128Converter = DecimalConverter; -template ::BuilderType> -using Decimal256Converter = DecimalConverter; +#define DECL_DECIMAL_CONVERTER(width) \ +template ::BuilderType> \ +using Decimal##width##Converter = DecimalConverter; + +DECL_DECIMAL_CONVERTER(16) +DECL_DECIMAL_CONVERTER(32) +DECL_DECIMAL_CONVERTER(64) +DECL_DECIMAL_CONVERTER(128) +DECL_DECIMAL_CONVERTER(256) + +#undef DECL_DECIMAL_CONVERTER // ------------------------------------------------------------------------ // Converter for timestamp arrays @@ -786,6 +793,9 @@ Status GetDictConverter(const std::shared_ptr& type, PARAM_CONVERTER_CASE(Type::LARGE_BINARY, StringConverter, LargeBinaryType) SIMPLE_CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryConverter, FixedSizeBinaryType) + SIMPLE_CONVERTER_CASE(Type::DECIMAL16, Decimal16Converter, Decimal16Type) + SIMPLE_CONVERTER_CASE(Type::DECIMAL32, Decimal32Converter, Decimal32Type) + SIMPLE_CONVERTER_CASE(Type::DECIMAL64, Decimal64Converter, Decimal64Type) SIMPLE_CONVERTER_CASE(Type::DECIMAL128, Decimal128Converter, Decimal128Type) SIMPLE_CONVERTER_CASE(Type::DECIMAL256, Decimal256Converter, Decimal256Type) default: @@ -843,6 +853,9 @@ Status GetConverter(const std::shared_ptr& type, SIMPLE_CONVERTER_CASE(Type::LARGE_STRING, StringConverter) SIMPLE_CONVERTER_CASE(Type::LARGE_BINARY, StringConverter) SIMPLE_CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryConverter<>) + SIMPLE_CONVERTER_CASE(Type::DECIMAL16, Decimal16Converter<>) + SIMPLE_CONVERTER_CASE(Type::DECIMAL32, Decimal32Converter<>) + SIMPLE_CONVERTER_CASE(Type::DECIMAL64, Decimal64Converter<>) SIMPLE_CONVERTER_CASE(Type::DECIMAL128, Decimal128Converter<>) SIMPLE_CONVERTER_CASE(Type::DECIMAL256, Decimal256Converter<>) SIMPLE_CONVERTER_CASE(Type::SPARSE_UNION, UnionConverter) diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index c5358ac89f186..3967f5f54c0e4 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -41,6 +41,7 @@ #include "arrow/util/bitmap_builders.h" #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" +#include "arrow/util/decimal_type_traits.h" #if defined(_MSC_VER) // "warning C4307: '+': integral constant overflow" @@ -498,8 +499,19 @@ TEST(TestFixedSizeBinary, Dictionary) { ASSERT_RAISES(Invalid, ArrayFromJSON(dictionary(int8(), type), R"(["x"])", &array)); } -template -void TestDecimalBasic(std::shared_ptr type) { +template +class TestDecimal : public testing::Test {}; +using DecimalTypes = ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, DecimalTypeTraits<128>, DecimalTypeTraits<256>>; + +TYPED_TEST_SUITE(TestDecimal, DecimalTypes); + +TYPED_TEST(TestDecimal, Basic) { + using TypeClass = typename TypeParam::TypeClass; + using DecimalBuilder = typename TypeParam::BuilderType; + using DecimalValue = typename TypeParam::ValueType; + + auto type = std::make_shared(5, 4); + std::shared_ptr expected, actual; ASSERT_OK(ArrayFromJSON(type, "[]", &actual)); @@ -510,54 +522,47 @@ void TestDecimalBasic(std::shared_ptr type) { } AssertArraysEqual(*expected, *actual); - ASSERT_OK(ArrayFromJSON(type, "[\"123.4567\", \"-78.9000\"]", &actual)); + ASSERT_OK(ArrayFromJSON(type, "[\"1.2345\", \"-3.2000\"]", &actual)); ASSERT_OK(actual->ValidateFull()); { DecimalBuilder builder(type); - ASSERT_OK(builder.Append(DecimalValue(1234567))); - ASSERT_OK(builder.Append(DecimalValue(-789000))); + ASSERT_OK(builder.Append(DecimalValue(12345))); + ASSERT_OK(builder.Append(DecimalValue(-32000))); ASSERT_OK(builder.Finish(&expected)); } AssertArraysEqual(*expected, *actual); - ASSERT_OK(ArrayFromJSON(type, "[\"123.4567\", null]", &actual)); + ASSERT_OK(ArrayFromJSON(type, "[\"1.2345\", null]", &actual)); ASSERT_OK(actual->ValidateFull()); { DecimalBuilder builder(type); - ASSERT_OK(builder.Append(DecimalValue(1234567))); + ASSERT_OK(builder.Append(DecimalValue(12345))); ASSERT_OK(builder.AppendNull()); ASSERT_OK(builder.Finish(&expected)); } AssertArraysEqual(*expected, *actual); } -TEST(TestDecimal128, Basics) { - TestDecimalBasic(decimal128(10, 4)); -} - -TEST(TestDecimal256, Basics) { - TestDecimalBasic(decimal256(10, 4)); -} - -TEST(TestDecimal, Errors) { - for (std::shared_ptr type : {decimal128(10, 4), decimal256(10, 4)}) { - std::shared_ptr array; +TYPED_TEST(TestDecimal, Errors) { + using TypeClass = typename TypeParam::TypeClass; + auto type = std::make_shared(5, 4); - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array)); - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[12.3456]", &array)); - // Bad scale - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.345\"]", &array)); - ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.34560\"]", &array)); - } -} + std::shared_ptr array; -TEST(TestDecimal, Dictionary) { - for (std::shared_ptr type : {decimal128(10, 2), decimal256(10, 2)}) { - AssertJSONDictArray(int32(), type, - R"(["123.45", "-78.90", "-78.90", null, "123.45"])", - /*indices=*/"[0, 1, 1, null, 0]", - /*values=*/R"(["123.45", "-78.90"])"); - } + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array)); + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[1.2345]", &array)); + // Bad scale + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.345\"]", &array)); + ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.34560\"]", &array)); +} + +TYPED_TEST(TestDecimal, Dictionary) { + using TypeClass = typename TypeParam::TypeClass; + auto type = std::make_shared(5, 2); + AssertJSONDictArray(int32(), type, + R"(["123.45", "-78.90", "-78.90", null, "123.45"])", + /*indices=*/"[0, 1, 1, null, 0]", + /*values=*/R"(["123.45", "-78.90"])"); } TEST(TestList, IntegerList) { diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index 8c2ac376d1e47..9bf37328151c4 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -222,12 +222,8 @@ class ArrayPrinter : public PrettyPrinter { return Status::OK(); } - Status WriteDataValues(const Decimal128Array& array) { - WriteValues(array, [&](int64_t i) { (*sink_) << array.FormatValue(i); }); - return Status::OK(); - } - - Status WriteDataValues(const Decimal256Array& array) { + template + Status WriteDataValues(const BaseDecimalArray& array) { WriteValues(array, [&](int64_t i) { (*sink_) << array.FormatValue(i); }); return Status::OK(); } diff --git a/cpp/src/arrow/pretty_print_test.cc b/cpp/src/arrow/pretty_print_test.cc index 538e736518527..ab69a7e782752 100644 --- a/cpp/src/arrow/pretty_print_test.cc +++ b/cpp/src/arrow/pretty_print_test.cc @@ -499,13 +499,13 @@ TEST_F(TestPrettyPrint, FixedSizeBinaryType) { } TEST_F(TestPrettyPrint, DecimalTypes) { - int32_t p = 19; + int32_t p = 5; int32_t s = 4; - for (auto type : {decimal128(p, s), decimal256(p, s)}) { - auto array = ArrayFromJSON(type, "[\"123.4567\", \"456.7891\", null]"); + for (auto type : {decimal16(p, s), decimal32(p, s), decimal64(p, s), decimal128(p, s), decimal256(p, s)}) { + auto array = ArrayFromJSON(type, "[\"1.4567\", \"3.2765\", null]"); - static const char* ex = "[\n 123.4567,\n 456.7891,\n null\n]"; + static const char* ex = "[\n 1.4567,\n 3.2765,\n null\n]"; CheckArray(*array, {0}, ex); } } diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index eca711d7c4f49..351e115e00f44 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -69,6 +69,11 @@ struct ScalarHashImpl { return StdHash(s.value.days) & StdHash(s.value.days); } + template + typename std::enable_if::type Visit(const BaseDecimalScalar& s) { + return StdHash(s.value.Value()); + } + Status Visit(const Decimal128Scalar& s) { return StdHash(s.value.low_bits()) & StdHash(s.value.high_bits()); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 99f691e043615..c91239f495d58 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -352,14 +352,6 @@ struct BaseDecimalScalar : public Scalar { ValueType value; }; -struct ARROW_EXPORT Decimal128Scalar : public BaseDecimalScalar<128> { - using BaseDecimalScalar<128>::BaseDecimalScalar; -}; - -struct ARROW_EXPORT Decimal256Scalar : public BaseDecimalScalar<256> { - using BaseDecimalScalar<256>::BaseDecimalScalar; -}; - struct ARROW_EXPORT BaseListScalar : public Scalar { using Scalar::Scalar; using ValueType = std::shared_ptr; diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 30a39e6e4c031..159fa725e0202 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -31,6 +31,7 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/type_traits.h" +#include "arrow/util/decimal_type_traits.h" namespace arrow { @@ -328,29 +329,23 @@ TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); } TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); } -TEST(TestDecimal128Scalar, Basics) { - auto ty = decimal128(3, 2); - auto pi = Decimal128Scalar(Decimal128("3.14"), ty); - auto null = MakeNullScalar(ty); - ASSERT_EQ(pi.value, Decimal128("3.14")); +template +class TestDecimalScalar : public testing::Test {}; +using DecimalTypes = ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, DecimalTypeTraits<128>, DecimalTypeTraits<256>>; - // test Array.GetScalar - auto arr = ArrayFromJSON(ty, "[null, \"3.14\"]"); - ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0)); - ASSERT_OK_AND_ASSIGN(auto second, arr->GetScalar(1)); - ASSERT_TRUE(first->Equals(null)); - ASSERT_FALSE(first->Equals(pi)); - ASSERT_TRUE(second->Equals(pi)); - ASSERT_FALSE(second->Equals(null)); -} +TYPED_TEST_SUITE(TestDecimalScalar, DecimalTypes); + +TYPED_TEST(TestDecimalScalar, Basics) { + using ScalarType = typename TypeParam::ScalarType; + using TypeClass = typename TypeParam::TypeClass; + using ValueType = typename TypeParam::ValueType; -TEST(TestDecimal256Scalar, Basics) { - auto ty = decimal256(3, 2); - auto pi = Decimal256Scalar(Decimal256("3.14"), ty); + auto ty = std::make_shared(3, 2); + auto pi = ScalarType(ValueType("3.14"), ty); auto null = MakeNullScalar(ty); - ASSERT_EQ(pi.value, Decimal256("3.14")); + ASSERT_EQ(pi.value, ValueType("3.14")); // test Array.GetScalar auto arr = ArrayFromJSON(ty, "[null, \"3.14\"]"); diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index aa9d22dae2fa5..9307cc771ffc9 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -69,6 +69,9 @@ std::vector AllTypeIds() { Type::HALF_FLOAT, Type::FLOAT, Type::DOUBLE, + Type::DECIMAL16, + Type::DECIMAL32, + Type::DECIMAL64, Type::DECIMAL128, Type::DECIMAL256, Type::DATE32, diff --git a/cpp/src/arrow/testing/json_internal.cc b/cpp/src/arrow/testing/json_internal.cc index 21f7514289db0..ed59d0111066a 100644 --- a/cpp/src/arrow/testing/json_internal.cc +++ b/cpp/src/arrow/testing/json_internal.cc @@ -298,14 +298,8 @@ class SchemaWriter { writer_->Int(type.list_size()); } - void WriteTypeMetadata(const Decimal128Type& type) { - writer_->Key("precision"); - writer_->Int(type.precision()); - writer_->Key("scale"); - writer_->Int(type.scale()); - } - - void WriteTypeMetadata(const Decimal256Type& type) { + template + void WriteTypeMetadata(const BaseDecimalType& type) { writer_->Key("precision"); writer_->Int(type.precision()); writer_->Key("scale"); @@ -384,6 +378,9 @@ class SchemaWriter { return WritePrimitive("fixedsizebinary", type); } + Status Visit(const Decimal16Type& type) { return WritePrimitive("decimal16", type); } + Status Visit(const Decimal32Type& type) { return WritePrimitive("decimal32", type); } + Status Visit(const Decimal64Type& type) { return WritePrimitive("decimal64", type); } Status Visit(const Decimal128Type& type) { return WritePrimitive("decimal", type); } Status Visit(const Decimal256Type& type) { return WritePrimitive("decimal256", type); } Status Visit(const TimestampType& type) { return WritePrimitive("timestamp", type); } @@ -549,23 +546,12 @@ class ArrayWriter { } } - void WriteDataValues(const Decimal128Array& arr) { - static const char null_string[] = "0"; - for (int64_t i = 0; i < arr.length(); ++i) { - if (arr.IsValid(i)) { - const Decimal128 value(arr.GetValue(i)); - writer_->String(value.ToIntegerString()); - } else { - writer_->String(null_string, sizeof(null_string)); - } - } - } - - void WriteDataValues(const Decimal256Array& arr) { + template + void WriteDataValues(const BaseDecimalArray& arr) { static const char null_string[] = "0"; for (int64_t i = 0; i < arr.length(); ++i) { if (arr.IsValid(i)) { - const Decimal256 value(arr.GetValue(i)); + const typename BaseDecimalArray::ValueType value(arr.GetValue(i)); writer_->String(value.ToIntegerString()); } else { writer_->String(null_string, sizeof(null_string)); @@ -860,14 +846,27 @@ Status GetDecimal(const RjObject& json_type, std::shared_ptr* type) { bit_width = maybe_bit_width.ValueOrDie(); } - if (bit_width == 128) { - *type = decimal128(precision, scale); - } else if (bit_width == 256) { - *type = decimal256(precision, scale); - } else { - return Status::Invalid("Only 128 bit and 256 Decimals are supported. Received", - bit_width); + switch (bit_width) { + case 16: + *type = decimal16(precision, scale); + break; + case 32: + *type = decimal32(precision, scale); + break; + case 64: + *type = decimal64(precision, scale); + break; + case 128: + *type = decimal128(precision, scale); + break; + case 256: + *type = decimal256(precision, scale); + break; + default: + return Status::Invalid("Only 128 bit and 256 Decimals are supported. Received", + bit_width); } + return Status::OK(); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 0a8dbff9ace57..ca083ad72e8d8 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -67,6 +67,12 @@ constexpr Type::type FixedSizeBinaryType::type_id; constexpr Type::type StructType::type_id; +constexpr Type::type Decimal16Type::type_id; + +constexpr Type::type Decimal32Type::type_id; + +constexpr Type::type Decimal64Type::type_id; + constexpr Type::type Decimal128Type::type_id; constexpr Type::type Decimal256Type::type_id; @@ -131,6 +137,9 @@ std::string ToString(Type::type id) { TO_STRING_CASE(HALF_FLOAT) TO_STRING_CASE(FLOAT) TO_STRING_CASE(DOUBLE) + TO_STRING_CASE(DECIMAL16) + TO_STRING_CASE(DECIMAL32) + TO_STRING_CASE(DECIMAL64) TO_STRING_CASE(DECIMAL128) TO_STRING_CASE(DECIMAL256) TO_STRING_CASE(DATE32) @@ -2183,6 +2192,18 @@ std::shared_ptr decimal(int32_t precision, int32_t scale) { : decimal256(precision, scale); } +std::shared_ptr decimal16(int32_t precision, int32_t scale) { + return std::make_shared(precision, scale); +} + +std::shared_ptr decimal32(int32_t precision, int32_t scale) { + return std::make_shared(precision, scale); +} + +std::shared_ptr decimal64(int32_t precision, int32_t scale) { + return std::make_shared(precision, scale); +} + std::shared_ptr decimal128(int32_t precision, int32_t scale) { return std::make_shared(precision, scale); } @@ -2198,6 +2219,9 @@ std::string BaseDecimalType::ToString() const { return s.str(); } +template class BaseDecimalType<16>; +template class BaseDecimalType<32>; +template class BaseDecimalType<64>; template class BaseDecimalType<128>; template class BaseDecimalType<256>; diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 4437c5e76c242..d0c5f4be35ab0 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -913,6 +913,27 @@ class BaseDecimalType : public DecimalType { static constexpr int32_t kByteWidth = width / 8; }; +/// \brief Concrete type class for decimal 16-bit data +class ARROW_EXPORT Decimal16Type : public BaseDecimalType<16> { +public: + static constexpr Type::type type_id = Type::DECIMAL16; + using BaseDecimalType<16>::BaseDecimalType; +}; + +/// \brief Concrete type class for decimal 32-bit data +class ARROW_EXPORT Decimal32Type : public BaseDecimalType<32> { +public: + static constexpr Type::type type_id = Type::DECIMAL32; + using BaseDecimalType<32>::BaseDecimalType; +}; + +/// \brief Concrete type class for decimal 64-bit data +class ARROW_EXPORT Decimal64Type : public BaseDecimalType<64> { +public: + static constexpr Type::type type_id = Type::DECIMAL64; + using BaseDecimalType<64>::BaseDecimalType; +}; + /// \brief Concrete type class for decimal 128-bit data class ARROW_EXPORT Decimal128Type : public BaseDecimalType<128> { public: diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index ea8a39f970ebc..5f85ce462e96d 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -145,13 +145,33 @@ struct StructScalar; class DecimalType; -#define DECIMAL_DECL(width) \ -class Decimal##width; \ -class Decimal##width##Type; \ -class Decimal##width##Array; \ -class Decimal##width##Builder; \ -struct Decimal##width##Scalar; +template +class DecimalAnyWidth; +template +class BaseDecimalArray; + +template +class BaseDecimalBuilder; + +template +struct BaseDecimalScalar; + +using Decimal16 = DecimalAnyWidth<16>; +using Decimal32 = DecimalAnyWidth<32>; +using Decimal64 = DecimalAnyWidth<64>; +class Decimal128; +class Decimal256; + +#define DECIMAL_DECL(width) \ +class Decimal##width##Type; \ +using Decimal##width##Array = BaseDecimalArray; \ +using Decimal##width##Builder = BaseDecimalBuilder; \ +using Decimal##width##Scalar = BaseDecimalScalar; + +DECIMAL_DECL(16) +DECIMAL_DECL(32) +DECIMAL_DECL(64) DECIMAL_DECL(128) DECIMAL_DECL(256) @@ -338,6 +358,12 @@ struct Type { /// DAY_TIME interval in SQL style INTERVAL_DAY_TIME, + DECIMAL16, + + DECIMAL32, + + DECIMAL64, + /// Precision- and scale-based decimal type with 128 bits. DECIMAL128, @@ -444,6 +470,18 @@ std::shared_ptr fixed_size_binary(int32_t byte_width); ARROW_EXPORT std::shared_ptr decimal(int32_t precision, int32_t scale); +/// \brief Create a Decimal16Type instance +ARROW_EXPORT +std::shared_ptr decimal16(int32_t precision, int32_t scale); + +/// \brief Create a Decimal32Type instance +ARROW_EXPORT +std::shared_ptr decimal32(int32_t precision, int32_t scale); + +/// \brief Create a Decimal64Type instance +ARROW_EXPORT +std::shared_ptr decimal64(int32_t precision, int32_t scale); + /// \brief Create a Decimal128Type instance ARROW_EXPORT std::shared_ptr decimal128(int32_t precision, int32_t scale); diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index d5ece2eea8e97..3a150c5e91d88 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1851,6 +1851,48 @@ TEST(TypesTest, TestDecimal256Large) { EXPECT_EQ(t1.bit_width(), 256); } +TEST(TypesTest, TestDecimal16) { + Decimal16Type t1(5, 3); + + EXPECT_EQ(t1.id(), Type::DECIMAL16); + EXPECT_EQ(t1.precision(), 5); + EXPECT_EQ(t1.scale(), 3); + + EXPECT_EQ(t1.ToString(), std::string("decimal16(5, 3)")); + + // Test properties + EXPECT_EQ(t1.byte_width(), 2); + EXPECT_EQ(t1.bit_width(), 16); +} + +TEST(TypesTest, TestDecimal32) { + Decimal32Type t1(10, 5); + + EXPECT_EQ(t1.id(), Type::DECIMAL32); + EXPECT_EQ(t1.precision(), 10); + EXPECT_EQ(t1.scale(), 5); + + EXPECT_EQ(t1.ToString(), std::string("decimal32(10, 5)")); + + // Test properties + EXPECT_EQ(t1.byte_width(), 4); + EXPECT_EQ(t1.bit_width(), 32); +} + +TEST(TypesTest, TestDecimal64) { + Decimal64Type t1(19, 10); + + EXPECT_EQ(t1.id(), Type::DECIMAL64); + EXPECT_EQ(t1.precision(), 19); + EXPECT_EQ(t1.scale(), 10); + + EXPECT_EQ(t1.ToString(), std::string("decimal64(19, 10)")); + + // Test properties + EXPECT_EQ(t1.byte_width(), 8); + EXPECT_EQ(t1.bit_width(), 64); +} + TEST(TypesTest, TestDecimalEquals) { Decimal128Type t1(8, 4); Decimal128Type t2(8, 4); diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index a569a4cb598a8..8865175785536 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -66,6 +66,9 @@ TYPE_ID_TRAIT(TIMESTAMP, TimestampType) TYPE_ID_TRAIT(INTERVAL_DAY_TIME, DayTimeIntervalType) TYPE_ID_TRAIT(INTERVAL_MONTHS, MonthIntervalType) TYPE_ID_TRAIT(DURATION, DurationType) +TYPE_ID_TRAIT(DECIMAL16, Decimal16Type) +TYPE_ID_TRAIT(DECIMAL32, Decimal32Type) +TYPE_ID_TRAIT(DECIMAL64, Decimal64Type) TYPE_ID_TRAIT(DECIMAL128, Decimal128Type) TYPE_ID_TRAIT(DECIMAL256, Decimal256Type) TYPE_ID_TRAIT(STRUCT, StructType) @@ -291,6 +294,9 @@ struct TypeTraits { \ constexpr static bool is_parameter_free = false; \ }; +DECIMAL_TYPE_TRAITS_DEF(16) +DECIMAL_TYPE_TRAITS_DEF(32) +DECIMAL_TYPE_TRAITS_DEF(64) DECIMAL_TYPE_TRAITS_DEF(128) DECIMAL_TYPE_TRAITS_DEF(256) @@ -585,6 +591,24 @@ using is_decimal_type = std::is_base_of; template using enable_if_decimal = enable_if_t::value, R>; +template +using is_decimal16_type = std::is_base_of; + +template +using enable_if_decimal16 = enable_if_t::value, R>; + +template +using is_decimal32_type = std::is_base_of; + +template +using enable_if_decimal32 = enable_if_t::value, R>; + +template +using is_decimal64_type = std::is_base_of; + +template +using enable_if_decimal64 = enable_if_t::value, R>; + template using is_decimal128_type = std::is_base_of; @@ -914,6 +938,9 @@ static inline bool is_dictionary(Type::type type_id) { static inline bool is_fixed_size_binary(Type::type type_id) { switch (type_id) { + case Type::DECIMAL16: + case Type::DECIMAL32: + case Type::DECIMAL64: case Type::DECIMAL128: case Type::DECIMAL256: case Type::FIXED_SIZE_BINARY: diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index b2864239b98bf..08020a4556230 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -30,6 +30,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/int128_internal.h" #include "arrow/util/int_util_internal.h" +#include "arrow/util/decimal_scale_multipliers.h" #include "arrow/util/decimal_meta.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" @@ -38,88 +39,8 @@ namespace arrow { using internal::SafeLeftShift; using internal::SafeSignedAdd; +using internal::SafeSignedMultiply; -static const BasicDecimal128 ScaleMultipliers[] = { - BasicDecimal128(1LL), - BasicDecimal128(10LL), - BasicDecimal128(100LL), - BasicDecimal128(1000LL), - BasicDecimal128(10000LL), - BasicDecimal128(100000LL), - BasicDecimal128(1000000LL), - BasicDecimal128(10000000LL), - BasicDecimal128(100000000LL), - BasicDecimal128(1000000000LL), - BasicDecimal128(10000000000LL), - BasicDecimal128(100000000000LL), - BasicDecimal128(1000000000000LL), - BasicDecimal128(10000000000000LL), - BasicDecimal128(100000000000000LL), - BasicDecimal128(1000000000000000LL), - BasicDecimal128(10000000000000000LL), - BasicDecimal128(100000000000000000LL), - BasicDecimal128(1000000000000000000LL), - BasicDecimal128(0LL, 10000000000000000000ULL), - BasicDecimal128(5LL, 7766279631452241920ULL), - BasicDecimal128(54LL, 3875820019684212736ULL), - BasicDecimal128(542LL, 1864712049423024128ULL), - BasicDecimal128(5421LL, 200376420520689664ULL), - BasicDecimal128(54210LL, 2003764205206896640ULL), - BasicDecimal128(542101LL, 1590897978359414784ULL), - BasicDecimal128(5421010LL, 15908979783594147840ULL), - BasicDecimal128(54210108LL, 11515845246265065472ULL), - BasicDecimal128(542101086LL, 4477988020393345024ULL), - BasicDecimal128(5421010862LL, 7886392056514347008ULL), - BasicDecimal128(54210108624LL, 5076944270305263616ULL), - BasicDecimal128(542101086242LL, 13875954555633532928ULL), - BasicDecimal128(5421010862427LL, 9632337040368467968ULL), - BasicDecimal128(54210108624275LL, 4089650035136921600ULL), - BasicDecimal128(542101086242752LL, 4003012203950112768ULL), - BasicDecimal128(5421010862427522LL, 3136633892082024448ULL), - BasicDecimal128(54210108624275221LL, 12919594847110692864ULL), - BasicDecimal128(542101086242752217LL, 68739955140067328ULL), - BasicDecimal128(5421010862427522170LL, 687399551400673280ULL)}; - -static const BasicDecimal128 ScaleMultipliersHalf[] = { - BasicDecimal128(0ULL), - BasicDecimal128(5ULL), - BasicDecimal128(50ULL), - BasicDecimal128(500ULL), - BasicDecimal128(5000ULL), - BasicDecimal128(50000ULL), - BasicDecimal128(500000ULL), - BasicDecimal128(5000000ULL), - BasicDecimal128(50000000ULL), - BasicDecimal128(500000000ULL), - BasicDecimal128(5000000000ULL), - BasicDecimal128(50000000000ULL), - BasicDecimal128(500000000000ULL), - BasicDecimal128(5000000000000ULL), - BasicDecimal128(50000000000000ULL), - BasicDecimal128(500000000000000ULL), - BasicDecimal128(5000000000000000ULL), - BasicDecimal128(50000000000000000ULL), - BasicDecimal128(500000000000000000ULL), - BasicDecimal128(5000000000000000000ULL), - BasicDecimal128(2LL, 13106511852580896768ULL), - BasicDecimal128(27LL, 1937910009842106368ULL), - BasicDecimal128(271LL, 932356024711512064ULL), - BasicDecimal128(2710LL, 9323560247115120640ULL), - BasicDecimal128(27105LL, 1001882102603448320ULL), - BasicDecimal128(271050LL, 10018821026034483200ULL), - BasicDecimal128(2710505LL, 7954489891797073920ULL), - BasicDecimal128(27105054LL, 5757922623132532736ULL), - BasicDecimal128(271050543LL, 2238994010196672512ULL), - BasicDecimal128(2710505431LL, 3943196028257173504ULL), - BasicDecimal128(27105054312LL, 2538472135152631808ULL), - BasicDecimal128(271050543121LL, 6937977277816766464ULL), - BasicDecimal128(2710505431213LL, 14039540557039009792ULL), - BasicDecimal128(27105054312137LL, 11268197054423236608ULL), - BasicDecimal128(271050543121376LL, 2001506101975056384ULL), - BasicDecimal128(2710505431213761LL, 1568316946041012224ULL), - BasicDecimal128(27105054312137610LL, 15683169460410122240ULL), - BasicDecimal128(271050543121376108LL, 9257742014424809472ULL), - BasicDecimal128(2710505431213761085LL, 343699775700336640ULL)}; static const BasicDecimal256 ScaleMultipliersDecimal256[] = { BasicDecimal256({1ULL, 0ULL, 0ULL, 0ULL}), @@ -233,7 +154,7 @@ static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF; static constexpr uint64_t kInt32Mask = 0xFFFFFFFF; #endif -// same as ScaleMultipliers[38] - 1 +// same as ScaleMultipliers128[38] - 1 static constexpr BasicDecimal128 kMaxValue = BasicDecimal128(5421010862427522170LL, 687399551400673280ULL - 1); @@ -283,7 +204,7 @@ BasicDecimal128 BasicDecimal128::Abs(const BasicDecimal128& in) { bool BasicDecimal128::FitsInPrecision(int32_t precision) const { DCHECK_GT(precision, 0); DCHECK_LE(precision, DecimalMeta<128>::max_precision); - return BasicDecimal128::Abs(*this) < ScaleMultipliers[precision]; + return BasicDecimal128::Abs(*this) < ScaleMultipliers128[precision]; } BasicDecimal128& BasicDecimal128::operator+=(const BasicDecimal128& right) { @@ -931,8 +852,6 @@ DecimalStatus DecimalRescale(const DecimalClass& value, int32_t original_scale, const int32_t abs_delta_scale = std::abs(delta_scale); DecimalClass multiplier = DecimalClass::GetScaleMultiplier(abs_delta_scale); - DCHECK_GE(abs_delta_scale, 1); - DCHECK_LE(abs_delta_scale, DecimalMeta<128>::max_precision); const bool rescale_would_cause_data_loss = RescaleWouldCauseDataLoss(value, delta_scale, multiplier, out); @@ -955,7 +874,7 @@ void BasicDecimal128::GetWholeAndFraction(int scale, BasicDecimal128* whole, DCHECK_GE(scale, 0); DCHECK_LE(scale, DecimalMeta<128>::max_precision); - BasicDecimal128 multiplier(ScaleMultipliers[scale]); + BasicDecimal128 multiplier(ScaleMultipliers128[scale]); auto s = Divide(multiplier, whole, fraction); DCHECK_EQ(s, DecimalStatus::kSuccess); } @@ -964,7 +883,7 @@ const BasicDecimal128& BasicDecimal128::GetScaleMultiplier(int32_t scale) { DCHECK_GE(scale, 0); DCHECK_LE(scale, DecimalMeta<128>::max_precision); - return ScaleMultipliers[scale]; + return ScaleMultipliers128[scale]; } const BasicDecimal128& BasicDecimal128::GetMaxValue() { return kMaxValue; } @@ -973,7 +892,7 @@ BasicDecimal128 BasicDecimal128::IncreaseScaleBy(int32_t increase_by) const { DCHECK_GE(increase_by, 0); DCHECK_LE(increase_by, DecimalMeta<128>::max_precision); - return (*this) * ScaleMultipliers[increase_by]; + return (*this) * ScaleMultipliers128[increase_by]; } BasicDecimal128 BasicDecimal128::ReduceScaleBy(int32_t reduce_by, bool round) const { @@ -984,13 +903,13 @@ BasicDecimal128 BasicDecimal128::ReduceScaleBy(int32_t reduce_by, bool round) co return *this; } - BasicDecimal128 divisor(ScaleMultipliers[reduce_by]); + BasicDecimal128 divisor(ScaleMultipliers128[reduce_by]); BasicDecimal128 result; BasicDecimal128 remainder; auto s = Divide(divisor, &result, &remainder); DCHECK_EQ(s, DecimalStatus::kSuccess); if (round) { - auto divisor_half = ScaleMultipliersHalf[reduce_by]; + auto divisor_half = ScaleMultipliersHalf128[reduce_by]; if (remainder.Abs() >= divisor_half) { if (result > 0) { result += 1; @@ -1198,4 +1117,136 @@ BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& ri return result; } +/// BasicDecimalAnyWidth + +template +BasicDecimalAnyWidth::BasicDecimalAnyWidth(const uint8_t* bytes) { + DCHECK_NE(bytes, nullptr); + value = *(reinterpret_cast(bytes)); +}; + +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator+=(const BasicDecimalAnyWidth& right) { + value = SafeSignedAdd(value, right.value); + return *this; +} + +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator-=(const BasicDecimalAnyWidth& right) { + value -= right.value; + return *this; +} + +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator*=(const BasicDecimalAnyWidth& right) { + value = value * right.value; + return *this; +} + +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator/=(const BasicDecimalAnyWidth& right) { + BasicDecimalAnyWidth remainder; + auto s = Divide(right, this, &remainder); + DCHECK_EQ(s, DecimalStatus::kSuccess); + return *this; +} + +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator%=(const BasicDecimalAnyWidth& right) { + BasicDecimalAnyWidth result; + auto s = Divide(right, &result, this); + DCHECK_EQ(s, DecimalStatus::kSuccess); + return *this; +} + + +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::Abs() { return *this < 0 ? Negate() : *this; } + +template +BasicDecimalAnyWidth BasicDecimalAnyWidth::Abs(const BasicDecimalAnyWidth& in) { + BasicDecimalAnyWidth result(in); + return result.Abs(); +} + +template +DecimalStatus BasicDecimalAnyWidth::Divide(const BasicDecimalAnyWidth& divisor, BasicDecimalAnyWidth* result, + BasicDecimalAnyWidth* remainder) const { + if (divisor.value == 0) { + return DecimalStatus::kDivideByZero; + } + + bool dividen_was_negative = Sign() == -1; + bool divisor_was_negative = divisor.Sign() == -1; + + *result = value / divisor.value; + *remainder = value % divisor.value; + + FixDivisionSigns(result, remainder, dividen_was_negative, divisor_was_negative); + return DecimalStatus::kSuccess; +} + +template +BasicDecimalAnyWidth BasicDecimalAnyWidth::GetScaleMultiplier(int32_t scale) { + DCHECK_GE(scale, 0); + DCHECK_LE(scale, DecimalMeta::max_precision); + + return BasicDecimalAnyWidth(ScaleMultipliersAnyWidth::value[scale]); +} + +template +std::array> 3)> BasicDecimalAnyWidth::ToBytes() const { + std::array> 3)> out{{0}}; + ToBytes(out.data()); + return out; +} + +template +void BasicDecimalAnyWidth::ToBytes(uint8_t* out) const { + DCHECK_NE(out, nullptr); + reinterpret_cast(out)[0] = value; +} + +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::Negate() { + value = - value; + return *this; +} + +template +DecimalStatus BasicDecimalAnyWidth::Rescale(int32_t original_scale, int32_t new_scale, + BasicDecimalAnyWidth* out) const { + return DecimalRescale(*this, original_scale, new_scale, out); +} + +template +bool BasicDecimalAnyWidth::FitsInPrecision(int32_t precision) const { + DCHECK_GT(precision, 0); + DCHECK_LE(precision, DecimalMeta::max_precision); + return BasicDecimalAnyWidth::Abs(*this) < ScaleMultipliersAnyWidth::value[precision]; +} + +template +void BasicDecimalAnyWidth::GetWholeAndFraction(int scale, BasicDecimalAnyWidth* whole, + BasicDecimalAnyWidth* fraction) const { + DCHECK_GE(scale, 0); + DCHECK_LE(scale, DecimalMeta::max_precision); + + BasicDecimalAnyWidth multiplier(ScaleMultipliersAnyWidth::value[scale]); + auto s = Divide(multiplier, whole, fraction); + DCHECK_EQ(s, DecimalStatus::kSuccess); +} + +template +BasicDecimalAnyWidth BasicDecimalAnyWidth::IncreaseScaleBy(int32_t increase_by) const { + DCHECK_GE(increase_by, 0); + DCHECK_LE(increase_by, DecimalMeta::max_precision); + + return (*this) * ScaleMultipliersAnyWidth::value[increase_by]; +} + +template class BasicDecimalAnyWidth<64>; +template class BasicDecimalAnyWidth<32>; +template class BasicDecimalAnyWidth<16>; + } // namespace arrow diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h index b62d894207713..0b48f89f1b9db 100644 --- a/cpp/src/arrow/util/basic_decimal.h +++ b/cpp/src/arrow/util/basic_decimal.h @@ -26,6 +26,7 @@ #include "arrow/util/macros.h" #include "arrow/util/type_traits.h" #include "arrow/util/visibility.h" +#include "arrow/util/decimal_meta.h" namespace arrow { @@ -36,6 +37,9 @@ enum class DecimalStatus { kRescaleDataLoss, }; +template +class BasicDecimalAnyWidth; + /// Represents a signed 128-bit integer in two's complement. /// /// This class is also compiled into LLVM IR - so, it should not have cpp references like @@ -59,6 +63,11 @@ class ARROW_EXPORT BasicDecimal128 { : BasicDecimal128(value >= T{0} ? 0 : -1, static_cast(value)) { // NOLINT } + /// \brief Upcast BasicDecimal with less widths + template + constexpr BasicDecimal128(const BasicDecimalAnyWidth& other) noexcept + : BasicDecimal128(other.Value()) {} + /// \brief Create a BasicDecimal128 from an array of bytes. Bytes are assumed to be in /// native-endian byte order. explicit BasicDecimal128(const uint8_t* bytes); @@ -211,6 +220,10 @@ class ARROW_EXPORT BasicDecimal256 { : little_endian_array_({value.low_bits(), static_cast(value.high_bits()), extend(value.high_bits()), extend(value.high_bits())}) {} + template + constexpr BasicDecimal256(const BasicDecimalAnyWidth& other) noexcept + : BasicDecimal256(other.Value()) {} + /// \brief Create a BasicDecimal256 from an array of bytes. Bytes are assumed to be in /// native-endian byte order. explicit BasicDecimal256(const uint8_t* bytes); @@ -322,4 +335,151 @@ ARROW_EXPORT BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right); ARROW_EXPORT BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& right); + + +template +class BasicDecimalAnyWidth { + public: + using ValueType = typename IntTypes::signed_type; + /// \brief Empty constructor creates a BasicDecimal256 with a value of 0. + constexpr BasicDecimalAnyWidth() noexcept : value(0) {} + + /// \brief Convert any integer value into a BasicDecimal256. + template ::value && + ((sizeof(T) < sizeof(ValueType)) || ((sizeof(T) == sizeof(ValueType)) && std::is_signed::value) + || std::is_same::value), T>::type> + constexpr BasicDecimalAnyWidth(T value) noexcept + : value(static_cast(value)) {} + + /// \brief Upcast BasicDecimal with less widths + template ::type> + constexpr BasicDecimalAnyWidth(const BasicDecimalAnyWidth<_width>& other) noexcept + : value(static_cast(other.Value())) {} + + /// \brief Create a BasicDecimal256 from an array of bytes. Bytes are assumed to be in + /// native-endian byte order. + explicit BasicDecimalAnyWidth(const uint8_t* bytes); + + /// \brief Negate the current value (in-place) + BasicDecimalAnyWidth& Negate(); + + /// \brief Absolute value (in-place) + BasicDecimalAnyWidth& Abs(); + + /// \brief Absolute value + static BasicDecimalAnyWidth Abs(const BasicDecimalAnyWidth& left); + + DecimalStatus Divide(const BasicDecimalAnyWidth& divisor, BasicDecimalAnyWidth* result, + BasicDecimalAnyWidth* remainder) const; + + // \brief Scale multiplier for given scale value. + static BasicDecimalAnyWidth GetScaleMultiplier(int32_t scale); + + /// \brief Return the raw bytes of the value in native-endian byte order. + std::array> 3)> ToBytes() const; + void ToBytes(uint8_t* out) const; + + /// \brief Convert BasicDecimal128 from one scale to another + DecimalStatus Rescale(int32_t original_scale, int32_t new_scale, + BasicDecimalAnyWidth* out) const; + + inline int64_t Sign() const { return value >= 0 ? 1 : -1; } + + /// \brief Get the high bits of the two's complement representation of the number. + inline constexpr ValueType Value() const { return value; } + + /// \brief Whether this number fits in the given precision + /// + /// Return true if the number of significant digits is less or equal to `precision`. + bool FitsInPrecision(int32_t precision) const; + + /// \brief separate the integer and fractional parts for the given scale. + void GetWholeAndFraction(int32_t scale, BasicDecimalAnyWidth* whole, + BasicDecimalAnyWidth* fraction) const; + + /// \brief Scale up. + BasicDecimalAnyWidth IncreaseScaleBy(int32_t increase_by) const; + + BasicDecimalAnyWidth& operator +=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator -=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator *=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator /=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator %=(const BasicDecimalAnyWidth&); + + friend bool operator==(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + return left.value == right.value; + } + + friend bool operator!=(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + return !operator==(left, right); + } + + friend bool operator<(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + return left.value < right.value; + } + + friend bool operator<=(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + return !operator<(right, left); + } + + friend bool operator>(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + return operator<(right, left); + } + + friend bool operator>=(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + return !operator<(left, right); + } + + friend BasicDecimalAnyWidth operator+(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + BasicDecimalAnyWidth result(left); + result += right; + return result; + }; + + friend BasicDecimalAnyWidth operator-(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + BasicDecimalAnyWidth result(left); + result -= right; + return result; + }; + + friend BasicDecimalAnyWidth operator*(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + BasicDecimalAnyWidth result(left); + result *= right; + return result; + }; + + friend BasicDecimalAnyWidth operator/(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { + BasicDecimalAnyWidth result = left; + result /= right; + return result; + }; + + friend BasicDecimalAnyWidth operator%(const BasicDecimalAnyWidth& left, const BasicDecimalAnyWidth& right) { + BasicDecimalAnyWidth result = left; + result %= right; + return result; + } + + private: + ValueType value; +}; + + +using BasicDecimal64 = BasicDecimalAnyWidth<64>; +using BasicDecimal32 = BasicDecimalAnyWidth<32>; +using BasicDecimal16 = BasicDecimalAnyWidth<16>; + } // namespace arrow diff --git a/cpp/src/arrow/util/decimal.cc b/cpp/src/arrow/util/decimal.cc index dcb2023616aeb..42bc9e7885a20 100644 --- a/cpp/src/arrow/util/decimal.cc +++ b/cpp/src/arrow/util/decimal.cc @@ -463,15 +463,13 @@ inline Status ToArrowStatus(DecimalStatus dstatus, int num_bits) { return Status::OK(); } -} // namespace - -Status Decimal128::FromString(const util::string_view& s, Decimal128* out, - int32_t* precision, int32_t* scale) { +Status FromStringToArray(const util::string_view& s, DecimalComponents& dec, + uint64_t* out, int32_t array_size, + int32_t* precision, int32_t* scale) { if (s.empty()) { return Status::Invalid("Empty string cannot be converted to decimal"); } - DecimalComponents dec; if (!ParseDecimalComponents(s.data(), s.size(), &dec)) { return Status::Invalid("The string '", s, "' is not a valid decimal number"); } @@ -482,26 +480,47 @@ Status Decimal128::FromString(const util::string_view& s, Decimal128* out, if (first_non_zero != std::string::npos) { significant_digits += dec.whole_digits.size() - first_non_zero; } - int32_t parsed_precision = static_cast(significant_digits); - int32_t parsed_scale = 0; - if (dec.has_exponent) { - auto adjusted_exponent = dec.exponent; - auto len = static_cast(significant_digits); - parsed_scale = -adjusted_exponent + len - 1; - } else { - parsed_scale = static_cast(dec.fractional_digits.size()); + if (precision != nullptr) { + *precision = static_cast(significant_digits); + } + + if (scale != nullptr) { + if (dec.has_exponent) { + auto adjusted_exponent = dec.exponent; + auto len = static_cast(significant_digits); + *scale = -adjusted_exponent + len - 1; + } else { + *scale = static_cast(dec.fractional_digits.size()); + } } if (out != nullptr) { - std::array little_endian_array = {0, 0}; - ShiftAndAdd(dec.whole_digits, little_endian_array.data(), little_endian_array.size()); - ShiftAndAdd(dec.fractional_digits, little_endian_array.data(), - little_endian_array.size()); - *out = - Decimal128(static_cast(little_endian_array[1]), little_endian_array[0]); - if (parsed_scale < 0) { - *out *= GetScaleMultiplier(-parsed_scale); + ShiftAndAdd(dec.whole_digits, out, array_size); + ShiftAndAdd(dec.fractional_digits, out, + array_size); + } + + return Status::OK(); +} + +} // namespace + +Status Decimal128::FromString(const util::string_view& s, Decimal128* out, + int32_t* precision, int32_t* scale) { + std::array little_endian_array = {0, 0}; + DecimalComponents dec; + + auto status = FromStringToArray(s, dec, little_endian_array.data(), 2, precision, scale); + if (status != Status::OK()) { + return status; + } + + if (out != nullptr) { + *out = Decimal128(static_cast(little_endian_array[1]), little_endian_array[0]); + + if (scale != nullptr && *scale < 0) { + *out *= GetScaleMultiplier(-*scale); } if (dec.sign == '-') { @@ -509,19 +528,14 @@ Status Decimal128::FromString(const util::string_view& s, Decimal128* out, } } - if (parsed_scale < 0) { - parsed_precision -= parsed_scale; - parsed_scale = 0; - } - - if (precision != nullptr) { - *precision = parsed_precision; - } - if (scale != nullptr) { - *scale = parsed_scale; + if (scale != nullptr && *scale < 0) { + if (precision != nullptr) { + *precision -= *scale; + } + *scale = 0; } - return Status::OK(); + return status; } Status Decimal128::FromString(const std::string& s, Decimal128* out, int32_t* precision, @@ -649,49 +663,22 @@ std::string Decimal256::ToString(int32_t scale) const { Status Decimal256::FromString(const util::string_view& s, Decimal256* out, int32_t* precision, int32_t* scale) { - if (s.empty()) { - return Status::Invalid("Empty string cannot be converted to decimal"); - } - + std::array little_endian_array = {0, 0, 0, 0}; DecimalComponents dec; - if (!ParseDecimalComponents(s.data(), s.size(), &dec)) { - return Status::Invalid("The string '", s, "' is not a valid decimal number"); - } - - // Count number of significant digits (without leading zeros) - size_t first_non_zero = dec.whole_digits.find_first_not_of('0'); - size_t significant_digits = dec.fractional_digits.size(); - if (first_non_zero != std::string::npos) { - significant_digits += dec.whole_digits.size() - first_non_zero; - } - if (precision != nullptr) { - *precision = static_cast(significant_digits); - } - - if (scale != nullptr) { - if (dec.has_exponent) { - auto adjusted_exponent = dec.exponent; - auto len = static_cast(significant_digits); - *scale = -adjusted_exponent + len - 1; - } else { - *scale = static_cast(dec.fractional_digits.size()); - } + auto status = FromStringToArray(s, dec, little_endian_array.data(), 4, precision, scale); + if (status != Status::OK()) { + return status; } if (out != nullptr) { - std::array little_endian_array = {0, 0, 0, 0}; - ShiftAndAdd(dec.whole_digits, little_endian_array.data(), little_endian_array.size()); - ShiftAndAdd(dec.fractional_digits, little_endian_array.data(), - little_endian_array.size()); *out = Decimal256(little_endian_array); - if (dec.sign == '-') { out->Negate(); } } - return Status::OK(); + return status; } Status Decimal256::FromString(const std::string& s, Decimal256* out, int32_t* precision, @@ -768,4 +755,95 @@ std::ostream& operator<<(std::ostream& os, const Decimal256& decimal) { os << decimal.ToIntegerString(); return os; } + +template +DecimalAnyWidth::DecimalAnyWidth(const std::string& str) : DecimalAnyWidth() { + *this = DecimalAnyWidth::FromString(str).ValueOrDie(); +} + +template +std::string DecimalAnyWidth::ToIntegerString() const { + std::stringstream ss; + ss << this->Value(); + return ss.str(); +} + +template +std::string DecimalAnyWidth::ToString(int32_t scale) const { + std::string str(ToIntegerString()); + AdjustIntegerStringWithScale(scale, &str); + return str; +} + +template +Status DecimalAnyWidth::FromString(const util::string_view& s, DecimalAnyWidth* out, + int32_t* precision, int32_t* scale) { + std::array little_endian_array = {0}; + DecimalComponents dec; + + auto status = FromStringToArray(s, dec, little_endian_array.data(), 1, precision, scale); + if (status != Status::OK()) { + return status; + } + + if (out != nullptr) { + *out = DecimalAnyWidth(static_cast(little_endian_array[0])); + + if (scale != nullptr && *scale < 0) { + *out *= BasicDecimalAnyWidth::GetScaleMultiplier(*scale); + } + + if (dec.sign == '-') { + out->Negate(); + } + } + + if (scale != nullptr && *scale < 0) { + if (precision != nullptr) { + *precision -= *scale; + } + *scale = 0; + } + + return status; +} + +template +Status DecimalAnyWidth::FromString(const std::string& s, DecimalAnyWidth* out, int32_t* precision, + int32_t* scale) { + return FromString(util::string_view(s), out, precision, scale); +} + +template +Status DecimalAnyWidth::FromString(const char* s, DecimalAnyWidth* out, int32_t* precision, + int32_t* scale) { + return FromString(util::string_view(s), out, precision, scale); +} + +template +Result::_DecimalType> DecimalAnyWidth::FromString(const util::string_view& s) { + _DecimalType out; + RETURN_NOT_OK(FromString(s, &out, nullptr, nullptr)); + return std::move(out); +} + +template +Result::_DecimalType> DecimalAnyWidth::FromString(const std::string& s) { + return FromString(util::string_view(s)); +} + +template +Result::_DecimalType> DecimalAnyWidth::FromString(const char* s) { + return FromString(util::string_view(s)); +} + +template +Status DecimalAnyWidth::ToArrowStatus(DecimalStatus dstatus) const { + return arrow::ToArrowStatus(dstatus, width); +} + +template class DecimalAnyWidth<16>; +template class DecimalAnyWidth<32>; +template class DecimalAnyWidth<64>; + } // namespace arrow diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index 3d41ae460e418..d467754943fe3 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -26,6 +26,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/basic_decimal.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/util/string_view.h" namespace arrow { @@ -257,4 +258,77 @@ class ARROW_EXPORT Decimal256 : public BasicDecimal256 { Status ToArrowStatus(DecimalStatus dstatus) const; }; + +template +class DecimalAnyWidth : public BasicDecimalAnyWidth { + public: + + using _DecimalType = typename DecimalTypeTraits::ValueType; + using ValueType = typename BasicDecimalAnyWidth::ValueType; + + /// \cond FALSE + // (need to avoid a duplicate definition in Sphinx) + using BasicDecimalAnyWidth::BasicDecimalAnyWidth; + /// \endcond + + /// \brief constructor creates a Decimal256 from a BasicDecimal128. + constexpr DecimalAnyWidth(const BasicDecimalAnyWidth& value) noexcept : BasicDecimalAnyWidth(value) {} + + /// \brief Parse the number from a base 10 string representation. + explicit DecimalAnyWidth(const std::string& value); + + /// \brief Empty constructor creates a Decimal256 with a value of 0. + // This is required on some older compilers. + constexpr DecimalAnyWidth() noexcept : BasicDecimalAnyWidth() {} + + /// \brief Convert the Decimal256 value to a base 10 decimal string with the given + /// scale. + std::string ToString(int32_t scale) const; + + /// \brief Convert the value to an integer string + std::string ToIntegerString() const; + + /// \brief Convert a decimal string to a Decimal256 value, optionally including + /// precision and scale if they're passed in and not null. + static Status FromString(const util::string_view& s, DecimalAnyWidth* out, + int32_t* precision, int32_t* scale = NULLPTR); + static Status FromString(const std::string& s, DecimalAnyWidth* out, int32_t* precision, + int32_t* scale = NULLPTR); + static Status FromString(const char* s, DecimalAnyWidth* out, int32_t* precision, + int32_t* scale = NULLPTR); + static Result<_DecimalType> FromString(const util::string_view& s); + static Result<_DecimalType> FromString(const std::string& s); + static Result<_DecimalType> FromString(const char* s); + + /// \brief Convert Decimal256 from one scale to another + Result<_DecimalType> Rescale(int32_t original_scale, int32_t new_scale) const { + _DecimalType out; + auto dstatus = BasicDecimalAnyWidth::Rescale(original_scale, new_scale, &out); + ARROW_RETURN_NOT_OK(ToArrowStatus(dstatus)); + return std::move(out); + } + + friend ARROW_EXPORT std::ostream& operator<<(std::ostream& os, + const DecimalAnyWidth& decimal) { + os << decimal.ToIntegerString(); + return os; + } + + private: + /// Converts internal error code to Status + Status ToArrowStatus(DecimalStatus dstatus) const; +}; + +// class ARROW_EXPORT Decimal16 : DecimalAnyWidth<16> { +// using DecimalAnyWidth<16>::DecimalAnyWidth; +// }; + +// class ARROW_EXPORT Decimal32 : DecimalAnyWidth<32> { +// using DecimalAnyWidth<32>::DecimalAnyWidth; +// }; + +// class ARROW_EXPORT Decimal64 : DecimalAnyWidth<64> { +// using DecimalAnyWidth<64>::DecimalAnyWidth; +// }; + } // namespace arrow diff --git a/cpp/src/arrow/util/decimal_meta.h b/cpp/src/arrow/util/decimal_meta.h index e5eb4d907a22a..59bb1d2206b23 100644 --- a/cpp/src/arrow/util/decimal_meta.h +++ b/cpp/src/arrow/util/decimal_meta.h @@ -19,9 +19,41 @@ namespace arrow { +template +struct IntTypes {}; + +#define IntTypes_DECL(bit_width) \ +template<> \ +struct IntTypes{ \ + using signed_type = int##bit_width##_t; \ + using unsigned_type = uint##bit_width##_t; \ +}; + +IntTypes_DECL(64); +IntTypes_DECL(32); +IntTypes_DECL(16); + template struct DecimalMeta; +template<> +struct DecimalMeta<16> { + static constexpr const char* name = "decimal16"; + static constexpr int32_t max_precision = 5; +}; + +template<> +struct DecimalMeta<32> { + static constexpr const char* name = "decimal32"; + static constexpr int32_t max_precision = 10; +}; + +template<> +struct DecimalMeta<64> { + static constexpr const char* name = "decimal64"; + static constexpr int32_t max_precision = 19; +}; + template<> struct DecimalMeta<128> { static constexpr const char* name = "decimal"; diff --git a/cpp/src/arrow/util/decimal_scale_multipliers.h b/cpp/src/arrow/util/decimal_scale_multipliers.h new file mode 100644 index 0000000000000..23067bc860a02 --- /dev/null +++ b/cpp/src/arrow/util/decimal_scale_multipliers.h @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/util/basic_decimal.h" + +namespace arrow { + +template +struct ScaleMultipliersAnyWidth {}; + +#define DECL_ANY_SCALE_MULTIPLIERS(width) \ +template<> \ +struct ScaleMultipliersAnyWidth { \ + static const int##width##_t value[]; \ +}; \ +const int##width##_t ScaleMultipliersAnyWidth::value[] = { \ + int##width##_t(1LL), \ + int##width##_t(10LL), \ + int##width##_t(100LL), \ + int##width##_t(1000LL), \ + int##width##_t(10000LL), \ + int##width##_t(100000LL), \ + int##width##_t(1000000LL), \ + int##width##_t(10000000LL), \ + int##width##_t(100000000LL), \ + int##width##_t(1000000000LL), \ + int##width##_t(10000000000LL), \ + int##width##_t(100000000000LL), \ + int##width##_t(1000000000000LL), \ + int##width##_t(10000000000000LL), \ + int##width##_t(100000000000000LL), \ + int##width##_t(1000000000000000LL), \ + int##width##_t(10000000000000000LL), \ + int##width##_t(100000000000000000LL), \ + int##width##_t(1000000000000000000LL) \ +}; + +DECL_ANY_SCALE_MULTIPLIERS(16) +DECL_ANY_SCALE_MULTIPLIERS(32) +DECL_ANY_SCALE_MULTIPLIERS(64) + +#undef DECL_ANY_SCALE_MULTIPLIERS + +static const BasicDecimal128 ScaleMultipliers128[] = { + BasicDecimal128(1LL), + BasicDecimal128(10LL), + BasicDecimal128(100LL), + BasicDecimal128(1000LL), + BasicDecimal128(10000LL), + BasicDecimal128(100000LL), + BasicDecimal128(1000000LL), + BasicDecimal128(10000000LL), + BasicDecimal128(100000000LL), + BasicDecimal128(1000000000LL), + BasicDecimal128(10000000000LL), + BasicDecimal128(100000000000LL), + BasicDecimal128(1000000000000LL), + BasicDecimal128(10000000000000LL), + BasicDecimal128(100000000000000LL), + BasicDecimal128(1000000000000000LL), + BasicDecimal128(10000000000000000LL), + BasicDecimal128(100000000000000000LL), + BasicDecimal128(1000000000000000000LL), + BasicDecimal128(0LL, 10000000000000000000ULL), + BasicDecimal128(5LL, 7766279631452241920ULL), + BasicDecimal128(54LL, 3875820019684212736ULL), + BasicDecimal128(542LL, 1864712049423024128ULL), + BasicDecimal128(5421LL, 200376420520689664ULL), + BasicDecimal128(54210LL, 2003764205206896640ULL), + BasicDecimal128(542101LL, 1590897978359414784ULL), + BasicDecimal128(5421010LL, 15908979783594147840ULL), + BasicDecimal128(54210108LL, 11515845246265065472ULL), + BasicDecimal128(542101086LL, 4477988020393345024ULL), + BasicDecimal128(5421010862LL, 7886392056514347008ULL), + BasicDecimal128(54210108624LL, 5076944270305263616ULL), + BasicDecimal128(542101086242LL, 13875954555633532928ULL), + BasicDecimal128(5421010862427LL, 9632337040368467968ULL), + BasicDecimal128(54210108624275LL, 4089650035136921600ULL), + BasicDecimal128(542101086242752LL, 4003012203950112768ULL), + BasicDecimal128(5421010862427522LL, 3136633892082024448ULL), + BasicDecimal128(54210108624275221LL, 12919594847110692864ULL), + BasicDecimal128(542101086242752217LL, 68739955140067328ULL), + BasicDecimal128(5421010862427522170LL, 687399551400673280ULL)}; + + +static const BasicDecimal128 ScaleMultipliersHalf128[] = { + BasicDecimal128(0ULL), + BasicDecimal128(5ULL), + BasicDecimal128(50ULL), + BasicDecimal128(500ULL), + BasicDecimal128(5000ULL), + BasicDecimal128(50000ULL), + BasicDecimal128(500000ULL), + BasicDecimal128(5000000ULL), + BasicDecimal128(50000000ULL), + BasicDecimal128(500000000ULL), + BasicDecimal128(5000000000ULL), + BasicDecimal128(50000000000ULL), + BasicDecimal128(500000000000ULL), + BasicDecimal128(5000000000000ULL), + BasicDecimal128(50000000000000ULL), + BasicDecimal128(500000000000000ULL), + BasicDecimal128(5000000000000000ULL), + BasicDecimal128(50000000000000000ULL), + BasicDecimal128(500000000000000000ULL), + BasicDecimal128(5000000000000000000ULL), + BasicDecimal128(2LL, 13106511852580896768ULL), + BasicDecimal128(27LL, 1937910009842106368ULL), + BasicDecimal128(271LL, 932356024711512064ULL), + BasicDecimal128(2710LL, 9323560247115120640ULL), + BasicDecimal128(27105LL, 1001882102603448320ULL), + BasicDecimal128(271050LL, 10018821026034483200ULL), + BasicDecimal128(2710505LL, 7954489891797073920ULL), + BasicDecimal128(27105054LL, 5757922623132532736ULL), + BasicDecimal128(271050543LL, 2238994010196672512ULL), + BasicDecimal128(2710505431LL, 3943196028257173504ULL), + BasicDecimal128(27105054312LL, 2538472135152631808ULL), + BasicDecimal128(271050543121LL, 6937977277816766464ULL), + BasicDecimal128(2710505431213LL, 14039540557039009792ULL), + BasicDecimal128(27105054312137LL, 11268197054423236608ULL), + BasicDecimal128(271050543121376LL, 2001506101975056384ULL), + BasicDecimal128(2710505431213761LL, 1568316946041012224ULL), + BasicDecimal128(27105054312137610LL, 15683169460410122240ULL), + BasicDecimal128(271050543121376108LL, 9257742014424809472ULL), + BasicDecimal128(2710505431213761085LL, 343699775700336640ULL)}; + +} // namespace arrow diff --git a/cpp/src/arrow/util/decimal_test.cc b/cpp/src/arrow/util/decimal_test.cc index 40ae49da2ce60..0dd99c7f10469 100644 --- a/cpp/src/arrow/util/decimal_test.cc +++ b/cpp/src/arrow/util/decimal_test.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -899,7 +900,6 @@ std::vector GetRandomNumbers(int32_t size) { auto rand = random::RandomArrayGenerator(0x5487655); auto x_array = rand.Numeric(size, static_cast(0), std::numeric_limits::max(), 0); - auto x_ptr = x_array->data()->template GetValues(1); std::vector ret; for (int i = 0; i < size; ++i) { @@ -1567,4 +1567,189 @@ TEST_P(Decimal256ToStringTest, ToString) { INSTANTIATE_TEST_SUITE_P(Decimal256ToStringTest, Decimal256ToStringTest, ::testing::ValuesIn(kToStringTestData)); + +// DecimalAnyWidth + +template +class DecimalAnyWidthTest : public ::testing::Test { }; + +template +class Decimal16Test : public ::testing::Test { }; + +template +class Decimal32Test : public ::testing::Test { }; + +template +class Decimal64Test : public ::testing::Test { }; + +using DecimalTypes = ::testing::Types; + +struct DecimalFromStringParams { + std::string value; + int expected_value; + int32_t expected_scale; + int32_t expected_precision; +}; + +static const std::vector DecimalFromStringParamsList = { + {"1234", 1234, 0, 4}, + {"12.34", 1234, 2, 4}, + {"+12.34", 1234, 2, 4}, + {"-12.34", -1234, 2, 4}, + {".0000", 0, 4, 4} +}; + +TYPED_TEST_SUITE(DecimalAnyWidthTest, DecimalTypes); + +TYPED_TEST(DecimalAnyWidthTest, FromString) { + for (auto& param : DecimalFromStringParamsList) { + TypeParam d; + int precision, scale; + + ASSERT_OK(TypeParam::FromString(param.value, &d, &precision, &scale)); + + ASSERT_EQ(param.expected_value, d); + ASSERT_EQ(param.expected_precision, precision); + ASSERT_EQ(param.expected_scale, scale); + } +} + +TYPED_TEST(DecimalAnyWidthTest, FromBool) { + ASSERT_EQ(TypeParam(0), TypeParam(false)); + ASSERT_EQ(TypeParam(1), TypeParam(true)); +} + +using Decimal16Types = + ::testing::Types; + +using Decimal32Types = + ::testing::Types; + +using Decimal64Types = + ::testing::Types; + +TYPED_TEST_SUITE(Decimal16Test, Decimal16Types); + +TYPED_TEST(Decimal16Test, Decimal16Types) { + TypeParam value = 42; + TypeParam max_value = std::numeric_limits::max(); + TypeParam min_value = std::numeric_limits::min(); + + Decimal16 d(value); + ASSERT_EQ(value, d); + + // Constructing from int will cause overflow + if (std::is_same::value) { + Decimal16 max_value_d(max_value); + ASSERT_EQ(static_cast(max_value), max_value_d); + + Decimal16 min_value_d(min_value); + ASSERT_EQ(static_cast(min_value), min_value_d); + } + else { + Decimal16 max_value_d(max_value); + ASSERT_EQ(max_value, max_value_d); + + Decimal16 min_value_d(min_value); + ASSERT_EQ(min_value, min_value_d); + } +} + +TYPED_TEST_SUITE(Decimal32Test, Decimal32Types); + +TYPED_TEST(Decimal32Test, Decimal32Types) { + TypeParam value = 42; + TypeParam max_value = std::numeric_limits::max(); + TypeParam min_value = std::numeric_limits::min(); + + Decimal32 d(value); + ASSERT_EQ(value, d); + + Decimal32 max_value_d(max_value); + ASSERT_EQ(max_value, max_value_d); + + Decimal32 min_value_d(min_value); + ASSERT_EQ(min_value, min_value_d); +} + +TYPED_TEST_SUITE(Decimal64Test, Decimal64Types); + +TYPED_TEST(Decimal64Test, Decimal64Types) { + TypeParam value = 42; + TypeParam max_value = std::numeric_limits::max(); + TypeParam min_value = std::numeric_limits::min(); + + Decimal64 d(value); + ASSERT_EQ(value, d); + + Decimal64 max_value_d(max_value); + ASSERT_EQ(max_value, max_value_d); + + Decimal64 min_value_d(min_value); + ASSERT_EQ(min_value, min_value_d); +} + +static const std::vector DecimalAnyWidthValues = { -2, -1, 0, 1, 2}; + +TYPED_TEST(DecimalAnyWidthTest, ComparatorTest) { + for (size_t i=0; i j, d1 > d2); + ASSERT_EQ(i >= j, d1 >= d2); + } + } +} + +TYPED_TEST(DecimalAnyWidthTest, UpCast) { + TypeParam d(42); + Decimal64 d64(d); + + ASSERT_EQ(d, d64); +} + +template +struct DecimalAnyWidthBinaryParams { + static const std::vector>> value; +}; + +template +const std::vector>> DecimalAnyWidthBinaryParams::value = { + {"+", [](T x, T y) -> T { return x + y;} }, + {"-", [](T x, T y) -> T { return x - y;} }, + {"*", [](T x, T y) -> T { return x * y;} }, + {"/", [](T x, T y) -> T { return y == 0? 0 : x / y;} }, + {"%", [](T x, T y) -> T { return y == 0? 0 : x % y;} }, +}; + +TYPED_TEST(DecimalAnyWidthTest, BinaryOperations) { + using ValueType = typename arrow::DecimalAnyWidthTest_BinaryOperations_Test::TypeParam::ValueType; + using ArrowValueType = typename arrow::CTypeTraits::ArrowType; + + auto DecimalFns = DecimalAnyWidthBinaryParams::value; + auto NumericFns = DecimalAnyWidthBinaryParams::value; + + for (size_t i = 0; i < DecimalFns.size(); i++){ + for (auto x : GetRandomNumbers(8)) { + for (auto y : GetRandomNumbers(8)) { + TypeParam d1(x), d2(y); + ASSERT_EQ(NumericFns[i].second(x, y), DecimalFns[i].second(d1, d2)) + << d1 << DecimalFns[i].first << " " << d2 << " " << " != " << NumericFns[i].second(x, y); + } + } + } +} + + + + + } // namespace arrow diff --git a/cpp/src/arrow/util/decimal_type_traits.h b/cpp/src/arrow/util/decimal_type_traits.h index fd8e9a5e1ff8d..df06fa9cb2549 100644 --- a/cpp/src/arrow/util/decimal_type_traits.h +++ b/cpp/src/arrow/util/decimal_type_traits.h @@ -35,6 +35,9 @@ struct DecimalTypeTraits { \ using ValueType = Decimal##width; \ }; +DECIMAL_TYPE_TRAITS_DECL(16) +DECIMAL_TYPE_TRAITS_DECL(32) +DECIMAL_TYPE_TRAITS_DECL(64) DECIMAL_TYPE_TRAITS_DECL(128) DECIMAL_TYPE_TRAITS_DECL(256) diff --git a/cpp/src/arrow/util/int_util_internal.h b/cpp/src/arrow/util/int_util_internal.h index de39229cfdd9f..246e25c15eafa 100644 --- a/cpp/src/arrow/util/int_util_internal.h +++ b/cpp/src/arrow/util/int_util_internal.h @@ -79,6 +79,14 @@ SignedInt SafeSignedSubtract(SignedInt u, SignedInt v) { static_cast(v)); } +/// Signed multiply with well-defined behaviour on overflow (as unsigned) +template +SignedInt SafeSignedMultiply(SignedInt u, SignedInt v) { + using UnsignedInt = typename std::make_unsigned::type; + return static_cast(static_cast(u) * + static_cast(v)); +} + /// Signed left shift with well-defined behaviour on negative numbers or overflow template SignedInt SafeLeftShift(SignedInt u, Shift shift) { diff --git a/cpp/src/arrow/visitor.cc b/cpp/src/arrow/visitor.cc index 851785081c792..70b2d3969d36e 100644 --- a/cpp/src/arrow/visitor.cc +++ b/cpp/src/arrow/visitor.cc @@ -66,6 +66,9 @@ ARRAY_VISITOR_DEFAULT(StructArray) ARRAY_VISITOR_DEFAULT(SparseUnionArray) ARRAY_VISITOR_DEFAULT(DenseUnionArray) ARRAY_VISITOR_DEFAULT(DictionaryArray) +ARRAY_VISITOR_DEFAULT(Decimal16Array) +ARRAY_VISITOR_DEFAULT(Decimal32Array) +ARRAY_VISITOR_DEFAULT(Decimal64Array) ARRAY_VISITOR_DEFAULT(Decimal128Array) ARRAY_VISITOR_DEFAULT(Decimal256Array) ARRAY_VISITOR_DEFAULT(ExtensionArray) @@ -106,6 +109,9 @@ TYPE_VISITOR_DEFAULT(TimestampType) TYPE_VISITOR_DEFAULT(DayTimeIntervalType) TYPE_VISITOR_DEFAULT(MonthIntervalType) TYPE_VISITOR_DEFAULT(DurationType) +TYPE_VISITOR_DEFAULT(Decimal16Type) +TYPE_VISITOR_DEFAULT(Decimal32Type) +TYPE_VISITOR_DEFAULT(Decimal64Type) TYPE_VISITOR_DEFAULT(Decimal128Type) TYPE_VISITOR_DEFAULT(Decimal256Type) TYPE_VISITOR_DEFAULT(ListType) @@ -155,6 +161,9 @@ SCALAR_VISITOR_DEFAULT(TimestampScalar) SCALAR_VISITOR_DEFAULT(DayTimeIntervalScalar) SCALAR_VISITOR_DEFAULT(MonthIntervalScalar) SCALAR_VISITOR_DEFAULT(DurationScalar) +SCALAR_VISITOR_DEFAULT(Decimal16Scalar) +SCALAR_VISITOR_DEFAULT(Decimal32Scalar) +SCALAR_VISITOR_DEFAULT(Decimal64Scalar) SCALAR_VISITOR_DEFAULT(Decimal128Scalar) SCALAR_VISITOR_DEFAULT(Decimal256Scalar) SCALAR_VISITOR_DEFAULT(ListScalar) diff --git a/cpp/src/arrow/visitor.h b/cpp/src/arrow/visitor.h index 0382e461199c7..c0281992d5cdb 100644 --- a/cpp/src/arrow/visitor.h +++ b/cpp/src/arrow/visitor.h @@ -53,6 +53,9 @@ class ARROW_EXPORT ArrayVisitor { virtual Status Visit(const DayTimeIntervalArray& array); virtual Status Visit(const MonthIntervalArray& array); virtual Status Visit(const DurationArray& array); + virtual Status Visit(const Decimal16Array& array); + virtual Status Visit(const Decimal32Array& array); + virtual Status Visit(const Decimal64Array& array); virtual Status Visit(const Decimal128Array& array); virtual Status Visit(const Decimal256Array& array); virtual Status Visit(const ListArray& array); @@ -96,6 +99,9 @@ class ARROW_EXPORT TypeVisitor { virtual Status Visit(const MonthIntervalType& type); virtual Status Visit(const DayTimeIntervalType& type); virtual Status Visit(const DurationType& type); + virtual Status Visit(const Decimal16Type& type); + virtual Status Visit(const Decimal32Type& type); + virtual Status Visit(const Decimal64Type& type); virtual Status Visit(const Decimal128Type& type); virtual Status Visit(const Decimal256Type& type); virtual Status Visit(const ListType& type); @@ -139,6 +145,9 @@ class ARROW_EXPORT ScalarVisitor { virtual Status Visit(const DayTimeIntervalScalar& scalar); virtual Status Visit(const MonthIntervalScalar& scalar); virtual Status Visit(const DurationScalar& scalar); + virtual Status Visit(const Decimal16Scalar& scalar); + virtual Status Visit(const Decimal32Scalar& scalar); + virtual Status Visit(const Decimal64Scalar& scalar); virtual Status Visit(const Decimal128Scalar& scalar); virtual Status Visit(const Decimal256Scalar& scalar); virtual Status Visit(const ListScalar& scalar); diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h index 132c35aeaa134..3e1f86f15850e 100644 --- a/cpp/src/arrow/visitor_inline.h +++ b/cpp/src/arrow/visitor_inline.h @@ -67,6 +67,9 @@ namespace arrow { ACTION(Time64); \ ACTION(MonthInterval); \ ACTION(DayTimeInterval); \ + ACTION(Decimal16); \ + ACTION(Decimal32); \ + ACTION(Decimal64); \ ACTION(Decimal128); \ ACTION(Decimal256); \ ACTION(List); \ From 4d1519892f8c925d8dac5e1fb6caeb475719a601 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Mon, 28 Dec 2020 15:31:03 +0300 Subject: [PATCH 3/8] Added decimal convertions from/to python Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/python/arrow_to_pandas.cc | 38 +++++------------ cpp/src/arrow/python/decimal.cc | 30 +++++++++++++ cpp/src/arrow/python/decimal.h | 57 +++++++++++++++++++++++++ cpp/src/arrow/python/python_test.cc | 41 +++++++++++++----- cpp/src/arrow/python/python_to_arrow.cc | 18 ++++++++ cpp/src/arrow/util/decimal.h | 4 +- 6 files changed, 147 insertions(+), 41 deletions(-) diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index 092452850301f..665df75df6a08 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -167,6 +167,9 @@ static inline bool ListTypeSupported(const DataType& type) { case Type::UINT64: case Type::FLOAT: case Type::DOUBLE: + case Type::DECIMAL16: + case Type::DECIMAL32: + case Type::DECIMAL64: case Type::DECIMAL128: case Type::DECIMAL256: case Type::BINARY: @@ -1021,7 +1024,7 @@ struct ObjectWriterVisitor { } template - enable_if_t::value || is_fixed_size_binary_type::value, + enable_if_t<(is_base_binary_type::value || is_fixed_size_binary_type::value) && !is_decimal_type::value, Status> Visit(const Type& type) { auto WrapValue = [](const util::string_view& view, PyObject** out) { @@ -1094,7 +1097,8 @@ struct ObjectWriterVisitor { return Status::OK(); } - Status Visit(const Decimal128Type& type) { + template + Status Visit(const BaseDecimalType& type) { OwnedRef decimal; OwnedRef Decimal; RETURN_NOT_OK(internal::ImportModule("decimal", &decimal)); @@ -1102,32 +1106,7 @@ struct ObjectWriterVisitor { PyObject* decimal_constructor = Decimal.obj(); for (int c = 0; c < data.num_chunks(); c++) { - const auto& arr = checked_cast(*data.chunk(c)); - - for (int64_t i = 0; i < arr.length(); ++i) { - if (arr.IsNull(i)) { - Py_INCREF(Py_None); - *out_values++ = Py_None; - } else { - *out_values++ = - internal::DecimalFromString(decimal_constructor, arr.FormatValue(i)); - RETURN_IF_PYERROR(); - } - } - } - - return Status::OK(); - } - - Status Visit(const Decimal256Type& type) { - OwnedRef decimal; - OwnedRef Decimal; - RETURN_NOT_OK(internal::ImportModule("decimal", &decimal)); - RETURN_NOT_OK(internal::ImportFromModule(decimal.obj(), "Decimal", &Decimal)); - PyObject* decimal_constructor = Decimal.obj(); - - for (int c = 0; c < data.num_chunks(); c++) { - const auto& arr = checked_cast(*data.chunk(c)); + const auto& arr = checked_cast&>(*data.chunk(c)); for (int64_t i = 0; i < arr.length(); ++i) { if (arr.IsNull(i)) { @@ -1871,6 +1850,9 @@ static Status GetPandasWriterType(const ChunkedArray& data, const PandasOptions& case Type::STRUCT: // fall through case Type::TIME32: // fall through case Type::TIME64: // fall through + case Type::DECIMAL16: // fall through + case Type::DECIMAL32: // fall through + case Type::DECIMAL64: // fall through case Type::DECIMAL128: // fall through case Type::DECIMAL256: // fall through *output_type = PandasWriter::OBJECT; diff --git a/cpp/src/arrow/python/decimal.cc b/cpp/src/arrow/python/decimal.cc index 67389095b946e..25d8af59ee15e 100644 --- a/cpp/src/arrow/python/decimal.cc +++ b/cpp/src/arrow/python/decimal.cc @@ -166,6 +166,36 @@ Status InternalDecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, } // namespace +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal16* out) { + return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); +} + +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, + Decimal16* out) { + return InternalDecimalFromPyObject(obj, arrow_type, out); +} + +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal32* out) { + return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); +} + +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, + Decimal32* out) { + return InternalDecimalFromPyObject(obj, arrow_type, out); +} + +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal64* out) { + return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); +} + +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, + Decimal64* out) { + return InternalDecimalFromPyObject(obj, arrow_type, out); +} + Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, Decimal128* out) { return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); diff --git a/cpp/src/arrow/python/decimal.h b/cpp/src/arrow/python/decimal.h index 1187037aed29e..08ce344f86469 100644 --- a/cpp/src/arrow/python/decimal.h +++ b/cpp/src/arrow/python/decimal.h @@ -24,6 +24,12 @@ namespace arrow { +template +class DecimalAnyWidth; + +using Decimal16 = DecimalAnyWidth<16>; +using Decimal32 = DecimalAnyWidth<32>; +using Decimal64 = DecimalAnyWidth<64>; class Decimal128; class Decimal256; @@ -56,6 +62,57 @@ ARROW_PYTHON_EXPORT PyObject* DecimalFromString(PyObject* decimal_constructor, const std::string& decimal_string); +// \brief Convert a Python decimal to an Arrow Decimal16 object +// \param[in] python_decimal A Python decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal16 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal16* out); + +// \brief Convert a Python object to an Arrow Decimal16 object +// \param[in] python_decimal A Python int or decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal16 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal16* out); + +// \brief Convert a Python decimal to an Arrow Decimal32 object +// \param[in] python_decimal A Python decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal32 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal32* out); + +// \brief Convert a Python object to an Arrow Decimal32 object +// \param[in] python_decimal A Python int or decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal32 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal32* out); + +// \brief Convert a Python decimal to an Arrow Decimal64 object +// \param[in] python_decimal A Python decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal64 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type, + Decimal64* out); + +// \brief Convert a Python object to an Arrow Decimal64 object +// \param[in] python_decimal A Python int or decimal.Decimal instance +// \param[in] arrow_type An instance of arrow::DecimalType +// \param[out] out A pointer to a Decimal64 +// \return The status of the operation +ARROW_PYTHON_EXPORT +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal64* out); + // \brief Convert a Python decimal to an Arrow Decimal128 object // \param[in] python_decimal A Python decimal.Decimal instance // \param[in] arrow_type An instance of arrow::DecimalType diff --git a/cpp/src/arrow/python/python_test.cc b/cpp/src/arrow/python/python_test.cc index 33e0ee9b1c9ab..b4a45831f6b9b 100644 --- a/cpp/src/arrow/python/python_test.cc +++ b/cpp/src/arrow/python/python_test.cc @@ -28,6 +28,7 @@ #include "arrow/table.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/decimal.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/util/optional.h" #include "arrow/python/arrow_to_pandas.h" @@ -358,6 +359,12 @@ void DecimalTestFromPythonDecimalRescale(std::shared_ptr type, TEST_F(DecimalTest, FromPythonDecimalRescaleNotTruncateable) { // We fail when truncating values that would lose data if cast to a decimal type with // lower scale + DecimalTestFromPythonDecimalRescale(::arrow::decimal16(5, 2), + this->CreatePythonDecimal("1.001"), {}); + DecimalTestFromPythonDecimalRescale(::arrow::decimal32(10, 2), + this->CreatePythonDecimal("1.001"), {}); + DecimalTestFromPythonDecimalRescale(::arrow::decimal64(10, 2), + this->CreatePythonDecimal("1.001"), {}); DecimalTestFromPythonDecimalRescale(::arrow::decimal128(10, 2), this->CreatePythonDecimal("1.001"), {}); DecimalTestFromPythonDecimalRescale(::arrow::decimal256(10, 2), @@ -367,6 +374,12 @@ TEST_F(DecimalTest, FromPythonDecimalRescaleNotTruncateable) { TEST_F(DecimalTest, FromPythonDecimalRescaleTruncateable) { // We allow truncation of values that do not lose precision when dividing by 10 * the // difference between the scales, e.g., 1.000 -> 1.00 + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal16(5, 2), this->CreatePythonDecimal("1.000"), 100); + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal32(10, 2), this->CreatePythonDecimal("1.000"), 100); + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal64(10, 2), this->CreatePythonDecimal("1.000"), 100); DecimalTestFromPythonDecimalRescale( ::arrow::decimal128(10, 2), this->CreatePythonDecimal("1.000"), 100); DecimalTestFromPythonDecimalRescale( @@ -374,25 +387,31 @@ TEST_F(DecimalTest, FromPythonDecimalRescaleTruncateable) { } TEST_F(DecimalTest, FromPythonNegativeDecimalRescale) { + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal16(5, 4), this->CreatePythonDecimal("-1.000"), -10000); + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal32(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000); + DecimalTestFromPythonDecimalRescale( + ::arrow::decimal64(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000); DecimalTestFromPythonDecimalRescale( ::arrow::decimal128(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000); DecimalTestFromPythonDecimalRescale( ::arrow::decimal256(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000); } -TEST_F(DecimalTest, Decimal128FromPythonInteger) { - Decimal128 value; - OwnedRef python_long(PyLong_FromLong(42)); - auto type = ::arrow::decimal128(10, 2); - const auto& decimal_type = checked_cast(*type); - ASSERT_OK(internal::DecimalFromPyObject(python_long.obj(), decimal_type, &value)); - ASSERT_EQ(4200, value); -} +template +class DecimalTestConversion : public testing::Test {}; +using DecimalTypes = ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, DecimalTypeTraits<128>, DecimalTypeTraits<256>>; -TEST_F(DecimalTest, Decimal256FromPythonInteger) { - Decimal256 value; +TYPED_TEST_SUITE(DecimalTestConversion, DecimalTypes); + +TYPED_TEST(DecimalTestConversion, Basics) { + using TypeClass = typename TypeParam::TypeClass; + using ValueType = typename TypeParam::ValueType; + + ValueType value; OwnedRef python_long(PyLong_FromLong(42)); - auto type = ::arrow::decimal256(10, 2); + auto type = std::make_shared(5, 2); const auto& decimal_type = checked_cast(*type); ASSERT_OK(internal::DecimalFromPyObject(python_long.obj(), decimal_type, &value)); ASSERT_EQ(4200, value); diff --git a/cpp/src/arrow/python/python_to_arrow.cc b/cpp/src/arrow/python/python_to_arrow.cc index b136bec9709d6..2d9de621af2b9 100644 --- a/cpp/src/arrow/python/python_to_arrow.cc +++ b/cpp/src/arrow/python/python_to_arrow.cc @@ -164,6 +164,24 @@ class PyValue { return value; } + static Result Convert(const Decimal16Type* type, const O&, I obj) { + Decimal16 value; + RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value)); + return value; + } + + static Result Convert(const Decimal32Type* type, const O&, I obj) { + Decimal32 value; + RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value)); + return value; + } + + static Result Convert(const Decimal64Type* type, const O&, I obj) { + Decimal64 value; + RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value)); + return value; + } + static Result Convert(const Decimal128Type* type, const O&, I obj) { Decimal128 value; RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value)); diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index d467754943fe3..afc089b21333a 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -260,7 +260,7 @@ class ARROW_EXPORT Decimal256 : public BasicDecimal256 { template -class DecimalAnyWidth : public BasicDecimalAnyWidth { +class ARROW_EXPORT DecimalAnyWidth : public BasicDecimalAnyWidth { public: using _DecimalType = typename DecimalTypeTraits::ValueType; @@ -308,7 +308,7 @@ class DecimalAnyWidth : public BasicDecimalAnyWidth { return std::move(out); } - friend ARROW_EXPORT std::ostream& operator<<(std::ostream& os, + friend std::ostream& operator<<(std::ostream& os, const DecimalAnyWidth& decimal) { os << decimal.ToIntegerString(); return os; From a108f38e9957edc43eb4103bb51105e706e4d632 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Thu, 14 Jan 2021 13:19:06 +0300 Subject: [PATCH 4/8] Attempt to fix linking error on Windows Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/array/array_decimal.cc | 10 +++++----- cpp/src/arrow/array/array_decimal.h | 2 +- cpp/src/arrow/array/builder_decimal.cc | 10 +++++----- cpp/src/arrow/array/builder_decimal.h | 2 +- cpp/src/arrow/type.h | 2 +- cpp/src/arrow/util/basic_decimal.cc | 6 +++--- cpp/src/arrow/util/basic_decimal.h | 2 +- cpp/src/arrow/util/decimal.cc | 6 +++--- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/array/array_decimal.cc b/cpp/src/arrow/array/array_decimal.cc index f05fe7960f2db..5e7c6c85beab8 100644 --- a/cpp/src/arrow/array/array_decimal.cc +++ b/cpp/src/arrow/array/array_decimal.cc @@ -46,10 +46,10 @@ std::string BaseDecimalArray::FormatValue(int64_t i) const { return value.ToString(type_.scale()); } -template class BaseDecimalArray<16>; -template class BaseDecimalArray<32>; -template class BaseDecimalArray<64>; -template class BaseDecimalArray<128>; -template class BaseDecimalArray<256>; +template class ARROW_EXPORT BaseDecimalArray<16>; +template class ARROW_EXPORT BaseDecimalArray<32>; +template class ARROW_EXPORT BaseDecimalArray<64>; +template class ARROW_EXPORT BaseDecimalArray<128>; +template class ARROW_EXPORT BaseDecimalArray<256>; } // namespace arrow diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index c9a0aff1bf36c..b7c8515304863 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -31,7 +31,7 @@ namespace arrow { /// Template Array class for decimal data template -class BaseDecimalArray : public FixedSizeBinaryArray { +class ARROW_EXPORT BaseDecimalArray : public FixedSizeBinaryArray { public: using TypeClass = typename DecimalTypeTraits::TypeClass; using ValueType = typename DecimalTypeTraits::ValueType; diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index f29fe6cc8c205..ad49cc2466a92 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -72,10 +72,10 @@ Status BaseDecimalBuilder::FinishInternal(std::shared_ptr* out return Status::OK(); } -template class BaseDecimalBuilder<16>; -template class BaseDecimalBuilder<32>; -template class BaseDecimalBuilder<64>; -template class BaseDecimalBuilder<128>; -template class BaseDecimalBuilder<256>; +template class ARROW_EXPORT BaseDecimalBuilder<16>; +template class ARROW_EXPORT BaseDecimalBuilder<32>; +template class ARROW_EXPORT BaseDecimalBuilder<64>; +template class ARROW_EXPORT BaseDecimalBuilder<128>; +template class ARROW_EXPORT BaseDecimalBuilder<256>; } // namespace arrow diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index 410bcac235aeb..4c27fc0cf1b64 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -31,7 +31,7 @@ namespace arrow { template -class BaseDecimalBuilder : public FixedSizeBinaryBuilder { +class ARROW_EXPORT BaseDecimalBuilder : public FixedSizeBinaryBuilder { public: using TypeClass = typename DecimalTypeTraits::TypeClass; using ArrayType = typename DecimalTypeTraits::ArrayType; diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index d0c5f4be35ab0..86d1715e21172 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -895,7 +895,7 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { /// \brief Template type class for decimal data template -class BaseDecimalType : public DecimalType { +class ARROW_EXPORT BaseDecimalType : public DecimalType { public: static constexpr const char* type_name() { return DecimalMeta::name; } diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index 08020a4556230..50e9ffccd769a 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -1245,8 +1245,8 @@ BasicDecimalAnyWidth BasicDecimalAnyWidth::IncreaseScaleBy(int32_t return (*this) * ScaleMultipliersAnyWidth::value[increase_by]; } -template class BasicDecimalAnyWidth<64>; -template class BasicDecimalAnyWidth<32>; -template class BasicDecimalAnyWidth<16>; +template class ARROW_EXPORT BasicDecimalAnyWidth<64>; +template class ARROW_EXPORT BasicDecimalAnyWidth<32>; +template class ARROW_EXPORT BasicDecimalAnyWidth<16>; } // namespace arrow diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h index 0b48f89f1b9db..6f08ec0b65dc3 100644 --- a/cpp/src/arrow/util/basic_decimal.h +++ b/cpp/src/arrow/util/basic_decimal.h @@ -338,7 +338,7 @@ ARROW_EXPORT BasicDecimal256 operator/(const BasicDecimal256& left, template -class BasicDecimalAnyWidth { +class ARROW_EXPORT BasicDecimalAnyWidth { public: using ValueType = typename IntTypes::signed_type; /// \brief Empty constructor creates a BasicDecimal256 with a value of 0. diff --git a/cpp/src/arrow/util/decimal.cc b/cpp/src/arrow/util/decimal.cc index 42bc9e7885a20..2de218a076b56 100644 --- a/cpp/src/arrow/util/decimal.cc +++ b/cpp/src/arrow/util/decimal.cc @@ -842,8 +842,8 @@ Status DecimalAnyWidth::ToArrowStatus(DecimalStatus dstatus) const { return arrow::ToArrowStatus(dstatus, width); } -template class DecimalAnyWidth<16>; -template class DecimalAnyWidth<32>; -template class DecimalAnyWidth<64>; +template class ARROW_EXPORT DecimalAnyWidth<16>; +template class ARROW_EXPORT DecimalAnyWidth<32>; +template class ARROW_EXPORT DecimalAnyWidth<64>; } // namespace arrow From 0c8240470fc0d95ce2205193522695c3c6784c15 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Fri, 15 Jan 2021 11:04:06 +0300 Subject: [PATCH 5/8] Edited R datatypes enum Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/dataset/filter.cc | 1775 ------------------------------- r/R/enums.R | 31 +- 2 files changed, 17 insertions(+), 1789 deletions(-) delete mode 100644 cpp/src/arrow/dataset/filter.cc diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc deleted file mode 100644 index a7ce9761feed2..0000000000000 --- a/cpp/src/arrow/dataset/filter.cc +++ /dev/null @@ -1,1775 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/dataset/filter.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "arrow/array/builder_primitive.h" -#include "arrow/buffer.h" -#include "arrow/compute/api.h" -#include "arrow/dataset/dataset.h" -#include "arrow/io/memory.h" -#include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/int_util_internal.h" -#include "arrow/util/iterator.h" -#include "arrow/util/logging.h" -#include "arrow/util/string.h" -#include "arrow/visitor_inline.h" - -namespace arrow { - -using compute::CompareOperator; -using compute::ExecContext; - -namespace dataset { - -using arrow::internal::checked_cast; -using arrow::internal::checked_pointer_cast; - -inline std::shared_ptr NullExpression() { - return std::make_shared(std::make_shared()); -} - -inline Datum NullDatum() { return Datum(std::make_shared()); } - -bool IsNullDatum(const Datum& datum) { - if (datum.is_scalar()) { - auto scalar = datum.scalar(); - return !scalar->is_valid; - } - - auto array_data = datum.array(); - return array_data->GetNullCount() == array_data->length; -} - -struct Comparison { - enum type { - LESS, - EQUAL, - GREATER, - NULL_, - }; -}; - -Result> EnsureNotDictionary( - const std::shared_ptr& scalar) { - if (scalar->type->id() == Type::DICTIONARY) { - return checked_cast(*scalar).GetEncodedValue(); - } - return scalar; -} - -Result Compare(const Scalar& lhs, const Scalar& rhs); - -struct CompareVisitor { - template - using ScalarType = typename TypeTraits::ScalarType; - - Status Visit(const NullType&) { - result_ = Comparison::NULL_; - return Status::OK(); - } - - Status Visit(const BooleanType&) { return CompareValues(); } - - template - enable_if_physical_floating_point Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_signed_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_unsigned_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_nested Visit(const T&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - template - enable_if_binary_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - template - enable_if_string_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - Status Visit(const Decimal16Type&) { return CompareValues(); } - Status Visit(const Decimal32Type&) { return CompareValues(); } - Status Visit(const Decimal64Type&) { return CompareValues(); } - Status Visit(const Decimal128Type&) { return CompareValues(); } - Status Visit(const Decimal256Type&) { return CompareValues(); } - - // Explicit because it falls under `physical_unsigned_integer`. - // TODO(bkietz) whenever we vendor a float16, this can be implemented - Status Visit(const HalfFloatType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const ExtensionType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const DictionaryType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - // defer comparison to ScalarType::value - template - Status CompareValues() { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - return CompareValues(lhs, rhs); - } - - // defer comparison to explicit values - template - Status CompareValues(Value lhs, Value rhs) { - result_ = lhs < rhs ? Comparison::LESS - : lhs == rhs ? Comparison::EQUAL : Comparison::GREATER; - return Status::OK(); - } - - Comparison::type result_; - const Scalar& lhs_; - const Scalar& rhs_; -}; - -// Compare two scalars -// if either is null, return is null -// TODO(bkietz) extract this to the scalar comparison kernels -Result Compare(const Scalar& lhs, const Scalar& rhs) { - if (!lhs.type->Equals(*rhs.type)) { - return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, - " vs ", *rhs.type); - } - if (!lhs.is_valid || !rhs.is_valid) { - return Comparison::NULL_; - } - CompareVisitor vis{Comparison::NULL_, lhs, rhs}; - RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); - return vis.result_; -} - -CompareOperator InvertCompareOperator(CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return CompareOperator::NOT_EQUAL; - - case CompareOperator::NOT_EQUAL: - return CompareOperator::EQUAL; - - case CompareOperator::GREATER: - return CompareOperator::LESS_EQUAL; - - case CompareOperator::GREATER_EQUAL: - return CompareOperator::LESS; - - case CompareOperator::LESS: - return CompareOperator::GREATER_EQUAL; - - case CompareOperator::LESS_EQUAL: - return CompareOperator::GREATER; - - default: - break; - } - - DCHECK(false); - return CompareOperator::EQUAL; -} - -template -std::shared_ptr InvertBoolean(const Boolean& expr) { - auto lhs = Invert(*expr.left_operand()); - auto rhs = Invert(*expr.right_operand()); - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - return nullptr; -} - -std::shared_ptr Invert(const Expression& expr) { - switch (expr.type()) { - case ExpressionType::NOT: - return checked_cast(expr).operand(); - - case ExpressionType::AND: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::OR: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::COMPARISON: { - const auto& comparison = checked_cast(expr); - auto inverted_op = InvertCompareOperator(comparison.op()); - return std::make_shared( - inverted_op, comparison.left_operand(), comparison.right_operand()); - } - - default: - break; - } - return nullptr; -} - -std::shared_ptr Expression::Assume(const Expression& given) const { - std::shared_ptr out; - - DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) { - if (out != nullptr) { - return Status::OK(); - } - - if (given.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } - - const auto& given_cmp = checked_cast(given); - if (given_cmp.op() != CompareOperator::EQUAL) { - return Status::OK(); - } - - if (this->Equals(given_cmp.left_operand())) { - out = given_cmp.right_operand(); - return Status::OK(); - } - - if (this->Equals(given_cmp.right_operand())) { - out = given_cmp.left_operand(); - return Status::OK(); - } - - return Status::OK(); - })); - - return out ? out : Copy(); -} - -std::shared_ptr ComparisonExpression::Assume(const Expression& given) const { - switch (given.type()) { - case ExpressionType::COMPARISON: { - return AssumeGivenComparison(checked_cast(given)); - } - - case ExpressionType::NOT: { - const auto& given_not = checked_cast(given); - if (auto inverted = Invert(*given_not.operand())) { - return Assume(*inverted); - } - return Copy(); - } - - case ExpressionType::OR: { - const auto& given_or = checked_cast(given); - - auto left_simplified = Assume(*given_or.left_operand()); - auto right_simplified = Assume(*given_or.right_operand()); - - // The result of simplification against the operands of an OrExpression - // cannot be used unless they are identical - if (left_simplified->Equals(right_simplified)) { - return left_simplified; - } - - return Copy(); - } - - case ExpressionType::AND: { - const auto& given_and = checked_cast(given); - - auto simplified = Copy(); - simplified = simplified->Assume(*given_and.left_operand()); - simplified = simplified->Assume(*given_and.right_operand()); - return simplified; - } - - // TODO(bkietz) we should be able to use ExpressionType::IN here - - default: - break; - } - - return Copy(); -} - -// Try to simplify one comparison against another comparison. -// For example, -// ("x"_ > 3) is a subset of ("x"_ > 2), so ("x"_ > 2).Assume("x"_ > 3) == (true) -// ("x"_ < 0) is disjoint with ("x"_ > 2), so ("x"_ > 2).Assume("x"_ < 0) == (false) -// If simplification to (true) or (false) is not possible, pass e through unchanged. -std::shared_ptr ComparisonExpression::AssumeGivenComparison( - const ComparisonExpression& given) const { - if (!left_operand_->Equals(given.left_operand_)) { - return Copy(); - } - - for (auto rhs : {right_operand_, given.right_operand_}) { - if (rhs->type() != ExpressionType::SCALAR) { - return Copy(); - } - } - - auto this_rhs = - EnsureNotDictionary(checked_cast(*right_operand_).value()) - .ValueOr(nullptr); - auto given_rhs = - EnsureNotDictionary( - checked_cast(*given.right_operand_).value()) - .ValueOr(nullptr); - - if (!this_rhs || !given_rhs) { - return Copy(); - } - - auto cmp = Compare(*this_rhs, *given_rhs).ValueOrDie(); - - if (cmp == Comparison::NULL_) { - // the RHS of e or given was null - return NullExpression(); - } - - static auto always = scalar(true); - static auto never = scalar(false); - - if (cmp == Comparison::GREATER) { - // the rhs of e is greater than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - if (cmp == Comparison::LESS) { - // the rhs of e is less than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - DCHECK_EQ(cmp, Comparison::EQUAL); - - // the rhs of the comparisons are equal - switch (op_) { - case CompareOperator::EQUAL: - switch (given.op()) { - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - return never; - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS_EQUAL: - case CompareOperator::LESS: - return never; - case CompareOperator::GREATER: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::LESS: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::GREATER: - return never; - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - return Copy(); -} - -std::shared_ptr AndExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially false then so is this AND - if (left_operand->Equals(false) || right_operand->Equals(false)) { - return scalar(false); - } - - // if either operand is trivially null then so is this AND - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially true then drop it - if (left_operand->Equals(true)) { - return right_operand; - } - if (right_operand->Equals(true)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new AND - return and_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr OrExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially true then so is this OR - if (left_operand->Equals(true) || right_operand->Equals(true)) { - return scalar(true); - } - - // if either operand is trivially null then so is this OR - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially false then drop it - if (left_operand->Equals(false)) { - return right_operand; - } - if (right_operand->Equals(false)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new OR - return or_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr NotExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - - if (operand->IsNull()) { - return NullExpression(); - } - if (operand->Equals(true)) { - return scalar(false); - } - if (operand->Equals(false)) { - return scalar(true); - } - - return Copy(); -} - -std::shared_ptr InExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() != ExpressionType::SCALAR) { - return std::make_shared(std::move(operand), set_); - } - - if (operand->IsNull()) { - return scalar(set_->null_count() > 0); - } - - Datum set, value; - if (set_->type_id() == Type::DICTIONARY) { - const auto& dict_set = checked_cast(*set_); - auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices()); - auto maybe_value = checked_cast( - *checked_cast(*operand).value()) - .GetEncodedValue(); - if (!maybe_decoded.ok() || !maybe_value.ok()) { - return std::make_shared(std::move(operand), set_); - } - set = *maybe_decoded; - value = *maybe_value; - } else { - set = set_; - value = checked_cast(*operand).value(); - } - - compute::CompareOptions eq(CompareOperator::EQUAL); - Result maybe_out = compute::Compare(set, value, eq); - if (!maybe_out.ok()) { - return std::make_shared(std::move(operand), set_); - } - - Datum out = maybe_out.ValueOrDie(); - - DCHECK(out.is_array()); - DCHECK_EQ(out.type()->id(), Type::BOOL); - auto out_array = checked_pointer_cast(out.make_array()); - - for (int64_t i = 0; i < out_array->length(); ++i) { - if (out_array->IsValid(i) && out_array->Value(i)) { - return scalar(true); - } - } - return scalar(false); -} - -std::shared_ptr IsValidExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() == ExpressionType::SCALAR) { - return scalar(!operand->IsNull()); - } - - return std::make_shared(std::move(operand)); -} - -std::shared_ptr CastExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - return std::make_shared(std::move(operand), std::move(to_type), - options_); - } - auto like = arrow::util::get>(to_)->Assume(given); - return std::make_shared(std::move(operand), std::move(like), options_); -} - -const std::shared_ptr& CastExpression::to_type() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -const std::shared_ptr& CastExpression::like_expr() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -std::string FieldExpression::ToString() const { return name_; } - -std::string OperatorName(compute::CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return "=="; - case CompareOperator::NOT_EQUAL: - return "!="; - case CompareOperator::LESS: - return "<"; - case CompareOperator::LESS_EQUAL: - return "<="; - case CompareOperator::GREATER: - return ">"; - case CompareOperator::GREATER_EQUAL: - return ">="; - default: - DCHECK(false); - } - return ""; -} - -std::string ScalarExpression::ToString() const { - auto type_repr = value_->type->ToString(); - if (!value_->is_valid) { - return "null:" + type_repr; - } - - return value_->ToString() + ":" + type_repr; -} - -using arrow::internal::JoinStrings; - -std::string AndExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " and ", right_operand_->ToString(), ")"}, ""); -} - -std::string OrExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " or ", right_operand_->ToString(), ")"}, ""); -} - -std::string NotExpression::ToString() const { - if (operand_->type() == ExpressionType::IS_VALID) { - const auto& is_valid = checked_cast(*operand_); - return JoinStrings({"(", is_valid.operand()->ToString(), " is null)"}, ""); - } - return JoinStrings({"(not ", operand_->ToString(), ")"}, ""); -} - -std::string IsValidExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is not null)"}, ""); -} - -std::string InExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is in ", set_->ToString(), ")"}, ""); -} - -std::string CastExpression::ToString() const { - std::string to; - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - to = " to " + to_type->ToString(); - } else { - auto like = arrow::util::get>(to_); - to = " like " + like->ToString(); - } - return JoinStrings({"(cast ", operand_->ToString(), std::move(to), ")"}, ""); -} - -std::string ComparisonExpression::ToString() const { - return JoinStrings({"(", left_operand_->ToString(), " ", OperatorName(op()), " ", - right_operand_->ToString(), ")"}, - ""); -} - -bool UnaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - operand_->Equals(checked_cast(other).operand_); -} - -bool BinaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - left_operand_->Equals( - checked_cast(other).left_operand_) && - right_operand_->Equals( - checked_cast(other).right_operand_); -} - -bool ComparisonExpression::Equals(const Expression& other) const { - return BinaryExpression::Equals(other) && - op_ == checked_cast(other).op_; -} - -bool ScalarExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::SCALAR && - value_->Equals(*checked_cast(other).value_); -} - -bool FieldExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::FIELD && - name_ == checked_cast(other).name_; -} - -bool Expression::Equals(const std::shared_ptr& other) const { - if (other == nullptr) { - return false; - } - return Equals(*other); -} - -bool Expression::IsNull() const { - if (type_ != ExpressionType::SCALAR) { - return false; - } - - const auto& scalar = checked_cast(*this).value(); - if (!scalar->is_valid) { - return true; - } - - return false; -} - -InExpression Expression::In(std::shared_ptr set) const { - return InExpression(Copy(), std::move(set)); -} - -IsValidExpression Expression::IsValid() const { return IsValidExpression(Copy()); } - -std::shared_ptr FieldExpression::Copy() const { - return std::make_shared(*this); -} - -std::shared_ptr ScalarExpression::Copy() const { - return std::make_shared(*this); -} - -std::shared_ptr and_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr and_(const ExpressionVector& subexpressions) { - auto acc = scalar(true); - for (const auto& next : subexpressions) { - if (next->Equals(false)) return next; - acc = acc->Equals(true) ? next : and_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr or_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr or_(const ExpressionVector& subexpressions) { - auto acc = scalar(false); - for (const auto& next : subexpressions) { - if (next->Equals(true)) return next; - acc = acc->Equals(false) ? next : or_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr not_(std::shared_ptr operand) { - return std::make_shared(std::move(operand)); -} - -AndExpression operator&&(const Expression& lhs, const Expression& rhs) { - return AndExpression(lhs.Copy(), rhs.Copy()); -} - -OrExpression operator||(const Expression& lhs, const Expression& rhs) { - return OrExpression(lhs.Copy(), rhs.Copy()); -} - -NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } - -CastExpression Expression::CastTo(std::shared_ptr type, - compute::CastOptions options) const { - return CastExpression(Copy(), type, std::move(options)); -} - -CastExpression Expression::CastLike(std::shared_ptr expr, - compute::CastOptions options) const { - return CastExpression(Copy(), std::move(expr), std::move(options)); -} - -CastExpression Expression::CastLike(const Expression& expr, - compute::CastOptions options) const { - return CastLike(expr.Copy(), std::move(options)); -} - -Result> ComparisonExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto lhs_type, left_operand_->Validate(schema)); - ARROW_ASSIGN_OR_RAISE(auto rhs_type, right_operand_->Validate(schema)); - - if (lhs_type->id() == Type::NA || rhs_type->id() == Type::NA) { - return boolean(); - } - - if (!lhs_type->Equals(rhs_type)) { - return Status::TypeError("cannot compare expressions of differing type, ", *lhs_type, - " vs ", *rhs_type); - } - - return boolean(); -} - -Status EnsureNullOrBool(const std::string& msg_prefix, - const std::shared_ptr& type) { - if (type->id() == Type::BOOL || type->id() == Type::NA) { - return Status::OK(); - } - return Status::TypeError(msg_prefix, *type); -} - -Result> ValidateBoolean(const ExpressionVector& operands, - const Schema& schema) { - for (const auto& operand : operands) { - ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); - RETURN_NOT_OK( - EnsureNullOrBool("cannot combine expressions including one of type ", type)); - } - return boolean(); -} - -Result> AndExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> OrExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> NotExpression::Validate(const Schema& schema) const { - return ValidateBoolean({operand_}, schema); -} - -Result> InExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - if (operand_type->id() == Type::NA || set_->type()->id() == Type::NA) { - return boolean(); - } - - if (!operand_type->Equals(set_->type())) { - return Status::TypeError("mismatch: set type ", *set_->type(), " vs operand type ", - *operand_type); - } - // TODO(bkietz) check if IsIn supports operand_type - return boolean(); -} - -Result> IsValidExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(std::ignore, operand_->Validate(schema)); - return boolean(); -} - -Result> CastExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - std::shared_ptr to_type; - if (arrow::util::holds_alternative>(to_)) { - to_type = arrow::util::get>(to_); - } else { - auto like = arrow::util::get>(to_); - ARROW_ASSIGN_OR_RAISE(to_type, like->Validate(schema)); - } - - // Until expressions carry a shape, detect scalar and try to cast it. Works - // if the operand is a scalar leaf. - if (operand_->type() == ExpressionType::SCALAR) { - auto scalar_expr = checked_pointer_cast(operand_); - ARROW_ASSIGN_OR_RAISE(std::ignore, scalar_expr->value()->CastTo(to_type)); - return to_type; - } - - if (!compute::CanCast(*operand_type, *to_type)) { - return Status::Invalid("Cannot cast to ", to_type->ToString()); - } - - return to_type; -} - -Result> ScalarExpression::Validate(const Schema& schema) const { - return value_->type; -} - -Result> FieldExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(name_).GetOneOrNone(schema)); - if (field != nullptr) { - return field->type(); - } - return null(); -} - -Result CastOrDictionaryEncode(const Datum& arr, - const std::shared_ptr& type, - const compute::CastOptions opts) { - if (type->id() == Type::DICTIONARY) { - const auto& dict_type = checked_cast(*type); - if (dict_type.index_type()->id() != Type::INT32) { - return Status::TypeError("cannot DictionaryEncode to index type ", - *dict_type.index_type()); - } - ARROW_ASSIGN_OR_RAISE(auto dense, compute::Cast(arr, dict_type.value_type(), opts)); - return compute::DictionaryEncode(dense); - } - - return compute::Cast(arr, type, opts); -} - -struct InsertImplicitCastsImpl { - struct ValidatedAndCast { - std::shared_ptr expr; - std::shared_ptr type; - }; - - Result InsertCastsAndValidate(const Expression& expr) { - ValidatedAndCast out; - ARROW_ASSIGN_OR_RAISE(out.expr, InsertImplicitCasts(expr, schema_)); - ARROW_ASSIGN_OR_RAISE(out.type, out.expr->Validate(schema_)); - return std::move(out); - } - - Result> Cast(std::shared_ptr type, - const Expression& expr) { - if (expr.type() != ExpressionType::SCALAR) { - return expr.CastTo(type).Copy(); - } - - // cast the scalar directly - const auto& value = checked_cast(expr).value(); - ARROW_ASSIGN_OR_RAISE(auto cast_value, value->CastTo(std::move(type))); - return scalar(cast_value); - } - - Result> operator()(const InExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - auto set = expr.set(); - - if (!op.type->Equals(set->type())) { - // cast the set (which we assume to be small) to match op.type - ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode(*set, op.type, {})); - set = encoded_set.make_array(); - } - - return std::make_shared(std::move(op.expr), std::move(set)); - } - - Result> operator()(const NotExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - - if (op.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(op.expr, Cast(boolean(), *op.expr)); - } - return not_(std::move(op.expr)); - } - - Result> operator()(const AndExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return and_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const OrExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return or_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const ComparisonExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->Equals(rhs.type)) { - return expr.Copy(); - } - - if (lhs.expr->type() == ExpressionType::SCALAR) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(rhs.type, *lhs.expr)); - } else { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(lhs.type, *rhs.expr)); - } - return std::make_shared(expr.op(), std::move(lhs.expr), - std::move(rhs.expr)); - } - - Result> operator()(const Expression& expr) const { - return expr.Copy(); - } - - const Schema& schema_; -}; - -Result> InsertImplicitCasts(const Expression& expr, - const Schema& schema) { - RETURN_NOT_OK(schema.CanReferenceFieldsByNames(FieldsInExpression(expr))); - return VisitExpression(expr, InsertImplicitCastsImpl{schema}); -} - -Status VisitConjunctionMembers(const Expression& expr, - const std::function& visitor) { - if (expr.type() == ExpressionType::AND) { - const auto& and_ = checked_cast(expr); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor)); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor)); - return Status::OK(); - } - - return visitor(expr); -} - -std::vector FieldsInExpression(const Expression& expr) { - struct { - void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); } - - void operator()(const UnaryExpression& expr) { - VisitExpression(*expr.operand(), *this); - } - - void operator()(const BinaryExpression& expr) { - VisitExpression(*expr.left_operand(), *this); - VisitExpression(*expr.right_operand(), *this); - } - - void operator()(const Expression&) const {} - - std::vector fields; - } visitor; - - VisitExpression(expr, visitor); - return std::move(visitor.fields); -} - -std::vector FieldsInExpression(const std::shared_ptr& expr) { - DCHECK_NE(expr, nullptr); - if (expr == nullptr) { - return {}; - } - - return FieldsInExpression(*expr); -} - -RecordBatchIterator ExpressionEvaluator::FilterBatches(RecordBatchIterator unfiltered, - std::shared_ptr filter, - MemoryPool* pool) { - auto filter_batches = [filter, pool, this](std::shared_ptr unfiltered) { - auto filtered = Evaluate(*filter, *unfiltered, pool).Map([&](Datum selection) { - return Filter(selection, unfiltered, pool); - }); - - if (filtered.ok() && (*filtered)->num_rows() == 0) { - // drop empty batches - return FilterIterator::Reject>(); - } - - return FilterIterator::MaybeAccept(std::move(filtered)); - }; - - return MakeFilterIterator(std::move(filter_batches), std::move(unfiltered)); -} - -std::shared_ptr ExpressionEvaluator::Null() { - struct Impl : ExpressionEvaluator { - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override { - ARROW_ASSIGN_OR_RAISE(auto type, expr.Validate(*batch.schema())); - return Datum(MakeNullScalar(type)); - } - - Result> Filter(const Datum& selection, - const std::shared_ptr& batch, - MemoryPool* pool) const override { - return batch; - } - }; - - return std::make_shared(); -} - -struct TreeEvaluator::Impl { - Result operator()(const ScalarExpression& expr) const { - return Datum(expr.value()); - } - - Result operator()(const FieldExpression& expr) const { - if (auto column = batch_.GetColumnByName(expr.name())) { - return std::move(column); - } - return NullDatum(); - } - - Result operator()(const AndExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneAnd); - } - - Result operator()(const OrExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneOr); - } - - Result EvaluateBoolean(const BinaryExpression& expr, - Result kernel(const Datum& left, - const Datum& right, - ExecContext* ctx)) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (lhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto lhs_array, - MakeArrayFromScalar(*lhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - lhs = Datum(std::move(lhs_array)); - } - - if (rhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto rhs_array, - MakeArrayFromScalar(*rhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - rhs = Datum(std::move(rhs_array)); - } - - return kernel(lhs, rhs, &ctx_); - } - - Result operator()(const NotExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum to_invert, Evaluate(*expr.operand())); - if (IsNullDatum(to_invert)) { - return NullDatum(); - } - - if (to_invert.is_scalar()) { - bool trivial_condition = - checked_cast(*to_invert.scalar()).value; - return Datum(std::make_shared(!trivial_condition)); - } - return compute::Invert(to_invert, &ctx_); - } - - Result operator()(const InExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(expr.set()->null_count() != 0); - } - - DCHECK(operand_values.is_array()); - return compute::IsIn(operand_values, expr.set(), &ctx_); - } - - Result operator()(const IsValidExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(false); - } - - if (operand_values.is_scalar()) { - return Datum(true); - } - - DCHECK(operand_values.is_array()); - if (operand_values.array()->GetNullCount() == 0) { - return Datum(true); - } - - return Datum(std::make_shared(operand_values.array()->length, - operand_values.array()->buffers[0])); - } - - Result operator()(const CastExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto to_type, expr.Validate(*batch_.schema())); - - ARROW_ASSIGN_OR_RAISE(auto to_cast, Evaluate(*expr.operand())); - if (to_cast.is_scalar()) { - return to_cast.scalar()->CastTo(to_type); - } - - DCHECK(to_cast.is_array()); - return CastOrDictionaryEncode(to_cast, to_type, expr.options()); - } - - Result operator()(const ComparisonExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (IsNullDatum(lhs) || IsNullDatum(rhs)) { - return Datum(std::make_shared()); - } - - if (lhs.type()->id() == Type::DICTIONARY && rhs.type()->id() == Type::DICTIONARY) { - if (lhs.is_array() && rhs.is_array()) { - // decode dictionary arrays - for (Datum* arg : {&lhs, &rhs}) { - auto dict = checked_pointer_cast(arg->make_array()); - ARROW_ASSIGN_OR_RAISE(*arg, compute::Take(dict->dictionary(), dict->indices(), - compute::TakeOptions::Defaults())); - } - } else if (lhs.is_array() || rhs.is_array()) { - auto dict = checked_pointer_cast( - (lhs.is_array() ? lhs : rhs).make_array()); - - ARROW_ASSIGN_OR_RAISE(auto scalar, checked_cast( - *(lhs.is_scalar() ? lhs : rhs).scalar()) - .GetEncodedValue()); - if (lhs.is_array()) { - lhs = dict->dictionary(); - rhs = std::move(scalar); - } else { - lhs = std::move(scalar); - rhs = dict->dictionary(); - } - ARROW_ASSIGN_OR_RAISE( - Datum out_dict, - compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_)); - - return compute::Take(out_dict, dict->indices(), compute::TakeOptions::Defaults()); - } - } - - return compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_); - } - - Result operator()(const Expression& expr) const { - return Status::NotImplemented("evaluation of ", expr.ToString()); - } - - Result Evaluate(const Expression& expr) const { - return this_->Evaluate(expr, batch_, ctx_.memory_pool()); - } - - const TreeEvaluator* this_; - const RecordBatch& batch_; - mutable compute::ExecContext ctx_; -}; - -Result TreeEvaluator::Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const { - return VisitExpression(expr, Impl{this, batch, compute::ExecContext{pool}}); -} - -Result> TreeEvaluator::Filter( - const Datum& selection, const std::shared_ptr& batch, - MemoryPool* pool) const { - if (selection.is_array()) { - auto selection_array = selection.make_array(); - compute::ExecContext ctx(pool); - ARROW_ASSIGN_OR_RAISE(Datum filtered, - compute::Filter(batch, selection_array, - compute::FilterOptions::Defaults(), &ctx)); - return filtered.record_batch(); - } - - if (!selection.is_scalar() || selection.type()->id() != Type::BOOL) { - return Status::NotImplemented("Filtering batches against DatumKind::", - selection.kind(), " of type ", *selection.type()); - } - - if (BooleanScalar(true).Equals(*selection.scalar())) { - return batch; - } - - return batch->Slice(0, 0); -} - -const std::shared_ptr& scalar(bool value) { - static auto true_ = scalar(MakeScalar(true)); - static auto false_ = scalar(MakeScalar(false)); - return value ? true_ : false_; -} - -// Serialization is accomplished by converting expressions to single element StructArrays -// then writing that to an IPC file. The last field is always an int32 column containing -// ExpressionType, the rest store the Expression's members. -struct SerializeImpl { - Result> ToArray(const Expression& expr) const { - return VisitExpression(expr, *this); - } - - Result> TaggedWithChildren(const Expression& expr, - ArrayVector children) const { - children.emplace_back(); - ARROW_ASSIGN_OR_RAISE(children.back(), - MakeArrayFromScalar(Int32Scalar(expr.type()), 1)); - - return StructArray::Make(children, std::vector(children.size(), "")); - } - - Result> operator()(const FieldExpression& expr) const { - // store the field's name in a StringArray - ARROW_ASSIGN_OR_RAISE(auto name, MakeArrayFromScalar(StringScalar(expr.name()), 1)); - return TaggedWithChildren(expr, {name}); - } - - Result> operator()(const ScalarExpression& expr) const { - // store the scalar's value in a single element Array - ARROW_ASSIGN_OR_RAISE(auto value, MakeArrayFromScalar(*expr.value(), 1)); - return TaggedWithChildren(expr, {value}); - } - - Result> operator()(const UnaryExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - return TaggedWithChildren(expr, {operand}); - } - - Result> operator()(const CastExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the cast target and a discriminant - std::shared_ptr is_like_expr, to; - if (const auto& to_type = expr.to_type()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(false), 1)); - ARROW_ASSIGN_OR_RAISE(to, MakeArrayOfNull(to_type, 1)); - } - if (const auto& like_expr = expr.like_expr()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(true), 1)); - ARROW_ASSIGN_OR_RAISE(to, ToArray(*like_expr)); - } - - return TaggedWithChildren(expr, {operand, is_like_expr, to}); - } - - Result> operator()(const BinaryExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - return TaggedWithChildren(expr, {left_operand, right_operand}); - } - - Result> operator()( - const ComparisonExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - // store the CompareOperator in a single element Int32Array - ARROW_ASSIGN_OR_RAISE(auto op, MakeArrayFromScalar(Int32Scalar(expr.op()), 1)); - return TaggedWithChildren(expr, {left_operand, right_operand, op}); - } - - Result> operator()(const InExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the set as a single element ListArray - auto set_type = list(expr.set()->type()); - - ARROW_ASSIGN_OR_RAISE(auto set_offsets, AllocateBuffer(sizeof(int32_t) * 2)); - reinterpret_cast(set_offsets->mutable_data())[0] = 0; - reinterpret_cast(set_offsets->mutable_data())[1] = - static_cast(expr.set()->length()); - - auto set_values = expr.set(); - - auto set = std::make_shared(std::move(set_type), 1, std::move(set_offsets), - std::move(set_values)); - return TaggedWithChildren(expr, {operand, set}); - } - - Result> operator()(const Expression& expr) const { - return Status::NotImplemented("serialization of ", expr.ToString()); - } - - Result> ToBuffer(const Expression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto array, SerializeImpl{}.ToArray(expr)); - ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array)); - ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); - ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - RETURN_NOT_OK(writer->Close()); - return stream->Finish(); - } -}; - -Result> Expression::Serialize() const { - return SerializeImpl{}.ToBuffer(*this); -} - -struct DeserializeImpl { - Result> FromArray(const Array& array) const { - if (array.type_id() != Type::STRUCT || array.length() != 1) { - return Status::Invalid("can only deserialize expressions from unit-length", - " StructArray, got ", array); - } - const auto& struct_array = checked_cast(array); - - ARROW_ASSIGN_OR_RAISE(auto expression_type, GetExpressionType(struct_array)); - switch (expression_type) { - case ExpressionType::FIELD: { - ARROW_ASSIGN_OR_RAISE(auto name, GetView(struct_array, 0)); - return field_ref(std::string(name)); - } - - case ExpressionType::SCALAR: { - ARROW_ASSIGN_OR_RAISE(auto value, struct_array.field(0)->GetScalar(0)); - return scalar(std::move(value)); - } - - case ExpressionType::NOT: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return not_(std::move(operand)); - } - - case ExpressionType::CAST: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto is_like_expr, GetView(struct_array, 1)); - if (is_like_expr) { - ARROW_ASSIGN_OR_RAISE(auto like_expr, FromArray(*struct_array.field(2))); - return operand->CastLike(std::move(like_expr)).Copy(); - } - return operand->CastTo(struct_array.field(2)->type()).Copy(); - } - - case ExpressionType::AND: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return and_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::OR: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return or_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::COMPARISON: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - ARROW_ASSIGN_OR_RAISE(auto op, GetView(struct_array, 2)); - return std::make_shared(static_cast(op), - std::move(left_operand), - std::move(right_operand)); - } - - case ExpressionType::IS_VALID: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return std::make_shared(std::move(operand)); - } - - case ExpressionType::IN: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - if (struct_array.field(1)->type_id() != Type::LIST) { - return Status::TypeError("expected field 1 of ", struct_array, - " to have list type"); - } - auto set = checked_cast(*struct_array.field(1)).values(); - return std::make_shared(std::move(operand), std::move(set)); - } - - default: - break; - } - - return Status::Invalid("non-deserializable ExpressionType ", expression_type); - } - - template ::ArrayType> - static Result().GetView(0))> GetView(const StructArray& array, - int index) { - if (index >= array.num_fields()) { - return Status::IndexError("expected ", array, " to have a child at index ", index); - } - - const auto& child = *array.field(index); - if (child.type_id() != T::type_id) { - return Status::TypeError("expected child ", index, " of ", array, " to have type ", - T::type_id); - } - - return checked_cast(child).GetView(0); - } - - static Result GetExpressionType(const StructArray& array) { - if (array.struct_type()->num_fields() < 1) { - return Status::Invalid("StructArray didn't contain ExpressionType member"); - } - - ARROW_ASSIGN_OR_RAISE(auto expression_type, - GetView(array, array.num_fields() - 1)); - return static_cast(expression_type); - } - - Result> FromBuffer(const Buffer& serialized) { - io::BufferReader stream(serialized); - ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); - ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); - ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray()); - return FromArray(*array); - } -}; - -Result> Expression::Deserialize(const Buffer& serialized) { - return DeserializeImpl{}.FromBuffer(serialized); -} - -// Transform an array of counts to offsets which will divide a ListArray -// into an equal number of slices with corresponding lengths. -inline Result> CountsToOffsets( - std::shared_ptr counts) { - Int32Builder offset_builder; - RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); - offset_builder.UnsafeAppend(0); - - for (int64_t i = 0; i < counts->length(); ++i) { - DCHECK_NE(counts->Value(i), 0); - auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); - offset_builder.UnsafeAppend(next_offset); - } - - std::shared_ptr offsets; - RETURN_NOT_OK(offset_builder.Finish(&offsets)); - return offsets; -} - -// Helper for simultaneous dictionary encoding of multiple arrays. -// -// The fused dictionary is the Cartesian product of the individual dictionaries. -// For example given two arrays A, B where A has unique values ["ex", "why"] -// and B has unique values [0, 1] the fused dictionary is the set of tuples -// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. -// -// TODO(bkietz) this capability belongs in an Action of the hash kernels, where -// it can be used to group aggregates without materializing a grouped batch. -// For the purposes of writing we need the materialized grouped batch anyway -// since no Writers accept a selection vector. -class StructDictionary { - public: - struct Encoded { - std::shared_ptr indices; - std::shared_ptr dictionary; - }; - - static Result Encode(const ArrayVector& columns) { - Encoded out{nullptr, std::make_shared()}; - - for (const auto& column : columns) { - if (column->null_count() != 0) { - return Status::NotImplemented("Grouping on a field with nulls"); - } - - RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); - } - - return out; - } - - Result> Decode(std::shared_ptr fused_indices, - FieldVector fields) { - std::vector builders(dictionaries_.size()); - for (Int32Builder& b : builders) { - RETURN_NOT_OK(b.Resize(fused_indices->length())); - } - - std::vector codes(dictionaries_.size()); - for (int64_t i = 0; i < fused_indices->length(); ++i) { - Expand(fused_indices->Value(i), codes.data()); - - auto builder_it = builders.begin(); - for (int32_t index : codes) { - builder_it++->UnsafeAppend(index); - } - } - - ArrayVector columns(dictionaries_.size()); - for (size_t i = 0; i < dictionaries_.size(); ++i) { - std::shared_ptr indices; - RETURN_NOT_OK(builders[i].FinishInternal(&indices)); - - ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); - columns[i] = column.make_array(); - } - - return StructArray::Make(std::move(columns), std::move(fields)); - } - - private: - Status AddOne(Datum column, std::shared_ptr* fused_indices) { - ArrayData* encoded; - if (column.type()->id() != Type::DICTIONARY) { - ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); - } - encoded = column.mutable_array(); - - auto indices = - std::make_shared(encoded->length, std::move(encoded->buffers[1])); - - dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); - auto dictionary_size = static_cast(dictionaries_.back()->length()); - - if (*fused_indices == nullptr) { - *fused_indices = std::move(indices); - size_ = dictionary_size; - return Status::OK(); - } - - // It's useful to think about the case where each of dictionaries_ has size 10. - // In this case the decimal digit in the ones place is the code in dictionaries_[0], - // the tens place corresponds to dictionaries_[1], etc. - // The incumbent indices must be shifted to the hundreds place so as not to collide. - ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, - compute::Multiply(indices, MakeScalar(size_))); - - ARROW_ASSIGN_OR_RAISE(new_fused_indices, - compute::Add(new_fused_indices, *fused_indices)); - - *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); - - // XXX should probably cap this at 2**15 or so - ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); - return Status::OK(); - } - - // expand a fused code into component dict codes, order is in order of addition - void Expand(int32_t fused_code, int32_t* codes) { - for (size_t i = 0; i < dictionaries_.size(); ++i) { - auto dictionary_size = static_cast(dictionaries_[i]->length()); - codes[i] = fused_code % dictionary_size; - fused_code /= dictionary_size; - } - } - - int32_t size_; - ArrayVector dictionaries_; -}; - -Result> MakeGroupings(const StructArray& by) { - if (by.num_fields() == 0) { - return Status::NotImplemented("Grouping with no criteria"); - } - - ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); - - ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); - ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); - fused.indices = checked_pointer_cast(sorted.make_array()); - - ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, - compute::ValueCounts(fused.indices)); - fused.indices.reset(); - - auto unique_fused_indices = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); - ARROW_ASSIGN_OR_RAISE( - auto unique_rows, - fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); - - auto counts = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); - ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); - - ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, - ListArray::FromArrays(*offsets, *sort_indices)); - - return StructArray::Make( - ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, - std::vector{"values", "groupings"}); -} - -Result> ApplyGroupings(const ListArray& groupings, - const Array& array) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(array, groupings.data()->child_data[0])); - - return std::make_shared(list(array.type()), groupings.length(), - groupings.value_offsets(), sorted.make_array()); -} - -Result ApplyGroupings(const ListArray& groupings, - const std::shared_ptr& batch) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(batch, groupings.data()->child_data[0])); - - const auto& sorted_batch = *sorted.record_batch(); - - RecordBatchVector out(static_cast(groupings.length())); - for (size_t i = 0; i < out.size(); ++i) { - out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); - } - - return out; -} - -} // namespace dataset -} // namespace arrow diff --git a/r/R/enums.R b/r/R/enums.R index 14910bc92e03a..bf98ed8818070 100644 --- a/r/R/enums.R +++ b/r/R/enums.R @@ -65,20 +65,23 @@ Type <- enum("Type::type", TIME64 = 20L, INTERVAL_MONTHS = 21L, INTERVAL_DAY_TIME = 22L, - DECIMAL = 23L, - DECIMAL256 = 24L, - LIST = 25L, - STRUCT = 26L, - SPARSE_UNION = 27L, - DENSE_UNION = 28L, - DICTIONARY = 29L, - MAP = 30L, - EXTENSION = 31L, - FIXED_SIZE_LIST = 32L, - DURATION = 33L, - LARGE_STRING = 34L, - LARGE_BINARY = 35L, - LARGE_LIST = 36L + DECIMAL16 = 23L, + DECIMAL32 = 24L, + DECIMAL64 = 25L, + DECIMAL = 26L, + DECIMAL256 = 27L, + LIST = 28L, + STRUCT = 29L, + SPARSE_UNION = 30L, + DENSE_UNION = 31L, + DICTIONARY = 32L, + MAP = 33L, + EXTENSION = 34L, + FIXED_SIZE_LIST = 35L, + DURATION = 36L, + LARGE_STRING = 37L, + LARGE_BINARY = 38L, + LARGE_LIST = 39L ) #' @rdname enums From 91ed793d3a87427e138c7d1c94ff6da7d6f27af2 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Fri, 15 Jan 2021 15:36:32 +0300 Subject: [PATCH 6/8] Tests fix Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/util/basic_decimal.h | 10 +++++----- cpp/src/arrow/util/decimal_test.cc | 12 +++++------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h index 6f08ec0b65dc3..ac2011452ce12 100644 --- a/cpp/src/arrow/util/basic_decimal.h +++ b/cpp/src/arrow/util/basic_decimal.h @@ -341,10 +341,10 @@ template class ARROW_EXPORT BasicDecimalAnyWidth { public: using ValueType = typename IntTypes::signed_type; - /// \brief Empty constructor creates a BasicDecimal256 with a value of 0. + /// \brief Empty constructor creates a BasicDecimal with a value of 0. constexpr BasicDecimalAnyWidth() noexcept : value(0) {} - /// \brief Convert any integer value into a BasicDecimal256. + /// \brief Convert any integer value into a BasicDecimal. template ::value && @@ -359,7 +359,7 @@ class ARROW_EXPORT BasicDecimalAnyWidth { constexpr BasicDecimalAnyWidth(const BasicDecimalAnyWidth<_width>& other) noexcept : value(static_cast(other.Value())) {} - /// \brief Create a BasicDecimal256 from an array of bytes. Bytes are assumed to be in + /// \brief Create a BasicDecimal from an array of bytes. Bytes are assumed to be in /// native-endian byte order. explicit BasicDecimalAnyWidth(const uint8_t* bytes); @@ -375,14 +375,14 @@ class ARROW_EXPORT BasicDecimalAnyWidth { DecimalStatus Divide(const BasicDecimalAnyWidth& divisor, BasicDecimalAnyWidth* result, BasicDecimalAnyWidth* remainder) const; - // \brief Scale multiplier for given scale value. + /// \brief Scale multiplier for given scale value. static BasicDecimalAnyWidth GetScaleMultiplier(int32_t scale); /// \brief Return the raw bytes of the value in native-endian byte order. std::array> 3)> ToBytes() const; void ToBytes(uint8_t* out) const; - /// \brief Convert BasicDecimal128 from one scale to another + /// \brief Convert BasicDecimal from one scale to another DecimalStatus Rescale(int32_t original_scale, int32_t new_scale, BasicDecimalAnyWidth* out) const; diff --git a/cpp/src/arrow/util/decimal_test.cc b/cpp/src/arrow/util/decimal_test.cc index 0dd99c7f10469..e58f4912f3eae 100644 --- a/cpp/src/arrow/util/decimal_test.cc +++ b/cpp/src/arrow/util/decimal_test.cc @@ -1735,21 +1735,19 @@ TYPED_TEST(DecimalAnyWidthTest, BinaryOperations) { using ArrowValueType = typename arrow::CTypeTraits::ArrowType; auto DecimalFns = DecimalAnyWidthBinaryParams::value; - auto NumericFns = DecimalAnyWidthBinaryParams::value; + auto NumericFns = DecimalAnyWidthBinaryParams::value; for (size_t i = 0; i < DecimalFns.size(); i++){ for (auto x : GetRandomNumbers(8)) { for (auto y : GetRandomNumbers(8)) { TypeParam d1(x), d2(y); - ASSERT_EQ(NumericFns[i].second(x, y), DecimalFns[i].second(d1, d2)) - << d1 << DecimalFns[i].first << " " << d2 << " " << " != " << NumericFns[i].second(x, y); + auto result = DecimalFns[i].second(d1, d2); + auto reference = static_cast(NumericFns[i].second(x, y)); + ASSERT_EQ(reference, result) + << d1 << " " << DecimalFns[i].first << " " << d2 << " " << " != " << result; } } } } - - - - } // namespace arrow From ba94f9ce6607a1d5dc77e7a3b432dc55ceaec83b Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Fri, 15 Jan 2021 17:14:44 +0300 Subject: [PATCH 7/8] Lint fixes Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/array/array_base.cc | 4 +- cpp/src/arrow/array/array_decimal.cc | 5 +- cpp/src/arrow/array/array_decimal.h | 5 +- cpp/src/arrow/array/array_test.cc | 21 +++-- cpp/src/arrow/array/builder_decimal.cc | 12 +-- cpp/src/arrow/array/builder_decimal.h | 6 +- cpp/src/arrow/array/validate.cc | 2 +- cpp/src/arrow/ipc/json_simple.cc | 8 +- cpp/src/arrow/ipc/json_simple_test.cc | 7 +- cpp/src/arrow/pretty_print_test.cc | 3 +- cpp/src/arrow/python/arrow_to_pandas.cc | 7 +- cpp/src/arrow/python/decimal.cc | 9 +- cpp/src/arrow/python/decimal.h | 2 +- cpp/src/arrow/python/python_test.cc | 24 ++--- cpp/src/arrow/scalar.cc | 4 +- cpp/src/arrow/scalar.h | 2 +- cpp/src/arrow/scalar_test.cc | 7 +- cpp/src/arrow/testing/json_internal.cc | 4 +- cpp/src/arrow/type.cc | 12 +-- cpp/src/arrow/type.h | 14 +-- cpp/src/arrow/type_fwd.h | 19 ++-- cpp/src/arrow/type_traits.h | 17 ++-- cpp/src/arrow/util/basic_decimal.cc | 93 +++++++++++-------- cpp/src/arrow/util/basic_decimal.h | 45 ++++----- cpp/src/arrow/util/decimal.cc | 62 +++++++------ cpp/src/arrow/util/decimal.h | 12 +-- cpp/src/arrow/util/decimal_meta.h | 26 +++--- .../arrow/util/decimal_scale_multipliers.h | 54 ++++++----- cpp/src/arrow/util/decimal_test.cc | 66 +++++++------ cpp/src/arrow/util/decimal_type_traits.h | 22 ++--- 30 files changed, 297 insertions(+), 277 deletions(-) diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index 2a0a7b7b6338c..682bc781cf2f5 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -69,7 +69,7 @@ struct ScalarFromArraySlotImpl { return Finish(a.Value(index_)); } - #define DECL_DECIMAL_VISIT(width) \ +#define DECL_DECIMAL_VISIT(width) \ Status Visit(const Decimal##width##Array& a) { \ return Finish(Decimal##width(a.GetValue(index_))); \ } @@ -80,7 +80,7 @@ struct ScalarFromArraySlotImpl { DECL_DECIMAL_VISIT(128) DECL_DECIMAL_VISIT(256) - #undef DECL_DECIMAL_VISIT +#undef DECL_DECIMAL_VISIT template Status Visit(const BaseBinaryArray& a) { diff --git a/cpp/src/arrow/array/array_decimal.cc b/cpp/src/arrow/array/array_decimal.cc index 5e7c6c85beab8..3d0187bd2ebec 100644 --- a/cpp/src/arrow/array/array_decimal.cc +++ b/cpp/src/arrow/array/array_decimal.cc @@ -32,14 +32,13 @@ namespace arrow { using internal::checked_cast; - -template +template BaseDecimalArray::BaseDecimalArray(const std::shared_ptr& data) : FixedSizeBinaryArray(data) { ARROW_CHECK_EQ(data->type->id(), DecimalTypeTraits::Id); } -template +template std::string BaseDecimalArray::FormatValue(int64_t i) const { const auto& type_ = checked_cast(*type()); const ValueType value(GetValue(i)); diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index b7c8515304863..7a35da7fa7aa2 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -22,15 +22,15 @@ #include #include "arrow/array/array_binary.h" -#include "arrow/util/decimal_type_traits.h" #include "arrow/array/data.h" #include "arrow/type.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/util/visibility.h" namespace arrow { /// Template Array class for decimal data -template +template class ARROW_EXPORT BaseDecimalArray : public FixedSizeBinaryArray { public: using TypeClass = typename DecimalTypeTraits::TypeClass; @@ -47,5 +47,4 @@ class ARROW_EXPORT BaseDecimalArray : public FixedSizeBinaryArray { // Backward compatibility using DecimalArray = Decimal128Array; - } // namespace arrow diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 38c19e4c50109..0f9b8fda85ef0 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -2408,17 +2408,19 @@ class DecimalTest : public ::testing::TestWithParam { } void InitNoNullsTest(int32_t precision) { - std::vector draw = {DecimalValue(1), DecimalValue(-2), DecimalValue(2389), - DecimalValue(4), DecimalValue(-12348)}; + std::vector draw = {DecimalValue(1), DecimalValue(-2), + DecimalValue(2389), DecimalValue(4), + DecimalValue(-12348)}; std::vector valid_bytes = {true, true, true, true, true}; this->TestCreate(precision, draw, valid_bytes, 0); this->TestCreate(precision, draw, valid_bytes, 2); } - void InitWithNullsTest(int32_t precision, std::string big_value, std::string big_negate_value) { + void InitWithNullsTest(int32_t precision, std::string big_value, + std::string big_negate_value) { std::vector draw = {DecimalValue(1), DecimalValue(2), DecimalValue(-1), - DecimalValue(4), DecimalValue(-1), DecimalValue(1), - DecimalValue(2)}; + DecimalValue(4), DecimalValue(-1), DecimalValue(1), + DecimalValue(2)}; DecimalValue big; ASSERT_OK_AND_ASSIGN(big, DecimalValue::FromString(big_value)); draw.push_back(big); @@ -2542,10 +2544,11 @@ TEST_P(Decimal256Test, NoNulls) { TEST_P(Decimal256Test, WithNulls) { int32_t precision = GetParam(); - this->InitWithNullsTest(precision, "578960446186580977117854925043439539266." - "34992332820282019728792003956564819967", - "-578960446186580977117854925043439539266." - "34992332820282019728792003956564819968"); + this->InitWithNullsTest(precision, + "578960446186580977117854925043439539266." + "34992332820282019728792003956564819967", + "-578960446186580977117854925043439539266." + "34992332820282019728792003956564819968"); } INSTANTIATE_TEST_SUITE_P(Decimal256Test, Decimal256Test, diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index ad49cc2466a92..19e5ddd8b5a8d 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -35,32 +35,32 @@ class MemoryPool; // ---------------------------------------------------------------------- // BaseDecimalBuilder -template +template BaseDecimalBuilder::BaseDecimalBuilder(const std::shared_ptr& type, - MemoryPool* pool) + MemoryPool* pool) : FixedSizeBinaryBuilder(type, pool), decimal_type_(internal::checked_pointer_cast(type)) {} -template +template Status BaseDecimalBuilder::Append(ValueType value) { RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1)); UnsafeAppend(value); return Status::OK(); } -template +template void BaseDecimalBuilder::UnsafeAppend(ValueType value) { value.ToBytes(GetMutableValue(length())); byte_builder_.UnsafeAdvance((width >> 3)); UnsafeAppendToBitmap(true); } -template +template void BaseDecimalBuilder::UnsafeAppend(util::string_view value) { FixedSizeBinaryBuilder::UnsafeAppend(value); } -template +template Status BaseDecimalBuilder::FinishInternal(std::shared_ptr* out) { std::shared_ptr data; RETURN_NOT_OK(byte_builder_.Finish(&data)); diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index 4c27fc0cf1b64..ebad9127d86d6 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -23,16 +23,16 @@ #include "arrow/array/builder_base.h" #include "arrow/array/builder_binary.h" #include "arrow/array/data.h" -#include "arrow/util/decimal_type_traits.h" #include "arrow/status.h" #include "arrow/type.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/util/visibility.h" namespace arrow { -template +template class ARROW_EXPORT BaseDecimalBuilder : public FixedSizeBinaryBuilder { -public: + public: using TypeClass = typename DecimalTypeTraits::TypeClass; using ArrayType = typename DecimalTypeTraits::ArrayType; using ValueType = typename DecimalTypeTraits::ValueType; diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index 691d7e4ad939f..47f5309c3c8ea 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -62,7 +62,7 @@ struct ValidateArrayImpl { return Status::OK(); } - template + template Status Visit(const BaseDecimalArray& array) { if (array.length() > 0 && array.values() == nullptr) { return Status::Invalid("values is null"); diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index 5cb903c8ccfe8..1a7a142f9f873 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -332,9 +332,11 @@ class DecimalConverter final const DecimalSubtype* decimal_type_; }; -#define DECL_DECIMAL_CONVERTER(width) \ -template ::BuilderType> \ -using Decimal##width##Converter = DecimalConverter; +#define DECL_DECIMAL_CONVERTER(width) \ + template ::BuilderType> \ + using Decimal##width##Converter = \ + DecimalConverter; DECL_DECIMAL_CONVERTER(16) DECL_DECIMAL_CONVERTER(32) diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index 3967f5f54c0e4..b44a17ce09dc1 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -501,7 +501,9 @@ TEST(TestFixedSizeBinary, Dictionary) { template class TestDecimal : public testing::Test {}; -using DecimalTypes = ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, DecimalTypeTraits<128>, DecimalTypeTraits<256>>; +using DecimalTypes = + ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, + DecimalTypeTraits<128>, DecimalTypeTraits<256>>; TYPED_TEST_SUITE(TestDecimal, DecimalTypes); @@ -559,8 +561,7 @@ TYPED_TEST(TestDecimal, Errors) { TYPED_TEST(TestDecimal, Dictionary) { using TypeClass = typename TypeParam::TypeClass; auto type = std::make_shared(5, 2); - AssertJSONDictArray(int32(), type, - R"(["123.45", "-78.90", "-78.90", null, "123.45"])", + AssertJSONDictArray(int32(), type, R"(["123.45", "-78.90", "-78.90", null, "123.45"])", /*indices=*/"[0, 1, 1, null, 0]", /*values=*/R"(["123.45", "-78.90"])"); } diff --git a/cpp/src/arrow/pretty_print_test.cc b/cpp/src/arrow/pretty_print_test.cc index ab69a7e782752..078bd5abccd61 100644 --- a/cpp/src/arrow/pretty_print_test.cc +++ b/cpp/src/arrow/pretty_print_test.cc @@ -502,7 +502,8 @@ TEST_F(TestPrettyPrint, DecimalTypes) { int32_t p = 5; int32_t s = 4; - for (auto type : {decimal16(p, s), decimal32(p, s), decimal64(p, s), decimal128(p, s), decimal256(p, s)}) { + for (auto type : {decimal16(p, s), decimal32(p, s), decimal64(p, s), decimal128(p, s), + decimal256(p, s)}) { auto array = ArrayFromJSON(type, "[\"1.4567\", \"3.2765\", null]"); static const char* ex = "[\n 1.4567,\n 3.2765,\n null\n]"; diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index 665df75df6a08..27903edc00c77 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -1024,7 +1024,9 @@ struct ObjectWriterVisitor { } template - enable_if_t<(is_base_binary_type::value || is_fixed_size_binary_type::value) && !is_decimal_type::value, + enable_if_t<(is_base_binary_type::value || + is_fixed_size_binary_type::value) && + !is_decimal_type::value, Status> Visit(const Type& type) { auto WrapValue = [](const util::string_view& view, PyObject** out) { @@ -1106,7 +1108,8 @@ struct ObjectWriterVisitor { PyObject* decimal_constructor = Decimal.obj(); for (int c = 0; c < data.num_chunks(); c++) { - const auto& arr = checked_cast&>(*data.chunk(c)); + const auto& arr = + checked_cast&>(*data.chunk(c)); for (int64_t i = 0; i < arr.length(); ++i) { if (arr.IsNull(i)) { diff --git a/cpp/src/arrow/python/decimal.cc b/cpp/src/arrow/python/decimal.cc index 25d8af59ee15e..8f36812da26c0 100644 --- a/cpp/src/arrow/python/decimal.cc +++ b/cpp/src/arrow/python/decimal.cc @@ -171,8 +171,7 @@ Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arr return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); } -Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, - Decimal16* out) { +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal16* out) { return InternalDecimalFromPyObject(obj, arrow_type, out); } @@ -181,8 +180,7 @@ Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arr return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); } -Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, - Decimal32* out) { +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal32* out) { return InternalDecimalFromPyObject(obj, arrow_type, out); } @@ -191,8 +189,7 @@ Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arr return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out); } -Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, - Decimal64* out) { +Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal64* out) { return InternalDecimalFromPyObject(obj, arrow_type, out); } diff --git a/cpp/src/arrow/python/decimal.h b/cpp/src/arrow/python/decimal.h index 08ce344f86469..5a7698a3929b5 100644 --- a/cpp/src/arrow/python/decimal.h +++ b/cpp/src/arrow/python/decimal.h @@ -24,7 +24,7 @@ namespace arrow { -template +template class DecimalAnyWidth; using Decimal16 = DecimalAnyWidth<16>; diff --git a/cpp/src/arrow/python/python_test.cc b/cpp/src/arrow/python/python_test.cc index b4a45831f6b9b..799b7c1d4098a 100644 --- a/cpp/src/arrow/python/python_test.cc +++ b/cpp/src/arrow/python/python_test.cc @@ -360,11 +360,11 @@ TEST_F(DecimalTest, FromPythonDecimalRescaleNotTruncateable) { // We fail when truncating values that would lose data if cast to a decimal type with // lower scale DecimalTestFromPythonDecimalRescale(::arrow::decimal16(5, 2), - this->CreatePythonDecimal("1.001"), {}); + this->CreatePythonDecimal("1.001"), {}); DecimalTestFromPythonDecimalRescale(::arrow::decimal32(10, 2), - this->CreatePythonDecimal("1.001"), {}); + this->CreatePythonDecimal("1.001"), {}); DecimalTestFromPythonDecimalRescale(::arrow::decimal64(10, 2), - this->CreatePythonDecimal("1.001"), {}); + this->CreatePythonDecimal("1.001"), {}); DecimalTestFromPythonDecimalRescale(::arrow::decimal128(10, 2), this->CreatePythonDecimal("1.001"), {}); DecimalTestFromPythonDecimalRescale(::arrow::decimal256(10, 2), @@ -374,12 +374,12 @@ TEST_F(DecimalTest, FromPythonDecimalRescaleNotTruncateable) { TEST_F(DecimalTest, FromPythonDecimalRescaleTruncateable) { // We allow truncation of values that do not lose precision when dividing by 10 * the // difference between the scales, e.g., 1.000 -> 1.00 - DecimalTestFromPythonDecimalRescale( - ::arrow::decimal16(5, 2), this->CreatePythonDecimal("1.000"), 100); - DecimalTestFromPythonDecimalRescale( - ::arrow::decimal32(10, 2), this->CreatePythonDecimal("1.000"), 100); - DecimalTestFromPythonDecimalRescale( - ::arrow::decimal64(10, 2), this->CreatePythonDecimal("1.000"), 100); + DecimalTestFromPythonDecimalRescale(::arrow::decimal16(5, 2), + this->CreatePythonDecimal("1.000"), 100); + DecimalTestFromPythonDecimalRescale(::arrow::decimal32(10, 2), + this->CreatePythonDecimal("1.000"), 100); + DecimalTestFromPythonDecimalRescale(::arrow::decimal64(10, 2), + this->CreatePythonDecimal("1.000"), 100); DecimalTestFromPythonDecimalRescale( ::arrow::decimal128(10, 2), this->CreatePythonDecimal("1.000"), 100); DecimalTestFromPythonDecimalRescale( @@ -399,9 +399,11 @@ TEST_F(DecimalTest, FromPythonNegativeDecimalRescale) { ::arrow::decimal256(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000); } -template +template class DecimalTestConversion : public testing::Test {}; -using DecimalTypes = ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, DecimalTypeTraits<128>, DecimalTypeTraits<256>>; +using DecimalTypes = + ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, + DecimalTypeTraits<128>, DecimalTypeTraits<256>>; TYPED_TEST_SUITE(DecimalTestConversion, DecimalTypes); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 351e115e00f44..4c01c38adca06 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -69,8 +69,8 @@ struct ScalarHashImpl { return StdHash(s.value.days) & StdHash(s.value.days); } - template - typename std::enable_if::type Visit(const BaseDecimalScalar& s) { + template ::type> + Status Visit(const BaseDecimalScalar& s) { return StdHash(s.value.Value()); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index c91239f495d58..01decf9637cd6 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -340,7 +340,7 @@ struct ARROW_EXPORT DurationScalar : public TemporalScalar { using TemporalScalar::TemporalScalar; }; -template +template struct BaseDecimalScalar : public Scalar { using Scalar::Scalar; using TypeClass = typename DecimalTypeTraits::TypeClass; diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 159fa725e0202..64c2a5aa234d5 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -329,10 +329,11 @@ TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); } TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); } - -template +template class TestDecimalScalar : public testing::Test {}; -using DecimalTypes = ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, DecimalTypeTraits<128>, DecimalTypeTraits<256>>; +using DecimalTypes = + ::testing::Types, DecimalTypeTraits<32>, DecimalTypeTraits<64>, + DecimalTypeTraits<128>, DecimalTypeTraits<256>>; TYPED_TEST_SUITE(TestDecimalScalar, DecimalTypes); diff --git a/cpp/src/arrow/testing/json_internal.cc b/cpp/src/arrow/testing/json_internal.cc index ed59d0111066a..4d5e4c7e5efd4 100644 --- a/cpp/src/arrow/testing/json_internal.cc +++ b/cpp/src/arrow/testing/json_internal.cc @@ -546,7 +546,7 @@ class ArrayWriter { } } - template + template void WriteDataValues(const BaseDecimalArray& arr) { static const char null_string[] = "0"; for (int64_t i = 0; i < arr.length(); ++i) { @@ -864,7 +864,7 @@ Status GetDecimal(const RjObject& json_type, std::shared_ptr* type) { break; default: return Status::Invalid("Only 128 bit and 256 Decimals are supported. Received", - bit_width); + bit_width); } return Status::OK(); diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index ca083ad72e8d8..20d71818a683d 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -35,6 +35,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/decimal_type_traits.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/key_value_metadata.h" @@ -42,7 +43,6 @@ #include "arrow/util/make_unique.h" #include "arrow/util/range.h" #include "arrow/util/vector.h" -#include "arrow/util/decimal_type_traits.h" #include "arrow/visitor_inline.h" namespace arrow { @@ -806,16 +806,16 @@ int32_t DecimalType::DecimalSize(int32_t precision) { // ---------------------------------------------------------------------- // Decimal type - -template +template BaseDecimalType::BaseDecimalType(int32_t precision, int32_t scale) : DecimalType(DecimalTypeTraits::Id, (width >> 3), precision, scale) { ARROW_CHECK_GE(precision, kMinPrecision); ARROW_CHECK_LE(precision, kMaxPrecision); } -template -Result> BaseDecimalType::Make(int32_t precision, int32_t scale) { +template +Result> BaseDecimalType::Make(int32_t precision, + int32_t scale) { if (precision < kMinPrecision || precision > kMaxPrecision) { return Status::Invalid("Decimal precision out of range: ", precision); } @@ -2212,7 +2212,7 @@ std::shared_ptr decimal256(int32_t precision, int32_t scale) { return std::make_shared(precision, scale); } -template +template std::string BaseDecimalType::ToString() const { std::stringstream s; s << type_name() << "(" << precision_ << ", " << scale_ << ")"; diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 86d1715e21172..edfcde60bc25a 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -30,10 +30,10 @@ #include "arrow/result.h" #include "arrow/type_fwd.h" // IWYU pragma: export #include "arrow/util/checked_cast.h" +#include "arrow/util/decimal_meta.h" #include "arrow/util/macros.h" #include "arrow/util/variant.h" #include "arrow/util/visibility.h" -#include "arrow/util/decimal_meta.h" #include "arrow/visitor.h" // IWYU pragma: keep namespace arrow { @@ -894,7 +894,7 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { }; /// \brief Template type class for decimal data -template +template class ARROW_EXPORT BaseDecimalType : public DecimalType { public: static constexpr const char* type_name() { return DecimalMeta::name; } @@ -915,35 +915,35 @@ class ARROW_EXPORT BaseDecimalType : public DecimalType { /// \brief Concrete type class for decimal 16-bit data class ARROW_EXPORT Decimal16Type : public BaseDecimalType<16> { -public: + public: static constexpr Type::type type_id = Type::DECIMAL16; using BaseDecimalType<16>::BaseDecimalType; }; /// \brief Concrete type class for decimal 32-bit data class ARROW_EXPORT Decimal32Type : public BaseDecimalType<32> { -public: + public: static constexpr Type::type type_id = Type::DECIMAL32; using BaseDecimalType<32>::BaseDecimalType; }; /// \brief Concrete type class for decimal 64-bit data class ARROW_EXPORT Decimal64Type : public BaseDecimalType<64> { -public: + public: static constexpr Type::type type_id = Type::DECIMAL64; using BaseDecimalType<64>::BaseDecimalType; }; /// \brief Concrete type class for decimal 128-bit data class ARROW_EXPORT Decimal128Type : public BaseDecimalType<128> { -public: + public: static constexpr Type::type type_id = Type::DECIMAL128; using BaseDecimalType<128>::BaseDecimalType; }; /// \brief Concrete type class for decimal 256-bit data class ARROW_EXPORT Decimal256Type : public BaseDecimalType<256> { -public: + public: static constexpr Type::type type_id = Type::DECIMAL256; using BaseDecimalType<256>::BaseDecimalType; }; diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 5f85ce462e96d..ca917cde7beba 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -145,16 +145,16 @@ struct StructScalar; class DecimalType; -template +template class DecimalAnyWidth; -template +template class BaseDecimalArray; -template +template class BaseDecimalBuilder; -template +template struct BaseDecimalScalar; using Decimal16 = DecimalAnyWidth<16>; @@ -163,11 +163,11 @@ using Decimal64 = DecimalAnyWidth<64>; class Decimal128; class Decimal256; -#define DECIMAL_DECL(width) \ -class Decimal##width##Type; \ -using Decimal##width##Array = BaseDecimalArray; \ -using Decimal##width##Builder = BaseDecimalBuilder; \ -using Decimal##width##Scalar = BaseDecimalScalar; +#define DECIMAL_DECL(width) \ + class Decimal##width##Type; \ + using Decimal##width##Array = BaseDecimalArray; \ + using Decimal##width##Builder = BaseDecimalBuilder; \ + using Decimal##width##Scalar = BaseDecimalScalar; DECIMAL_DECL(16) DECIMAL_DECL(32) @@ -177,7 +177,6 @@ DECIMAL_DECL(256) #undef DECIMAL_DECL - struct UnionMode { enum type { SPARSE, DENSE }; }; diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 8865175785536..b1312219af0de 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -284,15 +284,14 @@ struct TypeTraits { static inline std::shared_ptr type_singleton() { return float16(); } }; - -#define DECIMAL_TYPE_TRAITS_DEF(width) \ -template <> \ -struct TypeTraits { \ - using ArrayType = Decimal##width##Array; \ - using BuilderType = Decimal##width##Builder; \ - using ScalarType = Decimal##width##Scalar; \ - constexpr static bool is_parameter_free = false; \ -}; +#define DECIMAL_TYPE_TRAITS_DEF(width) \ + template <> \ + struct TypeTraits { \ + using ArrayType = Decimal##width##Array; \ + using BuilderType = Decimal##width##Builder; \ + using ScalarType = Decimal##width##Scalar; \ + constexpr static bool is_parameter_free = false; \ + }; DECIMAL_TYPE_TRAITS_DEF(16) DECIMAL_TYPE_TRAITS_DEF(32) diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index 50e9ffccd769a..ccf7af82c0827 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -28,10 +28,10 @@ #include #include "arrow/util/bit_util.h" +#include "arrow/util/decimal_meta.h" +#include "arrow/util/decimal_scale_multipliers.h" #include "arrow/util/int128_internal.h" #include "arrow/util/int_util_internal.h" -#include "arrow/util/decimal_scale_multipliers.h" -#include "arrow/util/decimal_meta.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" @@ -41,7 +41,6 @@ using internal::SafeLeftShift; using internal::SafeSignedAdd; using internal::SafeSignedMultiply; - static const BasicDecimal256 ScaleMultipliersDecimal256[] = { BasicDecimal256({1ULL, 0ULL, 0ULL, 0ULL}), BasicDecimal256({10ULL, 0ULL, 0ULL, 0ULL}), @@ -1119,59 +1118,67 @@ BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& ri /// BasicDecimalAnyWidth -template +template BasicDecimalAnyWidth::BasicDecimalAnyWidth(const uint8_t* bytes) { DCHECK_NE(bytes, nullptr); value = *(reinterpret_cast(bytes)); -}; +} -template -BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator+=(const BasicDecimalAnyWidth& right) { +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator+=( + const BasicDecimalAnyWidth& right) { value = SafeSignedAdd(value, right.value); return *this; } -template -BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator-=(const BasicDecimalAnyWidth& right) { +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator-=( + const BasicDecimalAnyWidth& right) { value -= right.value; return *this; } -template -BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator*=(const BasicDecimalAnyWidth& right) { - value = value * right.value; +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator*=( + const BasicDecimalAnyWidth& right) { + value = SafeSignedMultiply(value, right.value); return *this; } -template -BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator/=(const BasicDecimalAnyWidth& right) { +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator/=( + const BasicDecimalAnyWidth& right) { BasicDecimalAnyWidth remainder; auto s = Divide(right, this, &remainder); DCHECK_EQ(s, DecimalStatus::kSuccess); return *this; } -template -BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator%=(const BasicDecimalAnyWidth& right) { +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::operator%=( + const BasicDecimalAnyWidth& right) { BasicDecimalAnyWidth result; auto s = Divide(right, &result, this); DCHECK_EQ(s, DecimalStatus::kSuccess); return *this; } +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::Abs() { + return *this < 0 ? Negate() : *this; +} -template -BasicDecimalAnyWidth& BasicDecimalAnyWidth::Abs() { return *this < 0 ? Negate() : *this; } - -template -BasicDecimalAnyWidth BasicDecimalAnyWidth::Abs(const BasicDecimalAnyWidth& in) { +template +BasicDecimalAnyWidth BasicDecimalAnyWidth::Abs( + const BasicDecimalAnyWidth& in) { BasicDecimalAnyWidth result(in); return result.Abs(); } -template -DecimalStatus BasicDecimalAnyWidth::Divide(const BasicDecimalAnyWidth& divisor, BasicDecimalAnyWidth* result, - BasicDecimalAnyWidth* remainder) const { +template +DecimalStatus BasicDecimalAnyWidth::Divide(const BasicDecimalAnyWidth& divisor, + BasicDecimalAnyWidth* result, + BasicDecimalAnyWidth* remainder) const { if (divisor.value == 0) { return DecimalStatus::kDivideByZero; } @@ -1186,49 +1193,52 @@ DecimalStatus BasicDecimalAnyWidth::Divide(const BasicDecimalAnyWidth& di return DecimalStatus::kSuccess; } -template -BasicDecimalAnyWidth BasicDecimalAnyWidth::GetScaleMultiplier(int32_t scale) { +template +BasicDecimalAnyWidth BasicDecimalAnyWidth::GetScaleMultiplier( + int32_t scale) { DCHECK_GE(scale, 0); DCHECK_LE(scale, DecimalMeta::max_precision); return BasicDecimalAnyWidth(ScaleMultipliersAnyWidth::value[scale]); } -template +template std::array> 3)> BasicDecimalAnyWidth::ToBytes() const { std::array> 3)> out{{0}}; ToBytes(out.data()); return out; } -template +template void BasicDecimalAnyWidth::ToBytes(uint8_t* out) const { DCHECK_NE(out, nullptr); reinterpret_cast(out)[0] = value; } -template -BasicDecimalAnyWidth& BasicDecimalAnyWidth::Negate() { - value = - value; +template +BasicDecimalAnyWidth& BasicDecimalAnyWidth::Negate() { + value = -value; return *this; } -template -DecimalStatus BasicDecimalAnyWidth::Rescale(int32_t original_scale, int32_t new_scale, - BasicDecimalAnyWidth* out) const { +template +DecimalStatus BasicDecimalAnyWidth::Rescale( + int32_t original_scale, int32_t new_scale, BasicDecimalAnyWidth* out) const { return DecimalRescale(*this, original_scale, new_scale, out); } -template +template bool BasicDecimalAnyWidth::FitsInPrecision(int32_t precision) const { DCHECK_GT(precision, 0); DCHECK_LE(precision, DecimalMeta::max_precision); - return BasicDecimalAnyWidth::Abs(*this) < ScaleMultipliersAnyWidth::value[precision]; + return BasicDecimalAnyWidth::Abs(*this) < + ScaleMultipliersAnyWidth::value[precision]; } -template -void BasicDecimalAnyWidth::GetWholeAndFraction(int scale, BasicDecimalAnyWidth* whole, - BasicDecimalAnyWidth* fraction) const { +template +void BasicDecimalAnyWidth::GetWholeAndFraction( + int scale, BasicDecimalAnyWidth* whole, + BasicDecimalAnyWidth* fraction) const { DCHECK_GE(scale, 0); DCHECK_LE(scale, DecimalMeta::max_precision); @@ -1237,8 +1247,9 @@ void BasicDecimalAnyWidth::GetWholeAndFraction(int scale, BasicDecimalAny DCHECK_EQ(s, DecimalStatus::kSuccess); } -template -BasicDecimalAnyWidth BasicDecimalAnyWidth::IncreaseScaleBy(int32_t increase_by) const { +template +BasicDecimalAnyWidth BasicDecimalAnyWidth::IncreaseScaleBy( + int32_t increase_by) const { DCHECK_GE(increase_by, 0); DCHECK_LE(increase_by, DecimalMeta::max_precision); diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h index ac2011452ce12..6fc8b6f53d573 100644 --- a/cpp/src/arrow/util/basic_decimal.h +++ b/cpp/src/arrow/util/basic_decimal.h @@ -23,10 +23,10 @@ #include #include +#include "arrow/util/decimal_meta.h" #include "arrow/util/macros.h" #include "arrow/util/type_traits.h" #include "arrow/util/visibility.h" -#include "arrow/util/decimal_meta.h" namespace arrow { @@ -37,7 +37,7 @@ enum class DecimalStatus { kRescaleDataLoss, }; -template +template class BasicDecimalAnyWidth; /// Represents a signed 128-bit integer in two's complement. @@ -336,8 +336,7 @@ ARROW_EXPORT BasicDecimal256 operator*(const BasicDecimal256& left, ARROW_EXPORT BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& right); - -template +template class ARROW_EXPORT BasicDecimalAnyWidth { public: using ValueType = typename IntTypes::signed_type; @@ -346,13 +345,15 @@ class ARROW_EXPORT BasicDecimalAnyWidth { /// \brief Convert any integer value into a BasicDecimal. template ::value && - ((sizeof(T) < sizeof(ValueType)) || ((sizeof(T) == sizeof(ValueType)) && std::is_signed::value) - || std::is_same::value), T>::type> + typename = typename std::enable_if::value && + ((sizeof(T) < sizeof(ValueType)) || + ((sizeof(T) == sizeof(ValueType)) && + std::is_signed::value) || + std::is_same::value), + T>::type> constexpr BasicDecimalAnyWidth(T value) noexcept : value(static_cast(value)) {} - + /// \brief Upcast BasicDecimal with less widths template ::type> @@ -374,7 +375,7 @@ class ARROW_EXPORT BasicDecimalAnyWidth { DecimalStatus Divide(const BasicDecimalAnyWidth& divisor, BasicDecimalAnyWidth* result, BasicDecimalAnyWidth* remainder) const; - + /// \brief Scale multiplier for given scale value. static BasicDecimalAnyWidth GetScaleMultiplier(int32_t scale); @@ -399,15 +400,15 @@ class ARROW_EXPORT BasicDecimalAnyWidth { /// \brief separate the integer and fractional parts for the given scale. void GetWholeAndFraction(int32_t scale, BasicDecimalAnyWidth* whole, BasicDecimalAnyWidth* fraction) const; - + /// \brief Scale up. BasicDecimalAnyWidth IncreaseScaleBy(int32_t increase_by) const; - BasicDecimalAnyWidth& operator +=(const BasicDecimalAnyWidth&); - BasicDecimalAnyWidth& operator -=(const BasicDecimalAnyWidth&); - BasicDecimalAnyWidth& operator *=(const BasicDecimalAnyWidth&); - BasicDecimalAnyWidth& operator /=(const BasicDecimalAnyWidth&); - BasicDecimalAnyWidth& operator %=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator+=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator-=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator*=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator/=(const BasicDecimalAnyWidth&); + BasicDecimalAnyWidth& operator%=(const BasicDecimalAnyWidth&); friend bool operator==(const BasicDecimalAnyWidth& left, const BasicDecimalAnyWidth& right) { @@ -444,30 +445,31 @@ class ARROW_EXPORT BasicDecimalAnyWidth { BasicDecimalAnyWidth result(left); result += right; return result; - }; + } friend BasicDecimalAnyWidth operator-(const BasicDecimalAnyWidth& left, const BasicDecimalAnyWidth& right) { BasicDecimalAnyWidth result(left); result -= right; return result; - }; + } friend BasicDecimalAnyWidth operator*(const BasicDecimalAnyWidth& left, const BasicDecimalAnyWidth& right) { BasicDecimalAnyWidth result(left); result *= right; return result; - }; + } friend BasicDecimalAnyWidth operator/(const BasicDecimalAnyWidth& left, const BasicDecimalAnyWidth& right) { BasicDecimalAnyWidth result = left; result /= right; return result; - }; + } - friend BasicDecimalAnyWidth operator%(const BasicDecimalAnyWidth& left, const BasicDecimalAnyWidth& right) { + friend BasicDecimalAnyWidth operator%(const BasicDecimalAnyWidth& left, + const BasicDecimalAnyWidth& right) { BasicDecimalAnyWidth result = left; result %= right; return result; @@ -477,7 +479,6 @@ class ARROW_EXPORT BasicDecimalAnyWidth { ValueType value; }; - using BasicDecimal64 = BasicDecimalAnyWidth<64>; using BasicDecimal32 = BasicDecimalAnyWidth<32>; using BasicDecimal16 = BasicDecimalAnyWidth<16>; diff --git a/cpp/src/arrow/util/decimal.cc b/cpp/src/arrow/util/decimal.cc index 2de218a076b56..35566f274282d 100644 --- a/cpp/src/arrow/util/decimal.cc +++ b/cpp/src/arrow/util/decimal.cc @@ -464,8 +464,8 @@ inline Status ToArrowStatus(DecimalStatus dstatus, int num_bits) { } Status FromStringToArray(const util::string_view& s, DecimalComponents& dec, - uint64_t* out, int32_t array_size, - int32_t* precision, int32_t* scale) { + uint64_t* out, int32_t array_size, int32_t* precision, + int32_t* scale) { if (s.empty()) { return Status::Invalid("Empty string cannot be converted to decimal"); } @@ -497,8 +497,7 @@ Status FromStringToArray(const util::string_view& s, DecimalComponents& dec, if (out != nullptr) { ShiftAndAdd(dec.whole_digits, out, array_size); - ShiftAndAdd(dec.fractional_digits, out, - array_size); + ShiftAndAdd(dec.fractional_digits, out, array_size); } return Status::OK(); @@ -511,13 +510,15 @@ Status Decimal128::FromString(const util::string_view& s, Decimal128* out, std::array little_endian_array = {0, 0}; DecimalComponents dec; - auto status = FromStringToArray(s, dec, little_endian_array.data(), 2, precision, scale); + auto status = + FromStringToArray(s, dec, little_endian_array.data(), 2, precision, scale); if (status != Status::OK()) { return status; } if (out != nullptr) { - *out = Decimal128(static_cast(little_endian_array[1]), little_endian_array[0]); + *out = + Decimal128(static_cast(little_endian_array[1]), little_endian_array[0]); if (scale != nullptr && *scale < 0) { *out *= GetScaleMultiplier(-*scale); @@ -666,7 +667,8 @@ Status Decimal256::FromString(const util::string_view& s, Decimal256* out, std::array little_endian_array = {0, 0, 0, 0}; DecimalComponents dec; - auto status = FromStringToArray(s, dec, little_endian_array.data(), 4, precision, scale); + auto status = + FromStringToArray(s, dec, little_endian_array.data(), 4, precision, scale); if (status != Status::OK()) { return status; } @@ -756,32 +758,34 @@ std::ostream& operator<<(std::ostream& os, const Decimal256& decimal) { return os; } -template +template DecimalAnyWidth::DecimalAnyWidth(const std::string& str) : DecimalAnyWidth() { *this = DecimalAnyWidth::FromString(str).ValueOrDie(); } -template +template std::string DecimalAnyWidth::ToIntegerString() const { std::stringstream ss; ss << this->Value(); return ss.str(); } -template +template std::string DecimalAnyWidth::ToString(int32_t scale) const { std::string str(ToIntegerString()); AdjustIntegerStringWithScale(scale, &str); return str; } -template -Status DecimalAnyWidth::FromString(const util::string_view& s, DecimalAnyWidth* out, - int32_t* precision, int32_t* scale) { +template +Status DecimalAnyWidth::FromString(const util::string_view& s, + DecimalAnyWidth* out, int32_t* precision, + int32_t* scale) { std::array little_endian_array = {0}; DecimalComponents dec; - auto status = FromStringToArray(s, dec, little_endian_array.data(), 1, precision, scale); + auto status = + FromStringToArray(s, dec, little_endian_array.data(), 1, precision, scale); if (status != Status::OK()) { return status; } @@ -808,36 +812,40 @@ Status DecimalAnyWidth::FromString(const util::string_view& s, DecimalAny return status; } -template -Status DecimalAnyWidth::FromString(const std::string& s, DecimalAnyWidth* out, int32_t* precision, - int32_t* scale) { +template +Status DecimalAnyWidth::FromString(const std::string& s, + DecimalAnyWidth* out, int32_t* precision, + int32_t* scale) { return FromString(util::string_view(s), out, precision, scale); } -template -Status DecimalAnyWidth::FromString(const char* s, DecimalAnyWidth* out, int32_t* precision, - int32_t* scale) { +template +Status DecimalAnyWidth::FromString(const char* s, DecimalAnyWidth* out, + int32_t* precision, int32_t* scale) { return FromString(util::string_view(s), out, precision, scale); } -template -Result::_DecimalType> DecimalAnyWidth::FromString(const util::string_view& s) { +template +Result::_DecimalType> DecimalAnyWidth::FromString( + const util::string_view& s) { _DecimalType out; RETURN_NOT_OK(FromString(s, &out, nullptr, nullptr)); return std::move(out); } -template -Result::_DecimalType> DecimalAnyWidth::FromString(const std::string& s) { +template +Result::_DecimalType> DecimalAnyWidth::FromString( + const std::string& s) { return FromString(util::string_view(s)); } -template -Result::_DecimalType> DecimalAnyWidth::FromString(const char* s) { +template +Result::_DecimalType> DecimalAnyWidth::FromString( + const char* s) { return FromString(util::string_view(s)); } -template +template Status DecimalAnyWidth::ToArrowStatus(DecimalStatus dstatus) const { return arrow::ToArrowStatus(dstatus, width); } diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index afc089b21333a..0f8f8d3060213 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -258,11 +258,9 @@ class ARROW_EXPORT Decimal256 : public BasicDecimal256 { Status ToArrowStatus(DecimalStatus dstatus) const; }; - -template +template class ARROW_EXPORT DecimalAnyWidth : public BasicDecimalAnyWidth { - public: - + public: using _DecimalType = typename DecimalTypeTraits::ValueType; using ValueType = typename BasicDecimalAnyWidth::ValueType; @@ -272,7 +270,8 @@ class ARROW_EXPORT DecimalAnyWidth : public BasicDecimalAnyWidth { /// \endcond /// \brief constructor creates a Decimal256 from a BasicDecimal128. - constexpr DecimalAnyWidth(const BasicDecimalAnyWidth& value) noexcept : BasicDecimalAnyWidth(value) {} + constexpr DecimalAnyWidth(const BasicDecimalAnyWidth& value) noexcept + : BasicDecimalAnyWidth(value) {} /// \brief Parse the number from a base 10 string representation. explicit DecimalAnyWidth(const std::string& value); @@ -308,8 +307,7 @@ class ARROW_EXPORT DecimalAnyWidth : public BasicDecimalAnyWidth { return std::move(out); } - friend std::ostream& operator<<(std::ostream& os, - const DecimalAnyWidth& decimal) { + friend std::ostream& operator<<(std::ostream& os, const DecimalAnyWidth& decimal) { os << decimal.ToIntegerString(); return os; } diff --git a/cpp/src/arrow/util/decimal_meta.h b/cpp/src/arrow/util/decimal_meta.h index 59bb1d2206b23..727890926fdef 100644 --- a/cpp/src/arrow/util/decimal_meta.h +++ b/cpp/src/arrow/util/decimal_meta.h @@ -19,48 +19,48 @@ namespace arrow { -template +template struct IntTypes {}; -#define IntTypes_DECL(bit_width) \ -template<> \ -struct IntTypes{ \ - using signed_type = int##bit_width##_t; \ - using unsigned_type = uint##bit_width##_t; \ -}; +#define IntTypes_DECL(bit_width) \ + template <> \ + struct IntTypes { \ + using signed_type = int##bit_width##_t; \ + using unsigned_type = uint##bit_width##_t; \ + }; IntTypes_DECL(64); IntTypes_DECL(32); IntTypes_DECL(16); -template +template struct DecimalMeta; -template<> +template <> struct DecimalMeta<16> { static constexpr const char* name = "decimal16"; static constexpr int32_t max_precision = 5; }; -template<> +template <> struct DecimalMeta<32> { static constexpr const char* name = "decimal32"; static constexpr int32_t max_precision = 10; }; -template<> +template <> struct DecimalMeta<64> { static constexpr const char* name = "decimal64"; static constexpr int32_t max_precision = 19; }; -template<> +template <> struct DecimalMeta<128> { static constexpr const char* name = "decimal"; static constexpr int32_t max_precision = 38; }; -template<> +template <> struct DecimalMeta<256> { static constexpr const char* name = "decimal256"; static constexpr int32_t max_precision = 76; diff --git a/cpp/src/arrow/util/decimal_scale_multipliers.h b/cpp/src/arrow/util/decimal_scale_multipliers.h index 23067bc860a02..0efb9d8784f97 100644 --- a/cpp/src/arrow/util/decimal_scale_multipliers.h +++ b/cpp/src/arrow/util/decimal_scale_multipliers.h @@ -21,35 +21,34 @@ namespace arrow { -template +template struct ScaleMultipliersAnyWidth {}; #define DECL_ANY_SCALE_MULTIPLIERS(width) \ -template<> \ -struct ScaleMultipliersAnyWidth { \ - static const int##width##_t value[]; \ -}; \ -const int##width##_t ScaleMultipliersAnyWidth::value[] = { \ - int##width##_t(1LL), \ - int##width##_t(10LL), \ - int##width##_t(100LL), \ - int##width##_t(1000LL), \ - int##width##_t(10000LL), \ - int##width##_t(100000LL), \ - int##width##_t(1000000LL), \ - int##width##_t(10000000LL), \ - int##width##_t(100000000LL), \ - int##width##_t(1000000000LL), \ - int##width##_t(10000000000LL), \ - int##width##_t(100000000000LL), \ - int##width##_t(1000000000000LL), \ - int##width##_t(10000000000000LL), \ - int##width##_t(100000000000000LL), \ - int##width##_t(1000000000000000LL), \ - int##width##_t(10000000000000000LL), \ - int##width##_t(100000000000000000LL), \ - int##width##_t(1000000000000000000LL) \ -}; + template <> \ + struct ScaleMultipliersAnyWidth { \ + static const int##width##_t value[]; \ + }; \ + const int##width##_t ScaleMultipliersAnyWidth::value[] = { \ + int##width##_t(1LL), \ + int##width##_t(10LL), \ + int##width##_t(100LL), \ + int##width##_t(1000LL), \ + int##width##_t(10000LL), \ + int##width##_t(100000LL), \ + int##width##_t(1000000LL), \ + int##width##_t(10000000LL), \ + int##width##_t(100000000LL), \ + int##width##_t(1000000000LL), \ + int##width##_t(10000000000LL), \ + int##width##_t(100000000000LL), \ + int##width##_t(1000000000000LL), \ + int##width##_t(10000000000000LL), \ + int##width##_t(100000000000000LL), \ + int##width##_t(1000000000000000LL), \ + int##width##_t(10000000000000000LL), \ + int##width##_t(100000000000000000LL), \ + int##width##_t(1000000000000000000LL)}; DECL_ANY_SCALE_MULTIPLIERS(16) DECL_ANY_SCALE_MULTIPLIERS(32) @@ -98,7 +97,6 @@ static const BasicDecimal128 ScaleMultipliers128[] = { BasicDecimal128(542101086242752217LL, 68739955140067328ULL), BasicDecimal128(5421010862427522170LL, 687399551400673280ULL)}; - static const BasicDecimal128 ScaleMultipliersHalf128[] = { BasicDecimal128(0ULL), BasicDecimal128(5ULL), @@ -140,4 +138,4 @@ static const BasicDecimal128 ScaleMultipliersHalf128[] = { BasicDecimal128(271050543121376108LL, 9257742014424809472ULL), BasicDecimal128(2710505431213761085LL, 343699775700336640ULL)}; -} // namespace arrow +} // namespace arrow diff --git a/cpp/src/arrow/util/decimal_test.cc b/cpp/src/arrow/util/decimal_test.cc index e58f4912f3eae..2065472a99385 100644 --- a/cpp/src/arrow/util/decimal_test.cc +++ b/cpp/src/arrow/util/decimal_test.cc @@ -19,12 +19,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include @@ -1567,20 +1567,19 @@ TEST_P(Decimal256ToStringTest, ToString) { INSTANTIATE_TEST_SUITE_P(Decimal256ToStringTest, Decimal256ToStringTest, ::testing::ValuesIn(kToStringTestData)); - // DecimalAnyWidth template -class DecimalAnyWidthTest : public ::testing::Test { }; +class DecimalAnyWidthTest : public ::testing::Test {}; template -class Decimal16Test : public ::testing::Test { }; +class Decimal16Test : public ::testing::Test {}; template -class Decimal32Test : public ::testing::Test { }; +class Decimal32Test : public ::testing::Test {}; template -class Decimal64Test : public ::testing::Test { }; +class Decimal64Test : public ::testing::Test {}; using DecimalTypes = ::testing::Types; @@ -1592,12 +1591,11 @@ struct DecimalFromStringParams { }; static const std::vector DecimalFromStringParamsList = { - {"1234", 1234, 0, 4}, - {"12.34", 1234, 2, 4}, - {"+12.34", 1234, 2, 4}, - {"-12.34", -1234, 2, 4}, - {".0000", 0, 4, 4} -}; + {"1234", 1234, 0, 4}, + {"12.34", 1234, 2, 4}, + {"+12.34", 1234, 2, 4}, + {"-12.34", -1234, 2, 4}, + {".0000", 0, 4, 4}}; TYPED_TEST_SUITE(DecimalAnyWidthTest, DecimalTypes); @@ -1619,16 +1617,14 @@ TYPED_TEST(DecimalAnyWidthTest, FromBool) { ASSERT_EQ(TypeParam(1), TypeParam(true)); } -using Decimal16Types = - ::testing::Types; +using Decimal16Types = ::testing::Types; // NOLINT using Decimal32Types = - ::testing::Types; + ::testing::Types; // NOLINT using Decimal64Types = - ::testing::Types; + ::testing::Types; // NOLINT TYPED_TEST_SUITE(Decimal16Test, Decimal16Types); @@ -1647,8 +1643,7 @@ TYPED_TEST(Decimal16Test, Decimal16Types) { Decimal16 min_value_d(min_value); ASSERT_EQ(static_cast(min_value), min_value_d); - } - else { + } else { Decimal16 max_value_d(max_value); ASSERT_EQ(max_value, max_value_d); @@ -1691,12 +1686,12 @@ TYPED_TEST(Decimal64Test, Decimal64Types) { ASSERT_EQ(min_value, min_value_d); } -static const std::vector DecimalAnyWidthValues = { -2, -1, 0, 1, 2}; +static const std::vector DecimalAnyWidthValues = {-2, -1, 0, 1, 2}; TYPED_TEST(DecimalAnyWidthTest, ComparatorTest) { - for (size_t i=0; i +template struct DecimalAnyWidthBinaryParams { static const std::vector>> value; }; -template -const std::vector>> DecimalAnyWidthBinaryParams::value = { - {"+", [](T x, T y) -> T { return x + y;} }, - {"-", [](T x, T y) -> T { return x - y;} }, - {"*", [](T x, T y) -> T { return x * y;} }, - {"/", [](T x, T y) -> T { return y == 0? 0 : x / y;} }, - {"%", [](T x, T y) -> T { return y == 0? 0 : x % y;} }, +template +const std::vector>> + DecimalAnyWidthBinaryParams::value = { + {"+", [](T x, T y) -> T { return x + y; }}, + {"-", [](T x, T y) -> T { return x - y; }}, + {"*", [](T x, T y) -> T { return x * y; }}, + {"/", [](T x, T y) -> T { return y == 0 ? 0 : x / y; }}, + {"%", [](T x, T y) -> T { return y == 0 ? 0 : x % y; }}, }; TYPED_TEST(DecimalAnyWidthTest, BinaryOperations) { - using ValueType = typename arrow::DecimalAnyWidthTest_BinaryOperations_Test::TypeParam::ValueType; + using ValueType = typename arrow::DecimalAnyWidthTest_BinaryOperations_Test< + gtest_TypeParam_>::TypeParam::ValueType; using ArrowValueType = typename arrow::CTypeTraits::ArrowType; auto DecimalFns = DecimalAnyWidthBinaryParams::value; auto NumericFns = DecimalAnyWidthBinaryParams::value; - for (size_t i = 0; i < DecimalFns.size(); i++){ + for (size_t i = 0; i < DecimalFns.size(); i++) { for (auto x : GetRandomNumbers(8)) { for (auto y : GetRandomNumbers(8)) { TypeParam d1(x), d2(y); auto result = DecimalFns[i].second(d1, d2); auto reference = static_cast(NumericFns[i].second(x, y)); ASSERT_EQ(reference, result) - << d1 << " " << DecimalFns[i].first << " " << d2 << " " << " != " << result; + << d1 << " " << DecimalFns[i].first << " " << d2 << " " + << " != " << result; } } } diff --git a/cpp/src/arrow/util/decimal_type_traits.h b/cpp/src/arrow/util/decimal_type_traits.h index df06fa9cb2549..e3202d5c3ca70 100644 --- a/cpp/src/arrow/util/decimal_type_traits.h +++ b/cpp/src/arrow/util/decimal_type_traits.h @@ -21,19 +21,19 @@ namespace arrow { -template +template struct DecimalTypeTraits; -#define DECIMAL_TYPE_TRAITS_DECL(width) \ -template<> \ -struct DecimalTypeTraits { \ - static constexpr Type::type Id = Type::DECIMAL##width; \ - using ArrayType = Decimal##width##Array; \ - using BuilderType = Decimal##width##Builder; \ - using ScalarType = Decimal##width##Scalar; \ - using TypeClass = Decimal##width##Type; \ - using ValueType = Decimal##width; \ -}; +#define DECIMAL_TYPE_TRAITS_DECL(width) \ + template <> \ + struct DecimalTypeTraits { \ + static constexpr Type::type Id = Type::DECIMAL##width; \ + using ArrayType = Decimal##width##Array; \ + using BuilderType = Decimal##width##Builder; \ + using ScalarType = Decimal##width##Scalar; \ + using TypeClass = Decimal##width##Type; \ + using ValueType = Decimal##width; \ + }; DECIMAL_TYPE_TRAITS_DECL(16) DECIMAL_TYPE_TRAITS_DECL(32) From b22cdfd113685a20517124e2a06701ac36d83b0a Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Fri, 15 Jan 2021 20:54:16 +0300 Subject: [PATCH 8/8] Tests fixed Signed-off-by: Dmitry Chigarev --- cpp/src/arrow/array/array_decimal.h | 2 +- cpp/src/arrow/array/array_dict_test.cc | 6 ------ cpp/src/arrow/array/builder_decimal.h | 2 +- cpp/src/arrow/util/basic_decimal.cc | 7 +++++-- cpp/src/arrow/util/basic_decimal.h | 20 ++++++++++++++++---- cpp/src/arrow/util/decimal.h | 12 ------------ cpp/src/arrow/util/decimal_test.cc | 14 +++++++------- 7 files changed, 30 insertions(+), 33 deletions(-) diff --git a/cpp/src/arrow/array/array_decimal.h b/cpp/src/arrow/array/array_decimal.h index 7a35da7fa7aa2..0f8d8101912f8 100644 --- a/cpp/src/arrow/array/array_decimal.h +++ b/cpp/src/arrow/array/array_decimal.h @@ -45,6 +45,6 @@ class ARROW_EXPORT BaseDecimalArray : public FixedSizeBinaryArray { }; // Backward compatibility -using DecimalArray = Decimal128Array; +using DecimalArray = BaseDecimalArray<128>; } // namespace arrow diff --git a/cpp/src/arrow/array/array_dict_test.cc b/cpp/src/arrow/array/array_dict_test.cc index 498fbd8a812fb..2bc36a82d1123 100644 --- a/cpp/src/arrow/array/array_dict_test.cc +++ b/cpp/src/arrow/array/array_dict_test.cc @@ -933,12 +933,6 @@ void TestDecimalDictionaryBuilderDoubleTableSize( ASSERT_TRUE(expected.Equals(result)); } -// TEST(TestDecimal64DictionaryBuilder, DoubleTableSize) { -// const auto& decimal_type = arrow::decimal64(18, 0); -// Decimal64Builder decimal_builder(decimal_type); -// TestDecimalDictionaryBuilderDoubleTableSize(decimal_type, decimal_builder); -// } - TEST(TestDecimal128DictionaryBuilder, DoubleTableSize) { const auto& decimal_type = arrow::decimal128(21, 0); Decimal128Builder decimal_builder(decimal_type); diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index ebad9127d86d6..a04dfd415fdaf 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -63,6 +63,6 @@ class ARROW_EXPORT BaseDecimalBuilder : public FixedSizeBinaryBuilder { }; // Backward compatibility -using DecimalBuilder = Decimal128Builder; +using DecimalBuilder = BaseDecimalBuilder<128>; } // namespace arrow diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index ccf7af82c0827..e1ca749e0c27d 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -1186,8 +1186,11 @@ DecimalStatus BasicDecimalAnyWidth::Divide(const BasicDecimalAnyWidth& di bool dividen_was_negative = Sign() == -1; bool divisor_was_negative = divisor.Sign() == -1; - *result = value / divisor.value; - *remainder = value % divisor.value; + BasicDecimalAnyWidth dividen_abs = BasicDecimalAnyWidth::Abs(*this); + BasicDecimalAnyWidth divisor_abs = BasicDecimalAnyWidth::Abs(divisor); + + *result = dividen_abs.value / divisor_abs.value; + *remainder = dividen_abs.value % divisor_abs.value; FixDivisionSigns(result, remainder, dividen_was_negative, divisor_was_negative); return DecimalStatus::kSuccess; diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h index 6fc8b6f53d573..2e458c46b699b 100644 --- a/cpp/src/arrow/util/basic_decimal.h +++ b/cpp/src/arrow/util/basic_decimal.h @@ -339,6 +339,7 @@ ARROW_EXPORT BasicDecimal256 operator/(const BasicDecimal256& left, template class ARROW_EXPORT BasicDecimalAnyWidth { public: + static constexpr int bit_width = width; using ValueType = typename IntTypes::signed_type; /// \brief Empty constructor creates a BasicDecimal with a value of 0. constexpr BasicDecimalAnyWidth() noexcept : value(0) {} @@ -354,16 +355,16 @@ class ARROW_EXPORT BasicDecimalAnyWidth { constexpr BasicDecimalAnyWidth(T value) noexcept : value(static_cast(value)) {} + /// \brief Create a BasicDecimal from an array of bytes. Bytes are assumed to be in + /// native-endian byte order. + explicit BasicDecimalAnyWidth(const uint8_t* bytes); + /// \brief Upcast BasicDecimal with less widths template ::type> constexpr BasicDecimalAnyWidth(const BasicDecimalAnyWidth<_width>& other) noexcept : value(static_cast(other.Value())) {} - /// \brief Create a BasicDecimal from an array of bytes. Bytes are assumed to be in - /// native-endian byte order. - explicit BasicDecimalAnyWidth(const uint8_t* bytes); - /// \brief Negate the current value (in-place) BasicDecimalAnyWidth& Negate(); @@ -373,6 +374,17 @@ class ARROW_EXPORT BasicDecimalAnyWidth { /// \brief Absolute value static BasicDecimalAnyWidth Abs(const BasicDecimalAnyWidth& left); + /// Divide this number by right and return the result. + /// + /// This operation is not destructive. + /// The answer rounds to zero. Signs work like: + /// 21 / 5 -> 4, 1 + /// -21 / 5 -> -4, -1 + /// 21 / -5 -> -4, 1 + /// -21 / -5 -> 4, -1 + /// \param[in] divisor the number to divide by + /// \param[out] result the quotient + /// \param[out] remainder the remainder after the division DecimalStatus Divide(const BasicDecimalAnyWidth& divisor, BasicDecimalAnyWidth* result, BasicDecimalAnyWidth* remainder) const; diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index 0f8f8d3060213..e8cec7f1f68ab 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -317,16 +317,4 @@ class ARROW_EXPORT DecimalAnyWidth : public BasicDecimalAnyWidth { Status ToArrowStatus(DecimalStatus dstatus) const; }; -// class ARROW_EXPORT Decimal16 : DecimalAnyWidth<16> { -// using DecimalAnyWidth<16>::DecimalAnyWidth; -// }; - -// class ARROW_EXPORT Decimal32 : DecimalAnyWidth<32> { -// using DecimalAnyWidth<32>::DecimalAnyWidth; -// }; - -// class ARROW_EXPORT Decimal64 : DecimalAnyWidth<64> { -// using DecimalAnyWidth<64>::DecimalAnyWidth; -// }; - } // namespace arrow diff --git a/cpp/src/arrow/util/decimal_test.cc b/cpp/src/arrow/util/decimal_test.cc index 2065472a99385..95bb3ac0c8a20 100644 --- a/cpp/src/arrow/util/decimal_test.cc +++ b/cpp/src/arrow/util/decimal_test.cc @@ -1729,20 +1729,20 @@ const std::vector>> TYPED_TEST(DecimalAnyWidthTest, BinaryOperations) { using ValueType = typename arrow::DecimalAnyWidthTest_BinaryOperations_Test< gtest_TypeParam_>::TypeParam::ValueType; - using ArrowValueType = typename arrow::CTypeTraits::ArrowType; auto DecimalFns = DecimalAnyWidthBinaryParams::value; auto NumericFns = DecimalAnyWidthBinaryParams::value; for (size_t i = 0; i < DecimalFns.size(); i++) { - for (auto x : GetRandomNumbers(8)) { - for (auto y : GetRandomNumbers(8)) { + for (ValueType x : GetRandomNumbers(8)) { + for (ValueType y : GetRandomNumbers(8)) { TypeParam d1(x), d2(y); - auto result = DecimalFns[i].second(d1, d2); - auto reference = static_cast(NumericFns[i].second(x, y)); + TypeParam result = DecimalFns[i].second(d1, d2); + ValueType reference = static_cast(NumericFns[i].second(x, y)); ASSERT_EQ(reference, result) - << d1 << " " << DecimalFns[i].first << " " << d2 << " " - << " != " << result; + << "(" << x << " " << DecimalFns[i].first << " " << y << " = " << reference + << ") != (" << d1 << " " << DecimalFns[i].first << " " << d2 << " = " + << result << ")"; } } }