Skip to content

Commit

Permalink
Fix: get stuck when load extension in the concurrency environment
Browse files Browse the repository at this point in the history
- Add a new struct called LoadExtensionPromise
- Remove async modifier in ExtensionDirectory

Close #183
  • Loading branch information
onewe committed Mar 16, 2024
1 parent 7ee9e73 commit 4e02c28
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 55 deletions.
139 changes: 112 additions & 27 deletions dubbo/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ use crate::{
};
use dubbo_base::{extension_param::ExtensionType, url::UrlParam, StdError, Url};
use dubbo_logger::tracing::{error, info};
use std::{future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::sync::oneshot;
use tokio::sync::{oneshot, Semaphore};

pub static EXTENSIONS: once_cell::sync::Lazy<ExtensionDirectoryCommander> =
once_cell::sync::Lazy::new(|| ExtensionDirectory::init());
Expand All @@ -41,13 +42,11 @@ impl ExtensionDirectory {
let mut extension_directory = ExtensionDirectory::default();

// register static registry extension
let _ = extension_directory
.register(
StaticRegistry::name(),
StaticRegistry::convert_to_extension_factories(),
ExtensionType::Registry,
)
.await;
let _ = extension_directory.register(
StaticRegistry::name(),
StaticRegistry::convert_to_extension_factories(),
ExtensionType::Registry,
);

while let Some(extension_opt) = rx.recv().await {
match extension_opt {
Expand All @@ -57,20 +56,19 @@ impl ExtensionDirectory {
extension_type,
tx,
) => {
let result = extension_directory
.register(extension_name, extension_factories, extension_type)
.await;
let result = extension_directory.register(
extension_name,
extension_factories,
extension_type,
);
let _ = tx.send(result);
}
ExtensionOpt::Remove(extension_name, extension_type, tx) => {
let result = extension_directory
.remove(extension_name, extension_type)
.await;
let result = extension_directory.remove(extension_name, extension_type);
let _ = tx.send(result);
}
ExtensionOpt::Load(url, extension_type, tx) => {
let result = extension_directory.load(url, extension_type).await;
let _ = tx.send(result);
let _ = extension_directory.load(url, extension_type, tx);
}
}
}
Expand All @@ -79,7 +77,7 @@ impl ExtensionDirectory {
ExtensionDirectoryCommander { sender: tx }
}

async fn register(
fn register(
&mut self,
extension_name: String,
extension_factories: ExtensionFactories,
Expand All @@ -89,47 +87,134 @@ impl ExtensionDirectory {
ExtensionType::Registry => match extension_factories {
ExtensionFactories::RegistryExtensionFactory(registry_extension_factory) => {
self.registry_extension_loader
.register(extension_name, registry_extension_factory)
.await;
.register(extension_name, registry_extension_factory);
Ok(())
}
},
}
}

async fn remove(
fn remove(
&mut self,
extension_name: String,
extension_type: ExtensionType,
) -> Result<(), StdError> {
match extension_type {
ExtensionType::Registry => {
self.registry_extension_loader.remove(extension_name).await;
self.registry_extension_loader.remove(extension_name);
Ok(())
}
}
}

async fn load(
fn load(
&mut self,
url: Url,
extension_type: ExtensionType,
) -> Result<Extensions, StdError> {
callback: oneshot::Sender<Result<Extensions, StdError>>,
) {
match extension_type {
ExtensionType::Registry => {
let extension = self.registry_extension_loader.load(&url).await;
let extension = self.registry_extension_loader.load(url);
match extension {
Ok(extension) => Ok(Extensions::Registry(extension)),
Ok(mut extension) => {
tokio::spawn(async move {
let extension = extension.resolve().await;
match extension {
Ok(extension) => {
let _ = callback.send(Ok(Extensions::Registry(extension)));
}
Err(err) => {
error!("load extension failed: {}", err);
let _ = callback.send(Err(err));
}
}
});
}
Err(err) => {
error!("load extension failed: {}", err);
Err(err)
let _ = callback.send(Err(err));
}
}
}
}
}
}

pub(crate) struct LoadExtensionPromise<T> {
extension: Arc<Option<T>>,
fut: Option<Pin<Box<dyn Future<Output = Result<T, StdError>> + Send + 'static>>>,
semaphore: Arc<Semaphore>,
}

impl<T> LoadExtensionPromise<T>
where
T: Send + Clone + 'static,
{
pub(crate) fn new(
fut: Pin<Box<dyn Future<Output = Result<T, StdError>> + Send + 'static>>,
) -> Self {
LoadExtensionPromise {
extension: Arc::new(None),
fut: Some(fut),
semaphore: Arc::new(Semaphore::new(0)),
}
}

fn get_extension(&self) -> Option<T> {
self.extension.as_ref().as_ref().map(|a| a.clone())
}

pub(crate) async fn resolve(&mut self) -> Result<T, StdError> {
let extension = self.get_extension();
if let Some(extension) = extension {
return Ok(extension);
}

let fut = self.fut.take();
let Some(mut fut) = fut else {
let _ = self.semaphore.acquire().await;
// check it again
let extension = self.get_extension();
if let Some(extension) = extension {
info!("promise has been resolved.");
return Ok(extension);
}
return Err(LoadExtensionError::new("load extension canceled ".to_string()).into());
};

match fut.as_mut().await {
Ok(extension) => {
info!("create extension success");
let ptr = Arc::as_ptr(&self.extension) as *mut Option<T>;
unsafe {
*ptr = Some(extension.clone());
}
self.semaphore.close();
Ok(extension)
}
Err(err) => {
error!("create extension failed: {}", err);
self.semaphore.close();
Err(LoadExtensionError::new(
"load extension failed, create extension occur an error".to_string(),
)
.into())
}
}
}
}

impl<T> Clone for LoadExtensionPromise<T> {
fn clone(&self) -> Self {
LoadExtensionPromise {
extension: self.extension.clone(),
fut: None,
semaphore: self.semaphore.clone(),
}
}
}

pub struct ExtensionDirectoryCommander {
sender: tokio::sync::mpsc::Sender<ExtensionOpt>,
}
Expand Down Expand Up @@ -280,7 +365,7 @@ pub trait Extension: Sealed {

fn name() -> String;

async fn create(url: &Url) -> Result<Self::Target, StdError>;
async fn create(url: Url) -> Result<Self::Target, StdError>;
}

#[allow(private_bounds)]
Expand Down
52 changes: 30 additions & 22 deletions dubbo/src/extension/registry_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use proxy::RegistryProxy;

use crate::extension::{
ConvertToExtensionFactories, Extension, ExtensionFactories, ExtensionMetaInfo, ExtensionType,
LoadExtensionPromise,
};

// extension://0.0.0.0/?extension-type=registry&extension-name=nacos&registry-url=nacos://127.0.0.1:8848
Expand Down Expand Up @@ -80,20 +81,19 @@ where
fn convert_to_extension_factories() -> ExtensionFactories {
fn constrain<F>(f: F) -> F
where
F: for<'a> Fn(
&'a Url,
F: Fn(
Url,
) -> Pin<
Box<
dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>>
+ Send
+ 'a,
+ Send,
>,
>,
{
f
}

let constructor = constrain(|url: &Url| {
let constructor = constrain(|url: Url| {
let f = <T as Extension>::create(url);
Box::pin(f)
});
Expand All @@ -108,19 +108,18 @@ pub(super) struct RegistryExtensionLoader {
}

impl RegistryExtensionLoader {
pub(crate) async fn register(
&mut self,
extension_name: String,
factory: RegistryExtensionFactory,
) {
pub(crate) fn register(&mut self, extension_name: String, factory: RegistryExtensionFactory) {
self.factories.insert(extension_name, factory);
}

pub(crate) async fn remove(&mut self, extension_name: String) {
pub(crate) fn remove(&mut self, extension_name: String) {
self.factories.remove(&extension_name);
}

pub(crate) async fn load(&mut self, url: &Url) -> Result<RegistryProxy, StdError> {
pub(crate) fn load(
&mut self,
url: Url,
) -> Result<LoadExtensionPromise<RegistryProxy>, StdError> {
let extension_name = url.query::<ExtensionName>().unwrap();
let extension_name = extension_name.value();
let factory = self.factories.get_mut(&extension_name).ok_or_else(|| {
Expand All @@ -129,19 +128,19 @@ impl RegistryExtensionLoader {
extension_name
))
})?;
factory.create(url).await
factory.create(url)
}
}

type RegistryConstructor = for<'a> fn(
&'a Url,
type RegistryConstructor = fn(
Url,
) -> Pin<
Box<dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>> + Send + 'a>,
Box<dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>> + Send>,
>;

pub(crate) struct RegistryExtensionFactory {
constructor: RegistryConstructor,
instances: HashMap<String, RegistryProxy>,
instances: HashMap<String, LoadExtensionPromise<RegistryProxy>>,
}

impl RegistryExtensionFactory {
Expand All @@ -154,7 +153,10 @@ impl RegistryExtensionFactory {
}

impl RegistryExtensionFactory {
pub(super) async fn create(&mut self, url: &Url) -> Result<RegistryProxy, StdError> {
pub(super) fn create(
&mut self,
url: Url,
) -> Result<LoadExtensionPromise<RegistryProxy>, StdError> {
let registry_url = url.query::<RegistryUrl>().unwrap();
let registry_url = registry_url.value();
let url_str = registry_url.as_str().to_string();
Expand All @@ -164,10 +166,16 @@ impl RegistryExtensionFactory {
Ok(proxy)
}
None => {
let registry = (self.constructor)(url).await?;
let proxy = <RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
self.instances.insert(url_str, proxy.clone());
Ok(proxy)
let registry = (self.constructor)(url);
let fut = Box::pin(async move {
let registry = registry.await?;
let proxy = <RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
Ok(proxy)
});

let promise = LoadExtensionPromise::new(fut);
self.instances.insert(url_str, promise.clone());
Ok(promise)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion dubbo/src/registry/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ impl Extension for StaticRegistry {
"static".to_string()
}

async fn create(url: &Url) -> Result<Self::Target, StdError> {
async fn create(url: Url) -> Result<Self::Target, StdError> {
// url example:
// extension://0.0.0.0?extension-type=registry&extension-name=static&registry=static://127.0.0.1
let static_invoker_urls = url.query::<StaticInvokerUrls>();
Expand Down
10 changes: 5 additions & 5 deletions registry/nacos/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl Extension for NacosRegistry {
"nacos".to_string()
}

async fn create(url: &Url) -> Result<Self::Target, StdError> {
async fn create(url: Url) -> Result<Self::Target, StdError> {
// url example:
// extension://0.0.0.0?extension-type=registry&extension-name=nacos&registry=nacos://127.0.0.1:8848
let registry_url = url.query::<RegistryUrl>().unwrap();
Expand Down Expand Up @@ -446,7 +446,7 @@ pub mod tests {
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));

let registry = NacosRegistry::create(&extension_url).await.unwrap();
let registry = NacosRegistry::create(extension_url).await.unwrap();

let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider&timestamp=1670060843807".parse().unwrap();

Expand Down Expand Up @@ -478,7 +478,7 @@ pub mod tests {
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));

let registry = NacosRegistry::create(&extension_url).await.unwrap();
let registry = NacosRegistry::create(extension_url).await.unwrap();

let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider&timestamp=1670060843807".parse().unwrap();

Expand Down Expand Up @@ -518,7 +518,7 @@ pub mod tests {
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));

let registry = NacosRegistry::create(&extension_url).await.unwrap();
let registry = NacosRegistry::create(extension_url).await.unwrap();

let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider&timestamp=1670060843807".parse().unwrap();

Expand Down Expand Up @@ -562,7 +562,7 @@ pub mod tests {
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));

let registry = NacosRegistry::create(&extension_url).await.unwrap();
let registry = NacosRegistry::create(extension_url).await.unwrap();

let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider&timestamp=1670060843807".parse().unwrap();

Expand Down

0 comments on commit 4e02c28

Please sign in to comment.