Skip to content

Commit

Permalink
vec_1
Browse files Browse the repository at this point in the history
  • Loading branch information
jngrad committed Jun 14, 2024
1 parent e3f0797 commit 5b1cee6
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions src/utils/include/utils/Vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <initializer_list>
#include <iterator>
#include <numeric>
#include <span>
#include <type_traits>
#include <vector>

Expand Down Expand Up @@ -79,11 +80,21 @@ template <typename T, std::size_t N> class Vector : public Array<T, N> {

public:
template <class Range>
explicit Vector(Range const &rng) : Vector(std::begin(rng), std::end(rng)) {}
explicit Vector(Range const &rng)
: Vector(std::begin(rng), std::end(rng)) {}
explicit constexpr Vector(T const (&v)[N]) : Base() {
copy_init(std::begin(v), std::end(v));
}

explicit constexpr Vector(std::span<T> span) : Base() {
if (span.size() != N) {
throw std::length_error(
"Construction of Vector from Container of wrong length.");
}

copy_init(span.begin(), span.end());
}

constexpr Vector(std::initializer_list<T> v) : Base() {
if (N != v.size()) {
throw std::length_error(
Expand All @@ -103,14 +114,13 @@ template <typename T, std::size_t N> class Vector : public Array<T, N> {
}
}

/**
* @brief Create a vector that has all entries set to
* one value.
*/
static Vector<T, N> broadcast(T const &s) {
Vector<T, N> ret;
std::fill(ret.begin(), ret.end(), s);

/** @brief Create a vector that has all entries set to the same value. */
DEVICE_QUALIFIER static constexpr Vector<T, N>
broadcast(typename Base::value_type const &value) {
Vector<T, N> ret{};
for (std::size_t i = 0u; i != N; ++i) {
ret[i] = value;
}
return ret;
}

Expand All @@ -122,7 +132,7 @@ template <typename T, std::size_t N> class Vector : public Array<T, N> {
Vector<U, N> ret;

std::transform(begin(), end(), ret.begin(),
[](auto e) { return static_cast<U>(e); });
[](auto const &e) { return static_cast<U>(e); });

return ret;
}
Expand All @@ -140,7 +150,7 @@ template <typename T, std::size_t N> class Vector : public Array<T, N> {
Vector &normalize() {
auto const l = norm();
if (l > T(0)) {
for (std::size_t i = 0; i < N; i++)
for (std::size_t i = 0u; i < N; ++i)
this->operator[](i) /= l;
}

Expand Down Expand Up @@ -247,8 +257,7 @@ template <std::size_t N, typename T>
Vector<T, N> operator-(Vector<T, N> const &a) {
Vector<T, N> ret;

std::transform(std::begin(a), std::end(a), std::begin(ret),
[](T const &v) { return -v; });
std::transform(std::begin(a), std::end(a), std::begin(ret), std::negate<T>());

return ret;
}
Expand Down Expand Up @@ -300,6 +309,15 @@ Vector<T, N> operator/(Vector<T, N> const &a, T const &b) {
return ret;
}

template <std::size_t N, typename T>
Vector<T, N> operator/(T const &a, Vector<T, N> const &b) {
Vector<T, N> ret;

std::transform(std::begin(b), std::end(b), ret.begin(),
[a](T const &val) { return a / val; });
return ret;
}

template <std::size_t N, typename T>
Vector<T, N> &operator/=(Vector<T, N> &a, T const &b) {
std::transform(std::begin(a), std::end(a), std::begin(a),
Expand Down Expand Up @@ -367,7 +385,7 @@ auto hadamard_product(Vector<T, N> const &a, Vector<U, N> const &b) {

Vector<R, N> ret;
std::transform(a.cbegin(), a.cend(), b.cbegin(), ret.begin(),
[](auto ai, auto bi) { return ai * bi; });
[](auto const &ai, auto const &bi) { return ai * bi; });

return ret;
}
Expand Down Expand Up @@ -410,7 +428,7 @@ auto hadamard_division(Vector<T, N> const &a, Vector<U, N> const &b) {

Vector<R, N> ret;
std::transform(a.cbegin(), a.cend(), b.cbegin(), ret.begin(),
[](auto ai, auto bi) { return ai / bi; });
[](auto const &ai, auto const &bi) { return ai / bi; });

return ret;
}
Expand Down Expand Up @@ -448,11 +466,11 @@ auto hadamard_division(T const &a, U const &b) {
}

template <typename T> Vector<T, 3> unit_vector(unsigned int i) {
if (i == 0)
if (i == 0u)
return {T{1}, T{0}, T{0}};
if (i == 1)
if (i == 1u)
return {T{0}, T{1}, T{0}};
if (i == 2)
if (i == 2u)
return {T{0}, T{0}, T{1}};
throw std::domain_error("coordinate out of range");
}
Expand Down

0 comments on commit 5b1cee6

Please sign in to comment.