Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace definitions Vec with OnceLock slots #992

Merged
merged 4 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 204 additions & 64 deletions src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
/// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar.
/// We use DefinitionsBuilder to collect the references / definitions into a single vector
/// and then get a definition from a reference using an integer id (just for performance of not using a HashMap)
use std::collections::hash_map::Entry;
use std::{
collections::hash_map::Entry,
fmt::Debug,
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
};

use pyo3::prelude::*;
use pyo3::{prelude::*, PyTraverseError, PyVisit};

use ahash::AHashMap;

use crate::build_tools::py_schema_err;

// An integer id for the reference
pub type ReferenceId = usize;
use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};

/// Definitions are validators and serializers that are
/// shared by reference.
Expand All @@ -24,91 +28,227 @@ pub type ReferenceId = usize;
/// They get indexed by a ReferenceId, which are integer identifiers
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
/// gets build.
pub type Definitions<T> = [T];
#[derive(Clone)]
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);

#[derive(Clone, Debug)]
struct Definition<T> {
pub id: ReferenceId,
pub value: Option<T>,
impl<T> Definitions<T> {
pub fn values(&self) -> impl Iterator<Item = &Definition<T>> {
self.0.values()
}
}

/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<DefinitionInner<T>>);

impl<T> Definition<T> {
pub fn get(&self) -> Option<&T> {
self.0.value.get()
}
}

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
}

/// Reference to a definition.
pub struct DefinitionRef<T> {
name: Arc<String>,
value: Definition<T>,
}

// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
impl<T> Clone for DefinitionRef<T> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
value: self.value.clone(),
}
}
}

impl<T> DefinitionRef<T> {
pub fn id(&self) -> usize {
Arc::as_ptr(&self.value.0) as usize
}

pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
match self.value.0.value.get() {
Some(value) => self.value.0.name.get_or_init(|| init(value)),
None => "...",
}
}

pub fn get(&self) -> Option<&T> {
self.value.0.value.get()
}
}

impl<T: Debug> Debug for DefinitionRef<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To avoid possible infinite recursion from recursive definitions,
// a DefinitionRef just displays debug as its name
self.name.fmt(f)
}
}

impl<T: Debug> Debug for Definitions<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Formatted as a list for backwards compatibility; in principle
// this could be formatted as a map. Maybe change in a future
// minor release of pydantic.
write![f, "["]?;
let mut first = true;
for def in self.0.values() {
write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?;
first = false;
}
write![f, "]"]?;
Ok(())
}
}

impl<T> Clone for Definition<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T: Debug> Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0.value.get() {
Some(value) => value.fmt(f),
None => "...".fmt(f),
}
}
}

impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Some(value) = self.value.0.value.get() {
value.py_gc_traverse(visit)?;
}
Ok(())
}
}

impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for value in self.0.values() {
if let Some(value) = value.0.value.get() {
value.py_gc_traverse(visit)?;
}
}
Ok(())
}
}

#[derive(Clone, Debug)]
pub struct DefinitionsBuilder<T> {
definitions: AHashMap<String, Definition<T>>,
definitions: Definitions<T>,
}

impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
pub fn new() -> Self {
Self {
definitions: AHashMap::new(),
definitions: Definitions(AHashMap::new()),
}
}

/// Get a ReferenceId for the given reference string.
// This ReferenceId can later be used to retrieve a definition
pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId {
let next_id = self.definitions.len();
pub fn get_definition(&mut self, reference: &str) -> DefinitionRef<T> {
// We either need a String copy or two hashmap lookups
// Neither is better than the other
// We opted for the easier outward facing API
match self.definitions.entry(reference.to_string()) {
Entry::Occupied(entry) => entry.get().id,
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: None,
});
next_id
}
let name = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::new(),
name: LazyName::new(),
}))),
};
DefinitionRef {
name,
value: value.clone(),
}
}

/// Add a definition, returning the ReferenceId that maps to it
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<ReferenceId> {
let next_id = self.definitions.len();
match self.definitions.entry(reference.clone()) {
Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) {
Some(_) => py_schema_err!("Duplicate ref: `{}`", reference),
None => Ok(entry.get().id),
},
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: Some(value),
});
Ok(next_id)
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
let name = Arc::new(reference);
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => {
let definition = entry.into_mut();
match definition.0.value.set(value) {
Ok(()) => definition.clone(),
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
}
}
Entry::Vacant(entry) => entry
.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::from(value),
name: LazyName::new(),
})))
.clone(),
};
Ok(DefinitionRef { name, value })
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Definitions<T>> {
for (reference, def) in &self.definitions.0 {
if def.0.value.get().is_none() {
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
}
}
Ok(self.definitions)
}
}

/// Retrieve an item definition using a ReferenceId
/// If the definition doesn't yet exist (as happens in recursive types) then we create it
/// At the end (in finish()) we check that there are no undefined definitions
pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> {
let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) {
Some(v) => v,
None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id),
};
match def.value.as_ref() {
Some(v) => Ok(v),
None => py_schema_err!(
"Definitions error: attempted to use `{}` before it was filled",
reference
),
struct LazyName {
initialized: OnceLock<String>,
in_recursion: AtomicBool,
}

impl LazyName {
fn new() -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Vec<T>> {
// We need to create a vec of defs according to the order in their ids
let mut defs: Vec<(usize, T)> = Vec::new();
for (reference, def) in self.definitions {
match def.value {
None => return py_schema_err!("Definitions error: definition {} was never filled", reference),
Some(v) => defs.push((def.id, v)),
}
/// Gets the validator name, returning the default in the case of recursion loops
fn get_or_init(&self, init: impl FnOnce() -> String) -> &str {
if let Some(s) = self.initialized.get() {
return s.as_str();
}

if self
.in_recursion
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return "...";
}
let result = self.initialized.get_or_init(init).as_str();
self.in_recursion.store(false, Ordering::SeqCst);
result
}
}

impl Debug for LazyName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.initialized.get().map_or("...", String::as_str).fmt(f)
}
}

impl Clone for LazyName {
fn clone(&self) -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
defs.sort_by_key(|(id, _)| *id);
Ok(defs.into_iter().map(|(_, v)| v).collect())
}
}
8 changes: 8 additions & 0 deletions src/py_gc.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit};
Expand Down Expand Up @@ -35,6 +37,12 @@ impl<T: PyGcTraverse> PyGcTraverse for AHashMap<String, T> {
}
}

impl<T: PyGcTraverse> PyGcTraverse for Arc<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
}
}

impl<T: PyGcTraverse> PyGcTraverse for Box<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
Expand Down
9 changes: 0 additions & 9 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use serde::ser::Error;
use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;
use crate::definitions::Definitions;
use crate::recursion_guard::RecursionGuard;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
Expand Down Expand Up @@ -48,7 +46,6 @@ impl SerializationState {
Extra::new(
py,
mode,
&[],
by_alias,
&self.warnings,
false,
Expand All @@ -72,7 +69,6 @@ impl SerializationState {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct Extra<'a> {
pub mode: &'a SerMode,
pub definitions: &'a Definitions<CombinedSerializer>,
pub ob_type_lookup: &'a ObTypeLookup,
pub warnings: &'a CollectWarnings,
pub by_alias: bool,
Expand All @@ -98,7 +94,6 @@ impl<'a> Extra<'a> {
pub fn new(
py: Python<'a>,
mode: &'a SerMode,
definitions: &'a Definitions<CombinedSerializer>,
by_alias: bool,
warnings: &'a CollectWarnings,
exclude_unset: bool,
Expand All @@ -112,7 +107,6 @@ impl<'a> Extra<'a> {
) -> Self {
Self {
mode,
definitions,
ob_type_lookup: ObTypeLookup::cached(py),
warnings,
by_alias,
Expand Down Expand Up @@ -156,7 +150,6 @@ impl SerCheck {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct ExtraOwned {
mode: SerMode,
definitions: Vec<CombinedSerializer>,
warnings: CollectWarnings,
by_alias: bool,
exclude_unset: bool,
Expand All @@ -176,7 +169,6 @@ impl ExtraOwned {
pub fn new(extra: &Extra) -> Self {
Self {
mode: extra.mode.clone(),
definitions: extra.definitions.to_vec(),
warnings: extra.warnings.clone(),
by_alias: extra.by_alias,
exclude_unset: extra.exclude_unset,
Expand All @@ -196,7 +188,6 @@ impl ExtraOwned {
pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> {
Extra {
mode: &self.mode,
definitions: &self.definitions,
ob_type_lookup: ObTypeLookup::cached(py),
warnings: &self.warnings,
by_alias: self.by_alias,
Expand Down
Loading