From fdc60ccd3bda1447d0e58c283319c9265bab98eb Mon Sep 17 00:00:00 2001 From: pfeme Date: Sun, 30 Jun 2024 21:40:52 +0200 Subject: [PATCH] add caseclass0 check --- .../scala/zio/schema/codec/JsonCodec.scala | 8 +++++-- .../zio/schema/codec/JsonCodecSpec.scala | 21 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala b/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala index b72d8af6d..abaec6bad 100644 --- a/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala +++ b/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala @@ -378,7 +378,9 @@ object JsonCodec { private def enumEncoder[Z](schema: Schema.Enum[Z], cfg: Config, cases: Schema.Case[Z, _]*): ZJsonEncoder[Z] = // if all cases are CaseClass0, encode as a String - if (schema.annotations.exists(_.isInstanceOf[simpleEnum])) { + if (schema.annotations.exists(_.isInstanceOf[simpleEnum]) && cases.forall( + _.schema.isInstanceOf[Schema.CaseClass0[_]] + )) { val caseMap: Map[Z, String] = schema.nonTransientCases .map( @@ -699,7 +701,9 @@ object JsonCodec { }.toMap // if all cases are CaseClass0, decode as String - if (cases.forall(_.schema.isInstanceOf[Schema.CaseClass0[_]])) { + if (parentSchema.annotations.exists(_.isInstanceOf[simpleEnum]) && cases.forall( + _.schema.isInstanceOf[Schema.CaseClass0[_]] + )) { val caseMap: Map[String, Z] = cases.map(case_ => case_.id -> case_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct()).toMap ZJsonDecoder.string.mapOrFail( diff --git a/zio-schema-json/shared/src/test/scala-2/zio/schema/codec/JsonCodecSpec.scala b/zio-schema-json/shared/src/test/scala-2/zio/schema/codec/JsonCodecSpec.scala index 115bee8e5..f9900459a 100644 --- a/zio-schema-json/shared/src/test/scala-2/zio/schema/codec/JsonCodecSpec.scala +++ b/zio-schema-json/shared/src/test/scala-2/zio/schema/codec/JsonCodecSpec.scala @@ -292,6 +292,13 @@ object JsonCodecSpec extends ZIOSpecDefault { charSequenceToByteChunk("""{"oneOf":{"_type":"StringValue2","value":"foo2"}}""") ) }, + test("ADT with intermediate enumeration") { + assertEncodes( + Schema[IntermediateEnum], + IntermediateEnum.FinalElem, + charSequenceToByteChunk(""""FinalElem"""") + ) + }, test("case class") { assertEncodes( searchRequestWithTransientFieldSchema, @@ -1226,6 +1233,12 @@ object JsonCodecSpec extends ZIOSpecDefault { Enumeration2(BooleanValue2(false)) ) }, + test("ADT with intermediate enumeration") { + assertEncodesThenDecodes( + Schema[IntermediateEnum], + IntermediateEnum.FinalElem + ) + }, test("ADT with noDiscriminator") { assertEncodesThenDecodes( Schema[Enumeration3], @@ -1742,6 +1755,14 @@ object JsonCodecSpec extends ZIOSpecDefault { implicit val schema: Schema[Enumeration3] = DeriveSchema.gen[Enumeration3] } + sealed trait IntermediateEnum + + object IntermediateEnum { + sealed trait IntermediateElem extends IntermediateEnum + case object FinalElem extends IntermediateEnum + implicit val schema: zio.schema.Schema[IntermediateEnum] = zio.schema.DeriveSchema.gen[IntermediateEnum] + } + sealed trait Color object Color {