Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor split functions, more error detection #37

Merged
merged 13 commits into from
Oct 8, 2021
12 changes: 10 additions & 2 deletions scripts/build_mac.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ set -xe
pushd $(dirname $0)/..

OBS_VERSION=$(brew info --json=v2 --cask obs | jq -r .casks[0].version)
LLVM_VERSION=$(brew info --json=v2 llvm | jq -r .formulae[0].installed[0].version)
LLVM_VERSION=$(brew info --json=v2 llvm@12 | jq -r .formulae[0].installed[0].version)

echo "Using OBS ${OBS_VERSION}, LLVM ${LLVM_VERSION}"

[ -d deps ] || mkdir deps
[ -d deps/obs-studio ] && rm -rf deps/obs-studio
Expand All @@ -13,7 +15,13 @@ git -C deps clone --single-branch --depth 1 -b ${OBS_VERSION} https://github.com
mkdir build
pushd build
# cmake .. -DobsPath=../deps/obs-studio -DLLVM_DIR=/usr/local/Cellar/llvm/12.0.1/lib/cmake/llvm
cmake .. -DobsLibPath=/Applications/OBS.app/Contents/Frameworks -DobsIncludePath=$(cd ../deps/obs-studio/libobs; pwd) -DOnnxRuntimePath=$(cd ../deps/onnxruntime; pwd) -DLLVM_DIR=/usr/local/Cellar/llvm/${LLVM_VERSION}/lib/cmake/llvm
cmake .. \
-DobsLibPath=/Applications/OBS.app/Contents/Frameworks \
-DobsIncludePath=$(cd ../deps/obs-studio/libobs; pwd) \
-DOnnxRuntimePath=$(cd ../deps/onnxruntime; pwd) \
-DHalide_DIR=$(cd ../deps/Halide; pwd)/lib/cmake/Halide \
-DHalideHelpers_DIR=$(cd ../deps/Halide; pwd)/lib/cmake/HalideHelpers \
-DLLVM_DIR=/usr/local/Cellar/llvm/${LLVM_VERSION}/lib/cmake/llvm
cmake --build . --config Release
cpack
popd
18 changes: 16 additions & 2 deletions scripts/install_deps_mac.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
#!/bin/bash

set -xe
set -xeu
cd $(dirname $0)/..
DEPS_DIR=$(pwd)/deps

HALIDE_VERSION=12.0.1
HALIDE_URL=https://github.com/halide/Halide/releases/download/v12.0.1/Halide-12.0.1-x86-64-osx-5dabcaa9effca1067f907f6c8ea212f3d2b1d99a.tar.gz
HALIDE_TGZ=${DEPS_DIR}/Halide-${HALIDE_VERSION}.tar.gz
HALIDE_DIR=${DEPS_DIR}/Halide

[ -d ${DEPS_DIR} ] || mkdir ${DEPS_DIR}
[ -e ${HALIDE_TGZ} ] || curl -o ${HALIDE_TGZ} -L ${HALIDE_URL}
[ -d ${HALIDE_DIR} ] || mkdir ${HALIDE_DIR} && tar zxf ${HALIDE_TGZ} -C ${HALIDE_DIR} --strip-components 1

# for update obs version.
brew update

# brew install onnxruntime
brew install halide
brew install llvm@12
brew pin llvm
brew install obs --cask

283 changes: 172 additions & 111 deletions src/obs-virtualbg-detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,24 @@ void detector_destroy(void *data) {
bfree(filter_data);
}

void detector_update(void *data, obs_data_t *settings) {
virtual_bg_filter_data *filter_data = static_cast<virtual_bg_filter_data *>(data);
if (filter_data == NULL) {
return;
}

filter_data->use_threshold = obs_data_get_bool(settings, USE_THRESHOLD);
filter_data->threshold = (float)obs_data_get_double(settings, THRESHOLD_VALUE);
filter_data->use_mask_blur = obs_data_get_bool(settings, USE_MASK_BLUR);

Ort::SessionOptions sessionOptions;

sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

void detector_setup_ort_session_gpu(Ort::SessionOptions &sessionOptions) {
#if _WIN32
sessionOptions.DisableMemPattern();
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, 0));
try {
sessionOptions.DisableMemPattern();
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, 0));

} catch (const std::exception &ex) {
blog(LOG_ERROR,
"[Virtual BG detector] Can't Append Execution Provider DML. Will use CPU inference (its so heavy). "
"error: %s",
ex.what());
}
#endif
}

void detector_setup_ort_session_load_model(virtual_bg_filter_data *filter_data,
Ort::SessionOptions &sessionOptions) {
char *modelPath = obs_module_file("model.onnx");
static Ort::Env env(ORT_LOGGING_LEVEL_ERROR, "virtual_bg inference");

Expand All @@ -123,7 +121,7 @@ void detector_update(void *data, obs_data_t *settings) {
return;
}
bfree(wcharModelPath);
#elif __APPLE__
#else
try {
filter_data->session.reset(new Ort::Session(env, modelPath, sessionOptions));
} catch (const std::exception &ex) {
Expand All @@ -133,34 +131,87 @@ void detector_update(void *data, obs_data_t *settings) {
}
bfree(modelPath);
#endif
}

filter_data->input_names[0] = filter_data->session->GetInputName(0, *filter_data->allocator);
filter_data->output_names[0] = filter_data->session->GetOutputName(0, *filter_data->allocator);
auto input_info = filter_data->session->GetInputTypeInfo(0);
auto output_info = filter_data->session->GetOutputTypeInfo(0);
auto input_dims = input_info.GetTensorTypeAndShapeInfo().GetShape();
auto output_dims = output_info.GetTensorTypeAndShapeInfo().GetShape();
filter_data->tensor_width = input_dims[2];
filter_data->tensor_height = input_dims[1];
if (filter_data->input_u8_buffer) {
bfree(filter_data->input_u8_buffer);
void detector_setup_ort_session(virtual_bg_filter_data *filter_data) {
Ort::SessionOptions sessionOptions;

sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

detector_setup_ort_session_gpu(sessionOptions);
detector_setup_ort_session_load_model(filter_data, sessionOptions);
}

void detector_gather_session_information(virtual_bg_filter_data *filter_data) {
if (filter_data->session.get() != nullptr) {
try {
if (filter_data->input_u8_buffer) {
bfree(filter_data->input_u8_buffer);
filter_data->input_u8_buffer = NULL;
}
if (filter_data->mask_u8_buffer) {
bfree(filter_data->mask_u8_buffer);
filter_data->mask_u8_buffer = NULL;
}
if (filter_data->mask_blurred_u8_buffer) {
bfree(filter_data->mask_blurred_u8_buffer);
filter_data->mask_blurred_u8_buffer = NULL;
}
if (filter_data->feedback_buffer) {
bfree(filter_data->feedback_buffer);
filter_data->feedback_buffer = NULL;
}
filter_data->input_names[0] = filter_data->session->GetInputName(0, *filter_data->allocator);
filter_data->output_names[0] = filter_data->session->GetOutputName(0, *filter_data->allocator);
auto input_info = filter_data->session->GetInputTypeInfo(0);
auto output_info = filter_data->session->GetOutputTypeInfo(0);
auto input_dims = input_info.GetTensorTypeAndShapeInfo().GetShape();
auto output_dims = output_info.GetTensorTypeAndShapeInfo().GetShape();
filter_data->tensor_width = input_dims[2];
filter_data->tensor_height = input_dims[1];
filter_data->input_u8_buffer =
(uint8_t *)bmalloc(sizeof(uint8_t) * filter_data->tensor_width * filter_data->tensor_height * 3);
filter_data->mask_u8_buffer =
(uint8_t *)bmalloc(sizeof(uint8_t) * filter_data->tensor_width * filter_data->tensor_height);
filter_data->mask_blurred_u8_buffer =
(uint8_t *)bmalloc(sizeof(uint8_t) * filter_data->tensor_width * filter_data->tensor_height);
filter_data->feedback_buffer =
(float *)bzalloc(sizeof(float) * filter_data->tensor_width * filter_data->tensor_height);

filter_data->input_tensor =
Ort::Value::CreateTensor<float>(*filter_data->allocator, input_dims.data(), input_dims.size());
filter_data->output_tensor =
Ort::Value::CreateTensor<float>(*filter_data->allocator, output_dims.data(), output_dims.size());
} catch (const std::exception &ex) {
blog(LOG_ERROR, "[Virtual BG detector] somethins happens during gathering model information. error: %s",
ex.what());
}
}
}

void detector_update(void *data, obs_data_t *settings) {
virtual_bg_filter_data *filter_data = static_cast<virtual_bg_filter_data *>(data);
if (filter_data == NULL) {
return;
}
filter_data->input_u8_buffer =
(uint8_t *)bmalloc(sizeof(uint8_t) * filter_data->tensor_width * filter_data->tensor_height * 3);
filter_data->mask_u8_buffer =
(uint8_t *)bmalloc(sizeof(uint8_t) * filter_data->tensor_width * filter_data->tensor_height);
filter_data->mask_blurred_u8_buffer =
(uint8_t *)bmalloc(sizeof(uint8_t) * filter_data->tensor_width * filter_data->tensor_height);

filter_data->input_tensor =
Ort::Value::CreateTensor<float>(*filter_data->allocator, input_dims.data(), input_dims.size());
filter_data->output_tensor =
Ort::Value::CreateTensor<float>(*filter_data->allocator, output_dims.data(), output_dims.size());

filter_data->use_threshold = obs_data_get_bool(settings, USE_THRESHOLD);
filter_data->threshold = (float)obs_data_get_double(settings, THRESHOLD_VALUE);
filter_data->use_mask_blur = obs_data_get_bool(settings, USE_MASK_BLUR);

if (filter_data->preprocess_scaler) {
video_scaler_destroy(filter_data->preprocess_scaler);
filter_data->preprocess_scaler = NULL;
}

detector_setup_ort_session(filter_data);
detector_gather_session_information(filter_data);
}

void detector_setup_lut() {
for (int i = 0; i < 256; ++i) {
lut[i] = i / 255.0f;
}
}

void *detector_create(obs_data_t *settings, obs_source_t *source) {
Expand All @@ -177,15 +228,7 @@ void *detector_create(obs_data_t *settings, obs_source_t *source) {
}

detector_update(filter_data, settings);
if (filter_data->feedback_buffer) {
bfree(filter_data->feedback_buffer);
}
filter_data->feedback_buffer =
(float *)bzalloc(sizeof(float) * filter_data->tensor_width * filter_data->tensor_height);

for (int i = 0; i < 256; ++i) {
lut[i] = i / 255.0f;
}
detector_setup_lut();

return filter_data;
}
Expand All @@ -205,6 +248,83 @@ obs_properties_t *detector_properties(void *data) {
return ppts;
}

void detector_setup_preprocess_scaler(virtual_bg_filter_data *filter_data, struct obs_source_frame *frame) {
struct video_scale_info frame_scaler_info {
frame->format, frame->width, frame->height, frame->full_range ? VIDEO_RANGE_FULL : VIDEO_RANGE_DEFAULT,
VIDEO_CS_DEFAULT
};
struct video_scale_info tensor_scaler_info {
VIDEO_FORMAT_BGR3, (uint32_t)filter_data->tensor_width,
(uint32_t)filter_data->tensor_height, VIDEO_RANGE_DEFAULT, VIDEO_CS_DEFAULT
};
int ret = video_scaler_create(&filter_data->preprocess_scaler, &tensor_scaler_info, &frame_scaler_info,
VIDEO_SCALE_BICUBIC);
if (ret != 0) {
blog(LOG_ERROR, "[Virtual BG detector] Can't create video_scaler_create %d", ret);
throw new std::runtime_error("Cant create video_scaler_create");
} else {
blog(LOG_INFO, "[Virtual BG detector] video_scaler_create success. %dx%d -> %dx%d",
filter_data->frame_width, filter_data->frame_height, filter_data->tensor_width,
filter_data->tensor_height);
}
}

void detector_preprocess(virtual_bg_filter_data *filter_data, struct obs_source_frame *frame) {
const uint32_t linesize[] = {(uint32_t)filter_data->tensor_width * 3};
if (!video_scaler_scale(filter_data->preprocess_scaler, &filter_data->input_u8_buffer, linesize,
frame->data, frame->linesize)) {
blog(LOG_ERROR, "[Virtual BG detector] video_scaler_scale failed.");
throw new std::runtime_error("video_scaler_scale failed.");
}

float *tensor_buffer = filter_data->input_tensor.GetTensorMutableData<float>();
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
tensor_buffer[i * 3 + 0] = lut[filter_data->input_u8_buffer[i * 3 + 2]];
tensor_buffer[i * 3 + 1] = lut[filter_data->input_u8_buffer[i * 3 + 1]];
tensor_buffer[i * 3 + 2] = lut[filter_data->input_u8_buffer[i * 3 + 0]];
}
}

void detector_inference(virtual_bg_filter_data *filter_data) {
try {
filter_data->session->Run(Ort::RunOptions(NULL), filter_data->input_names, &filter_data->input_tensor, 1,
filter_data->output_names, &filter_data->output_tensor, 1);
} catch (const std::exception &ex) {
blog(LOG_ERROR, "[Virtual BG detector] Error at detector_inference. error: %s", ex.what());
throw ex;
}
}

void detector_postprocess(virtual_bg_filter_data *filter_data) {
const float *tensor_buffer2 = filter_data->output_tensor.GetTensorData<float>();
if (filter_data->use_threshold) {
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
float val = tensor_buffer2[i] * 0.9f + filter_data->feedback_buffer[i] * 0.1f;
filter_data->mask_u8_buffer[i] = val >= filter_data->threshold ? 255 : 0;
}
} else {
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
float val = tensor_buffer2[i] * 0.8f + filter_data->feedback_buffer[i] * 0.2f;
filter_data->mask_u8_buffer[i] = val * 255.0f;
}
}

if (filter_data->use_mask_blur) {
Halide::Runtime::Buffer<uint8_t> input{filter_data->mask_u8_buffer, (int)filter_data->tensor_width,
filter_data->tensor_height};
Halide::Runtime::Buffer<uint8_t> output{filter_data->mask_blurred_u8_buffer,
(int)filter_data->tensor_width, filter_data->tensor_height};

blur(input, output);
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
filter_data->feedback_buffer[i] = filter_data->mask_blurred_u8_buffer[i] / 255.0f;
}

set_mask_data(filter_data->parent, filter_data->mask_blurred_u8_buffer);
} else {
set_mask_data(filter_data->parent, filter_data->mask_u8_buffer);
}
}
struct obs_source_frame *detector_filter_video(void *data, struct obs_source_frame *frame) {
try {
auto start = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -235,70 +355,11 @@ struct obs_source_frame *detector_filter_video(void *data, struct obs_source_fra
}

if (!filter_data->preprocess_scaler) {
struct video_scale_info frame_scaler_info {
frame->format, frame->width, frame->height,
frame->full_range ? VIDEO_RANGE_FULL : VIDEO_RANGE_DEFAULT, VIDEO_CS_DEFAULT
};
struct video_scale_info tensor_scaler_info {
VIDEO_FORMAT_BGR3, (uint32_t)filter_data->tensor_width,
(uint32_t)filter_data->tensor_height, VIDEO_RANGE_DEFAULT, VIDEO_CS_DEFAULT
};
int ret = video_scaler_create(&filter_data->preprocess_scaler, &tensor_scaler_info, &frame_scaler_info,
VIDEO_SCALE_BICUBIC);
if (ret != 0) {
blog(LOG_ERROR, "[Virtual BG detector] Can't create video_scaler_create %d", ret);
return frame;
} else {
blog(LOG_INFO, "[Virtual BG detector] video_scaler_create success. %dx%d -> %dx%d",
filter_data->frame_width, filter_data->frame_height, filter_data->tensor_width,
filter_data->tensor_height);
}
}

const uint32_t linesize[] = {(uint32_t)filter_data->tensor_width * 3};
if (!video_scaler_scale(filter_data->preprocess_scaler, &filter_data->input_u8_buffer, linesize,
frame->data, frame->linesize)) {
blog(LOG_ERROR, "[Virtual BG detector] video_scaler_scale failed.");
return frame;
}

float *tensor_buffer = filter_data->input_tensor.GetTensorMutableData<float>();
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
tensor_buffer[i * 3 + 0] = lut[filter_data->input_u8_buffer[i * 3 + 2]];
tensor_buffer[i * 3 + 1] = lut[filter_data->input_u8_buffer[i * 3 + 1]];
tensor_buffer[i * 3 + 2] = lut[filter_data->input_u8_buffer[i * 3 + 0]];
}
filter_data->session->Run(Ort::RunOptions(NULL), filter_data->input_names, &filter_data->input_tensor, 1,
filter_data->output_names, &filter_data->output_tensor, 1);

const float *tensor_buffer2 = filter_data->output_tensor.GetTensorData<float>();
if (filter_data->use_threshold) {
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
float val = tensor_buffer2[i] * 0.9f + filter_data->feedback_buffer[i] * 0.1f;
filter_data->mask_u8_buffer[i] = val >= filter_data->threshold ? 255 : 0;
}
} else {
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
float val = tensor_buffer2[i] * 0.8f + filter_data->feedback_buffer[i] * 0.2f;
filter_data->mask_u8_buffer[i] = val * 255.0f;
}
}

if (filter_data->use_mask_blur) {
Halide::Runtime::Buffer<uint8_t> input{filter_data->mask_u8_buffer, (int)filter_data->tensor_width,
filter_data->tensor_height};
Halide::Runtime::Buffer<uint8_t> output{filter_data->mask_blurred_u8_buffer,
(int)filter_data->tensor_width, filter_data->tensor_height};

blur(input, output);
for (int i = 0; i < filter_data->tensor_width * filter_data->tensor_height; ++i) {
filter_data->feedback_buffer[i] = filter_data->mask_blurred_u8_buffer[i] / 255.0f;
}

set_mask_data(filter_data->parent, filter_data->mask_blurred_u8_buffer);
} else {
set_mask_data(filter_data->parent, filter_data->mask_u8_buffer);
detector_setup_preprocess_scaler(filter_data, frame);
}
detector_preprocess(filter_data, frame);
detector_inference(filter_data);
detector_postprocess(filter_data);

if (filter_data->cnt % 300 == 0) {
auto stop = std::chrono::high_resolution_clock::now();
Expand Down
Loading