Skip to content

Commit

Permalink
reorganized
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub committed Mar 2, 2020
1 parent 8f39e9e commit 2963bfd
Show file tree
Hide file tree
Showing 22 changed files with 135 additions and 30 deletions.
12 changes: 6 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ cmake_minimum_required(VERSION 3.6)
project(trt_pose)

add_library(trt_pose SHARED
trt_pose/plugins/find_peaks.cpp
trt_pose/plugins/refine_peaks.cpp
trt_pose/plugins/paf_score_graph.cpp
trt_pose/plugins/munkres.cpp
trt_pose/plugins/connect_parts.cpp
trt_pose/parse/find_peaks.cpp
trt_pose/parse/refine_peaks.cpp
trt_pose/parse/paf_score_graph.cpp
trt_pose/parse/munkres.cpp
trt_pose/parse/connect_parts.cpp
)

add_executable(trt_pose_test_all
trt_pose/plugins/test_all.cpp
trt_pose/parse/test_all.cpp
)
target_link_libraries(trt_pose_test_all trt_pose)
16 changes: 8 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
packages=find_packages(),
ext_package='trt_pose',
ext_modules=[cpp_extension.CppExtension('plugins', [
'trt_pose/plugins/find_peaks.cpp',
'trt_pose/plugins/paf_score_graph.cpp',
'trt_pose/plugins/refine_peaks.cpp',
'trt_pose/plugins/plugins.cpp',
'trt_pose/plugins/munkres.cpp',
'trt_pose/plugins/connect_parts.cpp',
'trt_pose/plugins/generate_cmap.cpp',
'trt_pose/plugins/generate_paf.cpp',
'trt_pose/parse/find_peaks.cpp',
'trt_pose/parse/paf_score_graph.cpp',
'trt_pose/parse/refine_peaks.cpp',
'trt_pose/parse/munkres.cpp',
'trt_pose/parse/connect_parts.cpp',
'trt_pose/plugins.cpp',
'trt_pose/train/generate_cmap.cpp',
'trt_pose/train/generate_paf.cpp',
])],
cmdclass={'build_ext': cpp_extension.BuildExtension},
install_requires=[
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "connect_parts.hpp"
#include <queue>

namespace trt_pose {
namespace parse {

std::size_t connect_parts_out_workspace(const int C, const int M) {
return sizeof(int) * C * M;
}
Expand Down Expand Up @@ -97,3 +100,6 @@ void connect_parts_out_batch(int *object_counts, // N
C, M, P, workspace);
}
}

} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once
#include <cstring>

namespace trt_pose {
namespace parse {

std::size_t connect_parts_out_workspace(const int C, const int M);

void connect_parts_out(int *object_counts, // 1
Expand All @@ -18,3 +21,6 @@ void connect_parts_out_batch(int *object_counts, // N
const int *counts, // NxC
const int N, const int K, const int C, const int M,
const int P, void *workspace);

} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))

namespace trt_pose {
namespace parse {

void find_peaks_out_hw(int *counts, // 1
int *peaks, // Mx2
const float *input, // HxW
Expand Down Expand Up @@ -75,3 +78,6 @@ void find_peaks_out_nchw(int *counts, // C
window_size);
}
}

} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

namespace trt_pose {
namespace parse {

void find_peaks_out_hw(int *counts, // 1
int *peaks, // Mx2
const float *input, // HxW
Expand All @@ -18,3 +21,6 @@ void find_peaks_out_nchw(int *counts, // NxC
const int N, const int C, const int H, const int W,
const int M, const float threshold,
const int window_size);

} // namespace parse
} // namespace trt_pose
8 changes: 8 additions & 0 deletions trt_pose/plugins/munkres.cpp → trt_pose/parse/munkres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
#include "utils/CoverTable.hpp"
#include "utils/PairGraph.hpp"

namespace trt_pose {
namespace parse {

using namespace utils;

void subMinRow(float *cost_graph, const int M, const int nrows,
const int ncols) {
for (int i = 0; i < nrows; i++) {
Expand Down Expand Up @@ -253,3 +258,6 @@ void assignment_out_nk(int *connections, // NxKx2xM
workspace);
}
}

} // namespace parse
} // namespace trt_pose
6 changes: 6 additions & 0 deletions trt_pose/plugins/munkres.hpp → trt_pose/parse/munkres.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include <cstring>

namespace trt_pose {
namespace parse {

std::size_t assignment_out_workspace(const int M);

void assignment_out(int *connections, // 2xM
Expand All @@ -22,3 +25,6 @@ void assignment_out_nk(int *connections, // NxKx2xM
const int *counts, // NxC
const int N, const int C, const int K, const int M,
const float score_threshold, void *workspace);

} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#define EPS 1e-5

namespace trt_pose {
namespace parse {

void paf_score_graph_out_hw(float *score_graph, // MxM
const float *paf_i, // HxW
const float *paf_j, // HxW
Expand Down Expand Up @@ -113,3 +116,6 @@ void paf_score_graph_out_nkhw(float *score_graph, // NxKxMxM
num_integral_samples);
}
}

} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

namespace trt_pose {
namespace parse {

void paf_score_graph_out_hw(float *score_graph, // MxM
const float *paf_i, // HxW
const float *paf_j, // HxW
Expand All @@ -25,3 +28,6 @@ void paf_score_graph_out_nkhw(float *score_graph, // NxKxMxM
const int N, const int K, const int C,
const int H, const int W, const int M,
const int num_integral_samples);

} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "refine_peaks.hpp"

namespace trt_pose {
namespace parse {

inline int reflect(int idx, int min, int max) {
if (idx < min) {
return -idx;
Expand Down Expand Up @@ -74,3 +77,6 @@ void refine_peaks_out_nchw(float *refined_peaks, // NxCxMx2
M, window_size);
}
}

} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

namespace trt_pose {
namespace parse {

void refine_peaks_out_hw(float *refined_peaks, // Mx2
const int *counts, // 1
const int *peaks, // Mx2
Expand All @@ -19,3 +22,6 @@ void refine_peaks_out_nchw(float *refined_peaks, // NxCxMx2
const float *cmap, const int N, const int C,
const int H, const int W, const int M,
const int window_size);

} // namespace parse
} // namespace trt_pose
3 changes: 3 additions & 0 deletions trt_pose/plugins/test_all.cpp → trt_pose/parse/test_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

#define ABS(x) ((x) > 0 ? (x) : (-x))

using namespace trt_pose;
using namespace trt_pose::parse;

void test_find_peaks_out_hw()
{
const int H = 4;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include <memory>
#include <vector>

namespace trt_pose {
namespace parse {
namespace utils {

class CoverTable
{
public:
Expand Down Expand Up @@ -66,3 +70,7 @@ class CoverTable
std::vector<bool> rows;
std::vector<bool> cols;
};

} // namespace utils
} // namespace parse
} // namespace trt_pose
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include <memory>
#include <vector>

namespace trt_pose {
namespace parse {
namespace utils {

class PairGraph
{
public:
Expand Down Expand Up @@ -111,3 +115,7 @@ class PairGraph
std::vector<int> rows;
std::vector<int> cols;
};

} // namespace utils
} // namespace parse
} // namespace trt_pose
17 changes: 10 additions & 7 deletions trt_pose/plugins/plugins.cpp → trt_pose/plugins.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#include "connect_parts.hpp"
#include "find_peaks.hpp"
#include "generate_cmap.hpp"
#include "generate_paf.hpp"
#include "munkres.hpp"
#include "paf_score_graph.hpp"
#include "refine_peaks.hpp"
#include "parse/connect_parts.hpp"
#include "parse/find_peaks.hpp"
#include "parse/munkres.hpp"
#include "parse/paf_score_graph.hpp"
#include "parse/refine_peaks.hpp"
#include "train/generate_cmap.hpp"
#include "train/generate_paf.hpp"
#include <torch/extension.h>
#include <vector>

using namespace trt_pose::parse;
using namespace trt_pose::train;

void find_peaks_out_torch(torch::Tensor counts, torch::Tensor peaks,
torch::Tensor input, const float threshold,
const int window_size, const int max_count) {
Expand Down
6 changes: 0 additions & 6 deletions trt_pose/plugins/generate_cmap.hpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "generate_cmap.hpp"

namespace trt_pose {
namespace train {

torch::Tensor generate_cmap(torch::Tensor counts, torch::Tensor peaks, int height, int width, float stdev, int window)
{
Expand Down Expand Up @@ -62,4 +64,7 @@ torch::Tensor generate_cmap(torch::Tensor counts, torch::Tensor peaks, int heigh
}

return cmap;
}
}

} // namespace trt_pose::train
} // namespace trt_pose
11 changes: 11 additions & 0 deletions trt_pose/train/generate_cmap.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include <torch/extension.h>
#include <vector>
#include <cmath>

namespace trt_pose {
namespace train {

torch::Tensor generate_cmap(torch::Tensor counts, torch::Tensor peaks, int height, int width, float stdev, int window);

} // namespace trt_pose::train
} // namespace trt_pose
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#define EPS 1e-5;

namespace trt_pose {
namespace train {

torch::Tensor generate_paf(torch::Tensor connections, torch::Tensor topology, torch::Tensor counts, torch::Tensor peaks, int height, int width, float stdev)
{
auto options = torch::TensorOptions()
Expand Down Expand Up @@ -91,4 +94,7 @@ torch::Tensor generate_paf(torch::Tensor connections, torch::Tensor topology, to
}

return paf;
}
}

} // namespace trt_pose::train
} // namespace trt_pose
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@
#include <vector>
#include <cmath>

namespace trt_pose {
namespace train {

torch::Tensor generate_paf(torch::Tensor connections, torch::Tensor topology, torch::Tensor counts, torch::Tensor peaks, int height, int width, float stdev);
torch::Tensor generate_paf(torch::Tensor connections, torch::Tensor topology, torch::Tensor counts, torch::Tensor peaks, int height, int width, float stdev);

} // namespace trt_pose::train
} // namespace trt_pose

0 comments on commit 2963bfd

Please sign in to comment.