Skip to content

Commit

Permalink
Extend guardrail-dev#1905 to CirceRefinedProtocolGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
blast-hardcheese committed Dec 26, 2023
1 parent 9e64878 commit 1e0370d
Showing 1 changed file with 104 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
components = components
)
oneOf <- fromOneOf(
clsName = formattedClsName,
clsName = NonEmptyList.of(formattedClsName),
model = comp,
definitions.value,
dtoPackage = dtoPackage,
Expand Down Expand Up @@ -226,7 +226,7 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
components
)
oneOf <- fromOneOf(
clsName = formattedClsName,
clsName = NonEmptyList.of(formattedClsName),
model = m,
definitions.value,
dtoPackage = dtoPackage,
Expand Down Expand Up @@ -274,7 +274,7 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
)
(declType, _) <- ModelResolver.determineTypeName[ScalaLanguage, Target](x, Tracker.cloneHistory(x, CustomTypeName(x, prefixes)), components)
oneOf <- fromOneOf(
clsName = formattedClsName,
clsName = NonEmptyList.of(formattedClsName),
model = x,
definitions.value,
dtoPackage = dtoPackage,
Expand Down Expand Up @@ -660,10 +660,15 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
nestedClasses <- nestedDefinitions.flatTraverse {
case classDefinition: ClassDefinition[ScalaLanguage] =>
for {
widenClass <- widenClassDefinition(classDefinition.cls)
companionTerm <- pureTermName(classDefinition.name)
companionDefinition <- wrapToObject(companionTerm, classDefinition.staticDefns.extraImports, classDefinition.staticDefns.definitions, Nil)
widenCompanion <- companionDefinition.traverse(widenObjectDefinition)
widenClass <- widenClassDefinition(classDefinition.cls)
companionTerm <- pureTermName(classDefinition.name)
companionDefinition <- wrapToObject(
companionTerm,
classDefinition.staticDefns.extraImports,
classDefinition.staticDefns.definitions,
classDefinition.staticDefns.statements
)
widenCompanion <- companionDefinition.traverse(widenObjectDefinition)
} yield List(widenClass) ++ widenCompanion.fold(classDefinition.staticDefns.definitions)(List(_))
case enumDefinition: EnumDefinition[ScalaLanguage] =>
for {
Expand All @@ -683,7 +688,7 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
}

private[this] def fromOneOf(
clsName: String,
clsName: NonEmptyList[String],
model: Tracker[Schema[_]],
definitions: List[(String, Tracker[Schema[_]])],
dtoPackage: List[String],
Expand All @@ -697,45 +702,79 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
Fw: FrameworkTerms[ScalaLanguage, Target],
Sw: OpenAPITerms[ScalaLanguage, Target],
P: ProtocolTerms[ScalaLanguage, Target]
): Target[Either[String, StrictProtocolElems[ScalaLanguage]]] =
): Target[Either[String, ClassDefinition[ScalaLanguage]]] =
NonEmptyList
.fromList(model.downField("oneOf", _.getOneOf()).indexedDistribute)
.fold[Target[Either[String, StrictProtocolElems[ScalaLanguage]]]](Target.pure(Left("Does not have oneOf"))) { xs =>
.fold[Target[Either[String, ClassDefinition[ScalaLanguage]]]](Target.pure(Left("Does not have oneOf"))) { xs =>
for {
nameGenerator <- Target.pure(new java.util.concurrent.atomic.AtomicInteger(1))
getNewName = () => Target.pure(nameGenerator).map(ng => s"Nested${ng.getAndIncrement()}")
getNewName = () => Target.pure(nameGenerator).map(ng => NonEmptyList.ofInitLast(dtoPackage :+ clsName.last, s"Nested${ng.getAndIncrement()}"))
fullyQualifyPackageName = { case (outer, PropMeta(name, tpe)) =>
tpe match {
case Type.Name(tn) =>
for {
fullInstanceType <- Sc.selectType(NonEmptyList.ofInitLast(dtoPackage ++ outer, tn))
} yield PropMeta(name, fullInstanceType)
case other =>
Target.raiseUserError(s"Unknown type structure: ${other}")
}
}: (Option[String], PropMeta[ScalaLanguage]) => Target[PropMeta[ScalaLanguage]]
(rawTypes, nestedClasses) <- xs
.traverse { model =>
for {
resolved <- ModelResolver.propMetaWithName[ScalaLanguage, Target](() => getNewName().flatMap(Sc.pureTypeName), model, components)
resolved <- ModelResolver.propMetaWithName[ScalaLanguage, Target](() => getNewName().flatMap(Sc.selectType), model, components)
(rawType, nestedClasses) <- resolved
.bitraverse(
deferred => Target.fromOption(concreteTypes.find(_.clsName == deferred.value), UserError("Not supported")),
deferred =>
for {
propMeta <- Target.fromOption(concreteTypes.find(_.clsName == deferred.value), UserError("Not supported"))
fqPropMeta <- fullyQualifyPackageName(None, propMeta)
} yield Left(fqPropMeta),
{
case core.Resolved(Type.Name(dtoName), _, _, _) =>
renderIntermediate(
model,
dtoName,
concreteTypes,
definitions,
dtoPackage,
supportPackage,
defaultPropertyRequirement,
components
)
// There's not a great way to close the loop in ModelResolver.propMetaWithName to the fact that
// we are generating a new class, and that we need a whole bunch of other infrastructure here.
// We rely on ModelResolver, but only generate a class if the ModelResolver can't find a class.
case core.Resolved(tpe, _, _, _) if tpe.syntax.startsWith((dtoPackage :+ clsName.last).mkString(".")) =>
for {
dtoName <- tpe match {
case Type.Name(x) => Target.pure(x)
case Type.Select(_, Type.Name(x)) => Target.pure(x)
case other => Target.raiseUserError(s"Unexpected type ${other}")
}
(pm, defns) <- renderIntermediate(
model,
dtoName,
concreteTypes,
definitions,
dtoPackage,
supportPackage,
defaultPropertyRequirement,
components
)
fqPropMeta <- fullyQualifyPackageName(Some(clsName.last), pm)
} yield Right((fqPropMeta, defns))
case core.Resolved(tpe, _, _, _) =>
for {
dtoName <- getNewName()

prefixes <- Cl.vendorPrefixes()
customTpe = CustomTypeName(model, prefixes).getOrElse(dtoName.last)
pm = PropMeta[ScalaLanguage](customTpe, tpe)

} yield Left(pm)
case other => Target.raiseUserError(s"Unexpected case ${other}")
}
)
.map(e => List(e).partitionEither(identity _))
.map(e => List(e.merge).partitionEither(identity _))
} yield (rawType, nestedClasses)
}
.map(_.unzip)
(nestedPMs, (nestedEncoders, nestedDecoders, nestedDefns)) = nestedClasses.toList.flatten.unzip.map(_.unzip3)
rawTypes <- Target.fromOption(
NonEmptyList.fromList(rawTypes.toList.flatten.map((None, _)) ++ nestedPMs.map((Some(clsName), _))),
NonEmptyList.fromList(rawTypes.toList.flatten ++ nestedPMs),
UserError("Expected at least some models")
)
fullType <- Sc.selectType(NonEmptyList.ofInitLast(dtoPackage, clsName))
fullType <- Sc.selectType(clsName.prependList(dtoPackage))

discriminator_ = model.downField("discriminator", _.getDiscriminator()).indexedDistribute

Expand Down Expand Up @@ -777,25 +816,26 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,

(members, (applyAdapters, (decoders, encoders))) <- rawTypes
.traverse {
case (outer, PropMeta(memberName, tpe @ Type.Name(name))) =>
case PropMeta(memberName, tpe) =>
for {
fullInstanceType <- Sc.selectType(NonEmptyList.ofInitLast(dtoPackage ++ outer, name))
termName = Term.Name(name)
applyAdapter = q"implicit val ${Pat.Var(Term.Name(s"from${name}"))}: ${fullInstanceType} => ${Type.Name(clsName)} = members.${termName}.apply _"
member = q"case class ${tpe}(value: ${fullInstanceType}) extends ${Init.After_4_6_0(Type.Name(clsName), Name.Anonymous(), Nil)}"
decoder = validateDecoder(memberName, q"_root_.io.circe.Decoder[${tpe}].map(${Term.Name(s"from${name}")})")
encoder = p"case members.${termName}(member) => ${injectDiscriminator(memberName, q"member")}"
() <- Target.pure(())
termName = Term.Name(memberName)
typeName = Type.Name(memberName)
applyAdapter = q"implicit val ${Pat.Var(Term.Name(s"from${memberName}"))}: ${tpe} => ${Type.Name(clsName.last)} = members.${termName}.apply _"
member = q"case class ${typeName}(value: ${tpe}) extends ${Init.After_4_6_0(fullType, Name.Anonymous(), Nil)}"
decoder = validateDecoder(memberName, q"_root_.io.circe.Decoder[${tpe}].map(${Term.Name(s"from${memberName}")})")
encoder = p"case members.${termName}(member) => ${injectDiscriminator(memberName, q"member")}"
} yield (member, (applyAdapter, (decoder, encoder)))
case (_, other) =>
case other =>
Target.raiseUserError(s"Unsupported case in oneOf, ${other}. Somehow got a complex type, we expected a singular Type.Name.")
}
.map(_.unzip.map(_.unzip.map(_.unzip))) // NonEmptyList#unzip4 adapter :see_no_evil:

encoder = q"""implicit def ${Term.Name(s"encode${clsName}")}: _root_.io.circe.Encoder[${Type
.Name(clsName)}] = _root_.io.circe.Encoder.instance(${Term
encoder = q"""implicit def ${Term.Name(s"encode${clsName.last}")}: _root_.io.circe.Encoder[${Type
.Name(clsName.last)}] = _root_.io.circe.Encoder.instance(${Term
.PartialFunction(encoders.toList)}) """
decoder = q"""
implicit def ${Term.Name(s"decode${clsName}")}: _root_.io.circe.Decoder[${Type.Name(clsName)}] = {
implicit def ${Term.Name(s"decode${clsName.last}")}: _root_.io.circe.Decoder[${Type.Name(clsName.last)}] = {
..${decoders.zipWithIndex.toList.map { case (decoder, idx) =>
q"val ${Pat.Var(Term.Name(s"dec${idx}"))} = ${decoder}"
}};
Expand All @@ -804,16 +844,16 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
"""
statements = List(
q"object members { ..${members.toList} }",
q"def apply[A](value: A)(implicit ev: A => ${Type.Name(clsName)}): ${Type.Name(clsName)} = ev(value)"
q"def apply[A](value: A)(implicit ev: A => ${Type.Name(clsName.last)}): ${Type.Name(clsName.last)} = ev(value)"
) ++ applyAdapters.toList ++ nestedDefns ++ nestedEncoders.flatten ++ nestedDecoders.flatten
defn = q"""sealed abstract class ${Type.Name(clsName)} {}"""
defn = q"""sealed abstract class ${Type.Name(clsName.last)} {}"""
staticDefns = StaticDefns[ScalaLanguage](
className = clsName,
className = clsName.last,
extraImports = List.empty,
definitions = List(encoder, decoder),
statements = statements
)
} yield Right(ClassDefinition[ScalaLanguage](clsName, Type.Name(clsName), fullType, defn, staticDefns, Nil))
} yield Right(ClassDefinition[ScalaLanguage](clsName.last, Type.Name(clsName.last), fullType, defn, staticDefns, Nil))
}

private def prepareProperties(
Expand Down Expand Up @@ -860,17 +900,29 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
.orRefine { case o: ComposedSchema => o }(o =>
for {
parents <- extractParents(o, definitions, concreteTypes, dtoPackage, supportPackage, defaultPropertyRequirement, components)
maybeClassDefinition <- fromModel(
nestedClassName,
o,
parents,
concreteTypes,
definitions,
dtoPackage,
supportPackage,
defaultPropertyRequirement,
components
model <- fromModel(
clsName = nestedClassName,
model = o,
parents = parents,
concreteTypes = concreteTypes,
definitions = definitions,
dtoPackage = dtoPackage,
supportPackage = supportPackage,
defaultPropertyRequirement = defaultPropertyRequirement,
components = components
)
oneOf <- fromOneOf(
clsName = nestedClassName,
model = o,
// parents,
definitions = definitions,
dtoPackage = dtoPackage,
supportPackage = supportPackage,
concreteTypes = concreteTypes,
defaultPropertyRequirement = defaultPropertyRequirement,
components = components
)
maybeClassDefinition = model.orElse(oneOf)
} yield Option(maybeClassDefinition)
)
.orRefine { case a: ArraySchema => a }(_.downField("items", _.getItems()).indexedCosequence.flatTraverse(processProperty(name, _)))
Expand Down

0 comments on commit 1e0370d

Please sign in to comment.