-
Notifications
You must be signed in to change notification settings - Fork 295
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from NVIDIA-AI-IOT/cpp
Cpp
- Loading branch information
Showing
15 changed files
with
990 additions
and
609 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
cmake_minimum_required(VERSION 3.6) | ||
project(trt_pose) | ||
|
||
add_library(trt_pose SHARED | ||
trt_pose/plugins/find_peaks.cpp | ||
trt_pose/plugins/refine_peaks.cpp | ||
trt_pose/plugins/paf_score_graph.cpp | ||
trt_pose/plugins/munkres.cpp | ||
trt_pose/plugins/connect_parts.cpp | ||
) | ||
|
||
add_executable(trt_pose_test_all | ||
trt_pose/plugins/test_all.cpp | ||
) | ||
target_link_libraries(trt_pose_test_all trt_pose) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,112 +1,99 @@ | ||
#include "connect_parts.hpp" | ||
#include <queue> | ||
|
||
std::size_t connect_parts_out_workspace(const int C, const int M) { | ||
return sizeof(int) * C * M; | ||
} | ||
|
||
void connect_parts_out(int *object_counts, // 1 | ||
int *objects, // PxC | ||
const int *connections, // Kx2xM | ||
const int *topology, // Kx4 | ||
const int *counts, // C | ||
const int K, const int C, const int M, const int P, | ||
void *workspace) { | ||
|
||
// initialize objects | ||
for (int i = 0; i < C * M; i++) { | ||
objects[i] = -1; | ||
} | ||
|
||
// initialize visited | ||
std::memset(workspace, 0, connect_parts_out_workspace(C, M)); | ||
int *visited = (int *)workspace; | ||
|
||
int num_objects = 0; | ||
|
||
for (int c = 0; c < C; c++) { | ||
if (num_objects >= P) { | ||
break; | ||
} | ||
|
||
const int count = counts[c]; | ||
|
||
for (int i = 0; i < count; i++) { | ||
if (num_objects >= P) { | ||
break; | ||
} | ||
|
||
std::queue<std::pair<int, int>> q; | ||
bool new_object = false; | ||
q.push({c, i}); | ||
|
||
void connect_parts_out(torch::Tensor object_counts, torch::Tensor objects, torch::Tensor connections, torch::Tensor topology, torch::Tensor counts, int max_count) | ||
{ | ||
auto options = torch::TensorOptions() | ||
.dtype(torch::kInt32) | ||
.layout(torch::kStrided) | ||
.device(torch::kCPU) | ||
.requires_grad(false); | ||
int N = counts.size(0); | ||
int K = topology.size(0); | ||
int C = counts.size(1); | ||
int M = connections.size(3); | ||
|
||
auto visited = torch::zeros({N, C, M}, options); | ||
auto visited_a = visited.accessor<int, 3>(); | ||
auto counts_a = counts.accessor<int, 2>(); | ||
auto topology_a = topology.accessor<int, 2>(); | ||
auto objects_a = objects.accessor<int, 3>(); | ||
auto object_counts_a = object_counts.accessor<int, 1>(); | ||
auto connections_a = connections.accessor<int, 4>(); | ||
|
||
for (int n = 0; n < N; n++) | ||
{ | ||
int num_objects = 0; | ||
for (int c = 0; c < C; c++) | ||
{ | ||
if (num_objects >= max_count) { | ||
break; | ||
while (!q.empty()) { | ||
auto node = q.front(); | ||
q.pop(); | ||
int c_n = node.first; | ||
int i_n = node.second; | ||
|
||
if (visited[c_n * M + i_n]) { | ||
continue; | ||
} | ||
|
||
visited[c_n * M + i_n] = 1; | ||
new_object = true; | ||
objects[num_objects * C + c_n] = i_n; | ||
|
||
for (int k = 0; k < K; k++) { | ||
const int *tk = &topology[k * 4]; | ||
const int c_a = tk[2]; | ||
const int c_b = tk[3]; | ||
const int *ck = &connections[k * 2 * M]; | ||
|
||
if (c_a == c_n) { | ||
int i_b = ck[i_n]; | ||
if (i_b >= 0) { | ||
q.push({c_b, i_b}); | ||
} | ||
|
||
int count = counts_a[n][c]; | ||
|
||
for (int i = 0; i < count; i++) | ||
{ | ||
if (num_objects >= max_count) { | ||
break; | ||
} | ||
|
||
std::queue<std::pair<int, int>> q; | ||
bool new_object = false; | ||
q.push({c, i}); | ||
|
||
while (!q.empty()) | ||
{ | ||
auto node = q.front(); | ||
q.pop(); | ||
int c_n = node.first; | ||
int i_n = node.second; | ||
|
||
if (visited_a[n][c_n][i_n]) { | ||
continue; | ||
} | ||
|
||
visited_a[n][c_n][i_n] = 1; | ||
new_object = true; | ||
objects_a[n][num_objects][c_n] = i_n; | ||
|
||
for (int k = 0; k < K; k++) | ||
{ | ||
int c_a = topology_a[k][2]; | ||
int c_b = topology_a[k][3]; | ||
|
||
if (c_a == c_n) | ||
{ | ||
int i_b = connections_a[n][k][0][i_n]; | ||
if (i_b >= 0) { | ||
q.push({c_b, i_b}); | ||
} | ||
} | ||
|
||
if (c_b == c_n) | ||
{ | ||
int i_a = connections_a[n][k][1][i_n]; | ||
if (i_a >= 0) { | ||
q.push({c_a, i_a}); | ||
} | ||
} | ||
} | ||
} | ||
|
||
if (new_object) | ||
{ | ||
num_objects++; | ||
} | ||
} | ||
|
||
if (c_b == c_n) { | ||
int i_a = ck[M + i_n]; | ||
if (i_a >= 0) { | ||
q.push({c_a, i_a}); | ||
} | ||
} | ||
} | ||
|
||
object_counts_a[n] = num_objects; | ||
} | ||
|
||
if (new_object) { | ||
num_objects++; | ||
} | ||
} | ||
} | ||
*object_counts = num_objects; | ||
} | ||
|
||
|
||
std::vector<torch::Tensor> connect_parts(torch::Tensor connections, torch::Tensor topology, torch::Tensor counts, int max_count) | ||
{ | ||
auto options = torch::TensorOptions() | ||
.dtype(torch::kInt32) | ||
.layout(torch::kStrided) | ||
.device(torch::kCPU) | ||
.requires_grad(false); | ||
|
||
int N = counts.size(0); | ||
int K = topology.size(0); | ||
int C = counts.size(1); | ||
int M = connections.size(3); | ||
|
||
auto objects = torch::full({N, max_count, C}, -1, options); | ||
auto object_counts = torch::zeros({N}, options); | ||
connect_parts_out(object_counts, objects, connections, topology, counts, max_count); | ||
return {object_counts, objects}; | ||
} | ||
void connect_parts_out_batch(int *object_counts, // N | ||
int *objects, // NxPxC | ||
const int *connections, // NxKx2xM | ||
const int *topology, // Kx4 | ||
const int *counts, // NxC | ||
const int N, const int K, const int C, const int M, | ||
const int P, void *workspace) { | ||
for (int n = 0; n < N; n++) { | ||
connect_parts_out(&object_counts[n], &objects[n * P * C], | ||
&connections[n * K * 2 * M], topology, &counts[n * C], K, | ||
C, M, P, workspace); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,20 @@ | ||
#include <torch/extension.h> | ||
#include <vector> | ||
#include <queue> | ||
#pragma once | ||
#include <cstring> | ||
|
||
std::size_t connect_parts_out_workspace(const int C, const int M); | ||
|
||
void connect_parts_out(torch::Tensor object_counts, torch::Tensor objects, torch::Tensor connections, torch::Tensor topology, torch::Tensor counts, int max_count); | ||
std::vector<torch::Tensor> connect_parts(torch::Tensor connections, torch::Tensor topology, torch::Tensor counts, int max_count); | ||
void connect_parts_out(int *object_counts, // 1 | ||
int *objects, // PxC | ||
const int *connections, // Kx2xM | ||
const int *topology, // Kx4 | ||
const int *counts, // C | ||
const int K, const int C, const int M, const int P, | ||
void *workspace); | ||
|
||
void connect_parts_out_batch(int *object_counts, // N | ||
int *objects, // NxPxC | ||
const int *connections, // NxKx2xM | ||
const int *topology, // Kx4 | ||
const int *counts, // NxC | ||
const int N, const int K, const int C, const int M, | ||
const int P, void *workspace); |
Oops, something went wrong.