Skip to content

Commit

Permalink
feat: typed enum fields
Browse files Browse the repository at this point in the history
Add typed_enum_fields method to prost-build configuration, which
allows type-checked representation of enumerations in fields of
message structs and variants of oneof enums. The argument and the
invocation order works like with the boxed method.

Depending on the syntax (and preparing for the future support of
editions), the type-checked representation can be closed (for proto2)
or open (for proto3). The former is represented by the generated
enum type itself, while the latter is represented by OpenEnum
wrapping the enum type.

A new enum_type annotation is supported in the prost attribute
inside derives, which allows to specify the type-checked representation
of enum types in message fields and oneof variants.
The accepted values are "open" or "closed".
  • Loading branch information
mzabaluev committed Nov 17, 2024
1 parent 3fee552 commit cb5e59c
Show file tree
Hide file tree
Showing 12 changed files with 26,210 additions and 455 deletions.
89 changes: 66 additions & 23 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl CodeGenerator<'_> {

fn append_field(&mut self, fq_message_name: &str, field: &Field) {
let type_ = field.descriptor.r#type();
let repeated = field.descriptor.label.and_then(|v| v.known()) == Some(Label::Repeated);
let repeated = field.descriptor.label == Some(Label::Repeated as i32);
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
let boxed = self.boxed(&field.descriptor, fq_message_name, None);
Expand All @@ -427,15 +427,19 @@ impl CodeGenerator<'_> {
let type_tag = self.field_type_tag(&field.descriptor);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
let bytes_type = self
.config
.bytes_type
.get_first_field(fq_message_name, field.descriptor.name())
.copied()
.unwrap_or_default();
self.buf
.push_str(&format!("={:?}", bytes_type.annotation()));
match type_ {
Type::Bytes => {
let bytes_type = self
.config
.bytes_type
.get_first_field(fq_message_name, field.descriptor.name())
.copied()
.unwrap_or_default();
self.buf
.push_str(&format!("={:?}", bytes_type.annotation()));
}
Type::Enum => self.push_enum_type_annotation(fq_message_name, field.descriptor.name()),
_ => {}
}

match field.descriptor.label() {
Expand Down Expand Up @@ -555,12 +559,16 @@ impl CodeGenerator<'_> {
let value_tag = self.map_value_type_tag(value);

self.buf.push_str(&format!(
"#[prost({}=\"{}, {}\", tag=\"{}\")]\n",
"#[prost({}=\"{}, {}\"",
map_type.annotation(),
key_tag,
value_tag,
field.descriptor.number()
));
if value.r#type() == Type::Enum {
self.push_enum_type_annotation(fq_message_name, field.descriptor.name());
}
self.buf
.push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number()));
self.append_field_attributes(fq_message_name, field.descriptor.name());
self.push_indent();
self.buf.push_str(&format!(
Expand Down Expand Up @@ -639,11 +647,12 @@ impl CodeGenerator<'_> {

self.push_indent();
let ty_tag = self.field_type_tag(&field.descriptor);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
field.descriptor.number()
));
self.buf.push_str(&format!("#[prost({}", ty_tag,));
if field.descriptor.r#type() == Type::Enum {
self.push_enum_type_annotation(&oneof_name, field.descriptor.name());
}
self.buf
.push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number()));
self.append_field_attributes(&oneof_name, field.descriptor.name());

self.push_indent();
Expand Down Expand Up @@ -947,6 +956,14 @@ impl CodeGenerator<'_> {
self.buf.push_str("}\n");
}

fn push_enum_type_annotation(&mut self, fq_message_name: &str, field_name: &str) {
match self.enum_field_repr(fq_message_name, field_name) {
EnumRepr::Int => {}
EnumRepr::Open => self.buf.push_str(", enum_type=\"open\""),
EnumRepr::Closed => self.buf.push_str(", enum_type=\"closed\""),
}
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
match field.r#type() {
Type::Float => String::from("f32"),
Expand All @@ -966,11 +983,15 @@ impl CodeGenerator<'_> {
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
Type::Enum => format!(
"{}::OpenEnum<{}>",
prost_path(self.config),
self.resolve_ident(field.type_name())
),
Type::Enum => match self.enum_field_repr(fq_message_name, field.name()) {
EnumRepr::Int => String::from("i32"),
EnumRepr::Open => format!(
"{}::OpenEnum<{}>",
prost_path(self.config),
self.resolve_ident(field.type_name())
),
EnumRepr::Closed => self.resolve_ident(field.type_name()),
},
}
}

Expand Down Expand Up @@ -1012,6 +1033,22 @@ impl CodeGenerator<'_> {
.join("::")
}

fn enum_field_repr(&self, fq_message_name: &str, field_name: &str) -> EnumRepr {
if self
.config
.typed_enum_fields
.get_first_field(fq_message_name, field_name)
.is_some()
{
match self.syntax {
Syntax::Proto2 => EnumRepr::Closed,
Syntax::Proto3 => EnumRepr::Open,
}
} else {
EnumRepr::Int
}
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
match field.r#type() {
Type::Float => Cow::Borrowed("float"),
Expand Down Expand Up @@ -1074,7 +1111,7 @@ impl CodeGenerator<'_> {
fq_message_name: &str,
oneof: Option<&str>,
) -> bool {
let repeated = field.label.and_then(|v| v.known()) == Some(Label::Repeated);
let repeated = field.label == Some(Label::Repeated as i32);
let fd_type = field.r#type();
if !repeated
&& (fd_type == Type::Message || fd_type == Type::Group)
Expand Down Expand Up @@ -1148,6 +1185,12 @@ fn can_pack(field: &FieldDescriptorProto) -> bool {
)
}

enum EnumRepr {
Int,
Closed,
Open,
}

struct EnumVariantMapping<'a> {
path_idx: usize,
proto_name: &'a str,
Expand Down
26 changes: 26 additions & 0 deletions prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct Config {
pub(crate) enum_attributes: PathMap<String>,
pub(crate) field_attributes: PathMap<String>,
pub(crate) boxed: PathMap<()>,
pub(crate) typed_enum_fields: PathMap<()>,
pub(crate) prost_types: bool,
pub(crate) strip_enum_prefix: bool,
pub(crate) out_dir: Option<PathBuf>,
Expand Down Expand Up @@ -372,6 +373,30 @@ impl Config {
self
}

/// Represent Protobuf enum types encountered in matched fields with types
/// bound to their corresponding Rust enum types, rather than the default `i32`.
///
/// Depending on the proto file syntax, the representation type can be:
/// * For closed enums (in proto2), the corresponding Rust enum type.
/// * For open enums (in proto3), the Rust enum type wrapped in [`OpenEnum`].
///
/// # Arguments
///
/// **`path`** - a path matching any number of fields. These fields will get the type-checked
/// enum representation.
/// For details about matching fields see [`btree_map`](#method.btree_map).
///
/// # Examples
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// config.typed_enum_fields(".my_messages");
/// ```
pub fn typed_enum_fields(&mut self, path: impl AsRef<str>) -> &mut Self {
self.typed_enum_fields.insert(path.as_ref().to_owned(), ());
self
}

/// Configures the code generator to use the provided service generator.
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
Expand Down Expand Up @@ -1158,6 +1183,7 @@ impl default::Default for Config {
enum_attributes: PathMap::default(),
field_attributes: PathMap::default(),
boxed: PathMap::default(),
typed_enum_fields: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
out_dir: None,
Expand Down
Loading

0 comments on commit cb5e59c

Please sign in to comment.