Skip to content

Commit

Permalink
move graph walkers to paddle/common (PaddlePaddle#63645)
Browse files Browse the repository at this point in the history
* move graph walkers to paddle/common and fix reference

* fix style

* give up useless walker move

* give up useless walker move and delete file
  • Loading branch information
Hongqing-work authored Apr 18, 2024
1 parent f1a411e commit 18f526e
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 175 deletions.
50 changes: 2 additions & 48 deletions paddle/cinn/common/bfs_walker.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,13 @@

#pragma once

#include <array>
#include <functional>
#include <queue>
#include <unordered_set>
#include "paddle/common/bfs_walker.h"

namespace cinn {
namespace common {

// breadth-first search visitor
template <typename NodeType>
class BfsWalker final {
public:
BfsWalker(const BfsWalker&) = delete;
BfsWalker(BfsWalker&&) = delete;

using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;

BfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}

void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}

template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::queue<NodeType> node_queue;
std::unordered_set<NodeType> queued_nodes;
const auto& TryEnqueueNode = [&](NodeType node) {
if (queued_nodes.count(node) == 0) {
node_queue.push(node);
queued_nodes.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryEnqueueNode(*iter);
}
while (!node_queue.empty()) {
NodeType node = node_queue.front();
node_queue.pop();
NodeHandler(node);
VisitNextNodes_(node, TryEnqueueNode);
}
}

private:
NodesVisitorType VisitNextNodes_;
};
using BfsWalker = ::common::BfsWalker<NodeType>;

} // namespace common
} // namespace cinn
72 changes: 2 additions & 70 deletions paddle/cinn/common/dfs_walker.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,82 +14,14 @@

#pragma once

#include <array>
#include <functional>
#include <iostream>
#include <queue>
#include <stack>
#include <unordered_set>
#include "paddle/common/dfs_walker.h"

namespace cinn {
namespace common {

// depth-first search visitor
template <typename NodeType>
class DfsWalker final {
public:
DfsWalker(const DfsWalker&) = delete;
DfsWalker(DfsWalker&&) = delete;

using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;

DfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}

void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler, [&](NodeType) {});
}

template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
(*this)(begin, end, NodeHandler, [&](NodeType) {});
}

// https://en.wikipedia.org/wiki/Depth-first_search
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandlerOnPush,
const NodeHandlerType& NodeHandlerOnPop) const {
std::unordered_set<NodeType> discovered;
struct Neighbours {
NodeType producer;
std::queue<NodeType> consumers;
};
std::stack<Neighbours> stack;
const auto& TryPush = [&](NodeType node) {
if (discovered.count(node) == 0) {
discovered.insert(node);
NodeHandlerOnPush(node);
stack.push(Neighbours{.producer = node});
VisitNextNodes_(node, [&](NodeType next_node) {
stack.top().consumers.push(next_node);
});
}
};
for (NodeIt node_iter = begin; node_iter != end; ++node_iter) {
TryPush(*node_iter);
while (!stack.empty()) {
auto* neighbours = &stack.top();
if (neighbours->consumers.empty()) {
NodeHandlerOnPop(neighbours->producer);
stack.pop();
} else {
TryPush(neighbours->consumers.front());
neighbours->consumers.pop();
}
}
}
}

private:
NodesVisitorType VisitNextNodes_;
};
using DfsWalker = ::common::DfsWalker<NodeType>;

} // namespace common
} // namespace cinn
59 changes: 2 additions & 57 deletions paddle/cinn/common/topo_walker.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,69 +14,14 @@

#pragma once

#include <array>
#include <functional>
#include <queue>
#include <unordered_set>
#include "paddle/common/topo_walker.h"

namespace cinn {
namespace common {

// Topological order visitor
template <typename NodeType>
class TopoWalker final {
public:
TopoWalker(const TopoWalker&) = default;
TopoWalker(TopoWalker&&) = default;

using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;

TopoWalker(const NodesVisitorType& VisitPrevNodesValue,
const NodesVisitorType& VisitNextNodesValue)
: VisitPrevNodes(VisitPrevNodesValue),
VisitNextNodes(VisitNextNodesValue) {}

void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}

template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::queue<NodeType> node_queue;
std::unordered_set<NodeType> queued_nodes;
const auto& TryEnqueueNode = [&](NodeType node) {
if (queued_nodes.count(node) == 0) {
node_queue.push(node);
queued_nodes.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryEnqueueNode(*iter);
}
while (!node_queue.empty()) {
NodeType node = node_queue.front();
node_queue.pop();
NodeHandler(node);
VisitNextNodes(node, [&](NodeType node) {
size_t num_unfinished_inputs = 0;
VisitPrevNodes(node, [&](NodeType in_node) {
num_unfinished_inputs += (queued_nodes.count(in_node) > 0 ? 0 : 1);
});
if (num_unfinished_inputs == 0) {
TryEnqueueNode(node);
}
});
}
}

NodesVisitorType VisitPrevNodes;
NodesVisitorType VisitNextNodes;
};
using TopoWalker = ::common::TopoWalker<NodeType>;

} // namespace common
} // namespace cinn
70 changes: 70 additions & 0 deletions paddle/common/bfs_walker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <array>
#include <functional>
#include <queue>
#include <unordered_set>

namespace common {

// breadth-first search visitor
template <typename NodeType>
class BfsWalker final {
public:
BfsWalker(const BfsWalker&) = delete;
BfsWalker(BfsWalker&&) = delete;

using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;

BfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}

void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}

template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::queue<NodeType> node_queue;
std::unordered_set<NodeType> queued_nodes;
const auto& TryEnqueueNode = [&](NodeType node) {
if (queued_nodes.count(node) == 0) {
node_queue.push(node);
queued_nodes.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryEnqueueNode(*iter);
}
while (!node_queue.empty()) {
NodeType node = node_queue.front();
node_queue.pop();
NodeHandler(node);
VisitNextNodes_(node, TryEnqueueNode);
}
}

private:
NodesVisitorType VisitNextNodes_;
};

} // namespace common
93 changes: 93 additions & 0 deletions paddle/common/dfs_walker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <array>
#include <functional>
#include <iostream>
#include <queue>
#include <stack>
#include <unordered_set>

namespace common {

// depth-first search visitor
template <typename NodeType>
class DfsWalker final {
public:
DfsWalker(const DfsWalker&) = delete;
DfsWalker(DfsWalker&&) = delete;

using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;

DfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}

void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler, [&](NodeType) {});
}

template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
(*this)(begin, end, NodeHandler, [&](NodeType) {});
}

// https://en.wikipedia.org/wiki/Depth-first_search
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandlerOnPush,
const NodeHandlerType& NodeHandlerOnPop) const {
std::unordered_set<NodeType> discovered;
struct Neighbours {
NodeType producer;
std::queue<NodeType> consumers;
};
std::stack<Neighbours> stack;
const auto& TryPush = [&](NodeType node) {
if (discovered.count(node) == 0) {
discovered.insert(node);
NodeHandlerOnPush(node);
stack.push(Neighbours{.producer = node});
VisitNextNodes_(node, [&](NodeType next_node) {
stack.top().consumers.push(next_node);
});
}
};
for (NodeIt node_iter = begin; node_iter != end; ++node_iter) {
TryPush(*node_iter);
while (!stack.empty()) {
auto* neighbours = &stack.top();
if (neighbours->consumers.empty()) {
NodeHandlerOnPop(neighbours->producer);
stack.pop();
} else {
TryPush(neighbours->consumers.front());
neighbours->consumers.pop();
}
}
}
}

private:
NodesVisitorType VisitNextNodes_;
};

} // namespace common
Loading

0 comments on commit 18f526e

Please sign in to comment.