Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(query): introduce udf runtime pool #17304

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/query/service/src/pipelines/builders/builder_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use databend_common_exception::Result;
use databend_common_pipeline_transforms::processors::TransformPipelineHelper;
use databend_common_sql::executor::physical_plans::Udf;
use databend_common_storages_fuse::TableContext;

use crate::pipelines::processors::transforms::TransformUdfScript;
use crate::pipelines::processors::transforms::TransformUdfServer;
Expand All @@ -29,15 +25,12 @@ impl PipelineBuilder {
self.build_pipeline(&udf.input)?;

if udf.script_udf {
let index_seq = Arc::new(AtomicUsize::new(0));
let runtime_num = self.ctx.get_settings().get_max_threads()? as usize;
let runtimes = TransformUdfScript::init_runtime(&udf.udf_funcs, runtime_num)?;
let runtimes = TransformUdfScript::init_runtime(&udf.udf_funcs)?;
self.main_pipeline.try_add_transformer(|| {
Ok(TransformUdfScript::new(
self.func_ctx.clone(),
udf.udf_funcs.clone(),
runtimes.clone(),
index_seq.clone(),
))
})
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ pub use udaf_script::*;
pub use utils::*;

pub use self::serde::*;
use super::runtime_pool;
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::fmt;
use std::io::BufRead;
use std::io::Cursor;
use std::sync::Arc;
use std::sync::Mutex;

use arrow_array::Array;
use arrow_array::RecordBatch;
Expand All @@ -40,8 +39,8 @@ use databend_common_functions::aggregates::AggregateFunction;
use databend_common_sql::plans::UDFLanguage;
use databend_common_sql::plans::UDFScriptCode;

#[cfg(feature = "python-udf")]
use super::super::python_udf::GLOBAL_PYTHON_RUNTIME;
use super::runtime_pool::Pool;
use super::runtime_pool::RuntimeBuilder;

pub struct AggregateUdfScript {
display_name: String,
Expand Down Expand Up @@ -138,6 +137,15 @@ impl AggregateFunction for AggregateUdfScript {
builder.append_column(&result);
Ok(())
}

fn need_manual_drop_state(&self) -> bool {
true
}

unsafe fn drop_state(&self, place: StateAddr) {
let state = place.get::<UdfAggState>();
std::ptr::drop_in_place(state);
}
}

impl fmt::Display for AggregateUdfScript {
Expand Down Expand Up @@ -244,19 +252,19 @@ pub fn create_udaf_script_function(
let UDFScriptCode { language, code, .. } = code;
let runtime = match language {
UDFLanguage::JavaScript => {
let pool = JsRuntimePool::new(
let builder = JsRuntimeBuilder {
name,
String::from_utf8(code.to_vec())?,
ArrowType::Struct(
code: String::from_utf8(code.to_vec())?,
state_type: ArrowType::Struct(
state_fields
.iter()
.map(|f| f.into())
.collect::<Vec<arrow_schema::Field>>()
.into(),
),
output_type,
);
UDAFRuntime::JavaScript(pool)
};
UDAFRuntime::JavaScript(JsRuntimePool::new(builder))
}
UDFLanguage::WebAssembly => unimplemented!(),
#[cfg(not(feature = "python-udf"))]
Expand All @@ -267,22 +275,19 @@ pub fn create_udaf_script_function(
}
#[cfg(feature = "python-udf")]
UDFLanguage::Python => {
let mut runtime = GLOBAL_PYTHON_RUNTIME.write();
let code = String::from_utf8(code.to_vec())?;
runtime.add_aggregate(
&name,
ArrowType::Struct(
let builder = python_pool::PyRuntimeBuilder {
name,
code: String::from_utf8(code.to_vec())?,
state_type: ArrowType::Struct(
state_fields
.iter()
.map(|f| f.into())
.collect::<Vec<arrow_schema::Field>>()
.into(),
),
ArrowType::from(&output_type),
arrow_udf_python::CallMode::CalledOnNullInput,
&code,
)?;
UDAFRuntime::Python(PythonInfo { name, output_type })
output_type,
};
UDAFRuntime::Python(Pool::new(builder))
}
};
let init_state = runtime
Expand All @@ -297,27 +302,17 @@ pub fn create_udaf_script_function(
}))
}

struct JsRuntimePool {
struct JsRuntimeBuilder {
name: String,
code: String,
state_type: ArrowType,
output_type: DataType,

runtimes: Mutex<Vec<arrow_udf_js::Runtime>>,
}

impl JsRuntimePool {
fn new(name: String, code: String, state_type: ArrowType, output_type: DataType) -> Self {
Self {
name,
code,
state_type,
output_type,
runtimes: Mutex::new(vec![]),
}
}
impl RuntimeBuilder<arrow_udf_js::Runtime> for JsRuntimeBuilder {
type Error = ErrorCode;

fn create(&self) -> Result<arrow_udf_js::Runtime> {
fn build(&self) -> std::result::Result<arrow_udf_js::Runtime, Self::Error> {
let mut runtime = match arrow_udf_js::Runtime::new() {
Ok(runtime) => runtime,
Err(e) => {
Expand All @@ -344,65 +339,78 @@ impl JsRuntimePool {

Ok(runtime)
}
}

fn call<T, F>(&self, op: F) -> anyhow::Result<T>
where F: FnOnce(&arrow_udf_js::Runtime) -> anyhow::Result<T> {
let mut runtimes = self.runtimes.lock().unwrap();
let runtime = match runtimes.pop() {
Some(runtime) => runtime,
None => self.create()?,
};
drop(runtimes);
type JsRuntimePool = Pool<arrow_udf_js::Runtime, JsRuntimeBuilder>;

#[cfg(feature = "python-udf")]
mod python_pool {
use super::*;

let result = op(&runtime)?;
pub(super) struct PyRuntimeBuilder {
pub name: String,
pub code: String,
pub state_type: ArrowType,
pub output_type: DataType,
}

let mut runtimes = self.runtimes.lock().unwrap();
runtimes.push(runtime);
impl RuntimeBuilder<arrow_udf_python::Runtime> for PyRuntimeBuilder {
type Error = ErrorCode;

Ok(result)
fn build(&self) -> std::result::Result<arrow_udf_python::Runtime, Self::Error> {
let mut runtime = arrow_udf_python::Builder::default()
.sandboxed(true)
.build()?;
let output_type: ArrowType = (&self.output_type).into();
runtime.add_aggregate(
&self.name,
self.state_type.clone(),
output_type,
arrow_udf_python::CallMode::CalledOnNullInput,
&self.code,
)?;
Ok(runtime)
}
}

pub type PyRuntimePool = Pool<arrow_udf_python::Runtime, PyRuntimeBuilder>;
}

enum UDAFRuntime {
JavaScript(JsRuntimePool),
#[expect(unused)]
WebAssembly,
#[cfg(feature = "python-udf")]
Python(PythonInfo),
}

#[cfg(feature = "python-udf")]
struct PythonInfo {
name: String,
output_type: DataType,
Python(python_pool::PyRuntimePool),
}

impl UDAFRuntime {
fn name(&self) -> &str {
match self {
UDAFRuntime::JavaScript(pool) => &pool.name,
UDAFRuntime::JavaScript(pool) => &pool.builder.name,
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => &info.name,
UDAFRuntime::Python(info) => &info.builder.name,
_ => unimplemented!(),
}
}

fn return_type(&self) -> DataType {
match self {
UDAFRuntime::JavaScript(pool) => pool.output_type.clone(),
UDAFRuntime::JavaScript(pool) => pool.builder.output_type.clone(),
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => info.output_type.clone(),
UDAFRuntime::Python(info) => info.builder.output_type.clone(),
_ => unimplemented!(),
}
}

fn create_state(&self) -> anyhow::Result<UdfAggState> {
let state = match self {
UDAFRuntime::JavaScript(pool) => pool.call(|runtime| runtime.create_state(&pool.name)),
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.create_state(&pool.builder.name))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.create_state(&info.name)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.create_state(&pool.builder.name))
}
_ => unimplemented!(),
}?;
Expand All @@ -412,12 +420,11 @@ impl UDAFRuntime {
fn accumulate(&self, state: &UdfAggState, input: &RecordBatch) -> anyhow::Result<UdfAggState> {
let state = match self {
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.accumulate(&pool.name, &state.0, input))
pool.call(|runtime| runtime.accumulate(&pool.builder.name, &state.0, input))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.accumulate(&info.name, &state.0, input)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.accumulate(&pool.builder.name, &state.0, input))
}
_ => unimplemented!(),
}?;
Expand All @@ -426,11 +433,12 @@ impl UDAFRuntime {

fn merge(&self, states: &Arc<dyn Array>) -> anyhow::Result<UdfAggState> {
let state = match self {
UDAFRuntime::JavaScript(pool) => pool.call(|runtime| runtime.merge(&pool.name, states)),
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.merge(&pool.builder.name, states))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.merge(&info.name, states)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.merge(&pool.builder.name, states))
}
_ => unimplemented!(),
}?;
Expand All @@ -440,12 +448,11 @@ impl UDAFRuntime {
fn finish(&self, state: &UdfAggState) -> anyhow::Result<Arc<dyn Array>> {
match self {
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.finish(&pool.name, &state.0))
pool.call(|runtime| runtime.finish(&pool.builder.name, &state.0))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.finish(&info.name, &state.0)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.finish(&pool.builder.name, &state.0))
}
_ => unimplemented!(),
}
Expand Down Expand Up @@ -495,9 +502,9 @@ mod tests {
Field::new("sum", ArrowType::Int64, false),
Field::new("weight", ArrowType::Int64, false),
];
let pool = JsRuntimePool::new(
agg_name.clone(),
r#"
let builder = JsRuntimeBuilder {
name: agg_name.clone(),
code: r#"
export function create_state() {
return {sum: 0, weight: 0};
}
Expand All @@ -521,9 +528,10 @@ export function finish(state) {
}
"#
.to_string(),
ArrowType::Struct(fields.clone().into()),
Float32Type::data_type(),
);
state_type: ArrowType::Struct(fields.clone().into()),
output_type: Float32Type::data_type(),
};
let pool = JsRuntimePool::new(builder);

let state = pool.call(|runtime| runtime.create_state(&agg_name))?;

Expand Down
14 changes: 1 addition & 13 deletions src/query/service/src/pipelines/processors/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
pub mod aggregator;
mod hash_join;
pub(crate) mod range_join;
mod runtime_pool;
mod transform_add_computed_columns;
mod transform_add_const_columns;
mod transform_add_internal_columns;
Expand Down Expand Up @@ -66,16 +67,3 @@ pub use transform_stream_sort_spill::*;
pub use transform_udf_script::TransformUdfScript;
pub use transform_udf_server::TransformUdfServer;
pub use window::*;

#[cfg(feature = "python-udf")]
mod python_udf {
use std::sync::Arc;
use std::sync::LazyLock;

use arrow_udf_python::Runtime;
use parking_lot::RwLock;

/// python runtime should be only initialized once by gil lock, see: https://github.com/python/cpython/blob/main/Python/pystate.c
pub static GLOBAL_PYTHON_RUNTIME: LazyLock<Arc<RwLock<Runtime>>> =
LazyLock::new(|| Arc::new(RwLock::new(Runtime::new().unwrap())));
}
Loading
Loading