Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Exp] Add GetMaxDepth #11970

Merged
merged 7 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions kratos/expression/binary_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ const std::vector<std::size_t> BinaryExpression<TOperationType>::GetItemShape()
return this->mpLeft->GetItemShape();
}

template <class TOperationType>
std::size_t BinaryExpression<TOperationType>::GetMaxDepth() const
{
return std::max(this->mpLeft->GetMaxDepth(), this->mpRight->GetMaxDepth()) + 1;
}

template <class TOperationType>
std::string BinaryExpression<TOperationType>::Info() const
{
Expand Down
2 changes: 2 additions & 0 deletions kratos/expression/binary_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class KRATOS_API(KRATOS_CORE) BinaryExpression : public Expression {

const std::vector<IndexType> GetItemShape() const override;

IndexType GetMaxDepth() const override;

std::string Info() const override;

///@}
Expand Down
6 changes: 6 additions & 0 deletions kratos/expression/container_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,12 @@ const TContainerType& ContainerExpression<TContainerType, TMeshType>::GetContain
return ContainerExpressionHelperUtilities::GetContainer<TContainerType>(ContainerExpressionHelperUtilities::GetMesh<TMeshType>(*mpModelPart));
}

template <class TContainerType, MeshType TMeshType>
std::size_t ContainerExpression<TContainerType, TMeshType>::GetMaxDepth() const
{
return this->GetExpression().GetMaxDepth();
}

template <class TContainerType, MeshType TMeshType>
std::string ContainerExpression<TContainerType, TMeshType>::Info() const
{
Expand Down
9 changes: 9 additions & 0 deletions kratos/expression/container_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ class KRATOS_API(KRATOS_CORE) ContainerExpression {
*/
const TContainerType& GetContainer() const;

/**
* @brief Get the Max Depth of the lazy expression tree.
*
* Returns the maximum depth of the lazy expression tree.
*
* @return IndexType Max depth of the lazy expression tree.
*/
IndexType GetMaxDepth() const;

/**
* @brief Get the info string
*
Expand Down
12 changes: 12 additions & 0 deletions kratos/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ class KRATOS_API(KRATOS_CORE) Expression {
*/
IndexType GetItemComponentCount() const;

/**
* @brief Get the Max Depth of the lazy expression tree.
*
* Returns the maximum depth of the lazy expression tree.
*
* @warning This is a recursive computation, hence this should not
* be done repeatedly unless necessary.
*
* @return IndexType Max depth of the lazy expression tree.
*/
virtual IndexType GetMaxDepth() const = 0;

///@}
///@name Input and output
///@{
Expand Down
2 changes: 2 additions & 0 deletions kratos/expression/literal_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class KRATOS_API(KRATOS_CORE) LiteralExpression : public Expression {

const std::vector<IndexType> GetItemShape() const override;

IndexType GetMaxDepth() const override { return 1; }

std::string Info() const override;

///@}
Expand Down
2 changes: 2 additions & 0 deletions kratos/expression/literal_flat_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class KRATOS_API(KRATOS_CORE) LiteralFlatExpression : public Expression {

const std::vector<IndexType> GetItemShape() const override;

IndexType GetMaxDepth() const override { return 1; }

IndexType size() const noexcept { return mData.size(); }

iterator begin() noexcept { return mData.begin(); }
Expand Down
9 changes: 9 additions & 0 deletions kratos/expression/unary_combine_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ class UnaryCombineExpression : public Expression {
return mStrides.size() == 1 && mStrides.back() == 1 ? std::vector<IndexType> {} : std::vector<IndexType> {mStrides.back()};
}

IndexType GetMaxDepth() const override
{
IndexType max_depth = 0;
for (const auto& p_expression : mSourceExpressions) {
max_depth = std::max(max_depth, p_expression->GetMaxDepth());
}
return max_depth + 1;
}

std::string Info() const override
{
std::stringstream msg;
Expand Down
5 changes: 5 additions & 0 deletions kratos/expression/unary_reshape_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class UnaryReshapeExpression : public Expression {
return mShape;
}

IndexType GetMaxDepth() const override
{
return mpSourceExpression->GetMaxDepth() + 1;
}

std::string Info() const override
{
std::stringstream msg;
Expand Down
5 changes: 5 additions & 0 deletions kratos/expression/unary_slice_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ const std::vector<std::size_t> UnarySliceExpression::GetItemShape() const
}
}

std::size_t UnarySliceExpression::GetMaxDepth() const
{
return mpSourceExpression->GetMaxDepth() + 1;
}

std::string UnarySliceExpression::Info() const
{
std::stringstream msg;
Expand Down
2 changes: 2 additions & 0 deletions kratos/expression/unary_slice_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class KRATOS_API(KRATOS_CORE) UnarySliceExpression : public Expression {

const std::vector<IndexType> GetItemShape() const override;

IndexType GetMaxDepth() const override;

std::string Info() const override;

///@}
Expand Down
1 change: 1 addition & 0 deletions kratos/python/add_container_expression_to_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ void AddContainerExpressionToPython(pybind11::module& m, const std::string& rNam
.def("GetContainer", py::overload_cast<>(&container_expression_holder_base::GetContainer), py::return_value_policy::reference)
.def("GetItemShape", &container_expression_holder_base::GetItemShape)
.def("GetItemComponentCount", &container_expression_holder_base::GetItemComponentCount)
.def("GetMaxDepth", &container_expression_holder_base::GetMaxDepth)
.def("Slice",
&container_expression_holder_base::Slice,
py::arg("offset"),
Expand Down
10 changes: 10 additions & 0 deletions kratos/python/add_expression_io_to_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ class ExpressionTrampoline final : public Expression
);
}

IndexType GetMaxDepth() const override
{
PYBIND11_OVERRIDE_PURE(
IndexType, /*return type*/
Expression, /*base type*/
GetMaxDepth /*function name*/
);
}

std::string Info() const override
{
PYBIND11_OVERRIDE_PURE(
Expand Down Expand Up @@ -209,6 +218,7 @@ void AddExpressionIOToPython(pybind11::module& rModule)
.def("GetItemShape", &Expression::GetItemShape)
.def("NumberOfEntities", &Expression::NumberOfEntities)
.def("GetItemComponentCount", &Expression::GetItemComponentCount)
.def("GetMaxDepth", &Expression::GetMaxDepth)
.def("__add__", [](Expression::Pointer pLeft, double Right) {return pLeft + Right;})
//.def("__add__", [](double Left, Expression::Pointer pRight) {return Left + pRight;})
.def("__add__", [](Expression::Pointer pLeft, Expression::Pointer pRight) {return pLeft + pRight;})
Expand Down
16 changes: 15 additions & 1 deletion kratos/tests/test_container_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import KratosMultiphysics.KratosUnittest as kratos_unittest

class TestContainerExpression(ABC):
ExpressionUnionType = Union[Kratos.Expression.NodalExpression, Kratos.Expression.ConditionExpression, Kratos.Expression.ElementExpression]
@classmethod
def CreateEntities(cls):
cls.model = Kratos.Model()
Expand Down Expand Up @@ -665,8 +666,21 @@ def test_VectorWrite(self):
velocity = Kratos.Array3([vector[i * 3], vector[i * 3 + 1], vector[i * 3 + 2]])
self.assertVectorAlmostEqual(velocity, self._GetValue(entity, Kratos.VELOCITY))

def test_GetMaxDepth(self):
a = self._GetContainerExpression()
self._Read(a, Kratos.GREEN_LAGRANGE_STRAIN_TENSOR)
b = a + 10
c = b * 2 + a
d = c ** 2
e = d.Comb([a, d])
f = e.Reshape([6, 2])
g = f.Slice(2, 4)
h = g.Reshape([2, 2])
i = h - a
self.assertEqual(i.GetMaxDepth(), 10)

@abstractmethod
def _GetContainerExpression(self) -> Union[Kratos.Expression.NodalExpression, Kratos.Expression.ElementExpression, Kratos.Expression.ConditionExpression]:
def _GetContainerExpression(self) -> ExpressionUnionType:
pass

@abstractmethod
Expand Down
Loading