Skip to content

Commit

Permalink
Honor custom types for array items
Browse files Browse the repository at this point in the history
Previously this only worked for protocol objects; now it also works for
parameters and in-line request bodies.
  • Loading branch information
kelnos committed Jun 5, 2021
1 parent 28b5ce2 commit 84d2583
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ object ProtocolGenerator {
typeName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName))
tpe <- selectType(typeName)
maybeNestedDefinition <- processProperty(name, schema)
resolvedType <- SwaggerUtil.propMetaWithName(tpe, schema)
resolvedType <- SwaggerUtil.propMetaWithName(tpe, schema, Cl.liftVectorType)
customType <- SwaggerUtil.customTypeName(schema)
propertyRequirement = getPropertyRequirement(schema, requiredFields.contains(name), defaultPropertyRequirement)
defValue <- defaultValue(typeName, schema, propertyRequirement, definitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ object SwaggerUtil {
case (Some("integer"), Some("int64")) => longType()
case (Some("integer"), fmt) => integerType(fmt).map(log(fmt, _))
case (Some("boolean"), fmt) => booleanType(fmt).map(log(fmt, _))
case (Some("array"), fmt) => arrayType(fmt).map(log(fmt, _))
case (Some("array"), fmt) => stringType(None).flatMap(liftArrayType(_, None)).map(log(fmt, _))
case (Some("file"), fmt) =>
fileType(None).map(log(fmt, _))
case (Some("binary"), fmt) =>
Expand Down Expand Up @@ -330,7 +330,7 @@ object SwaggerUtil {
Fw: FrameworkTerms[L, F]
): F[ResolvedType[L]] = {
import Fw._
propMetaImpl(property)(
propMetaImpl(property, Cl.liftVectorType)(
_.refine({ case o: ObjectSchema => o })(
o =>
for {
Expand All @@ -350,14 +350,14 @@ object SwaggerUtil {
} else Option.empty[L#Type].pure[F]
}

def propMetaWithName[L <: LA, F[_]](tpe: L#Type, property: Tracker[Schema[_]])(
def propMetaWithName[L <: LA, F[_]](tpe: L#Type, property: Tracker[Schema[_]], arrayTypeLifter: (L#Type, Option[L#Type]) => F[L#Type])(
implicit Sc: LanguageTerms[L, F],
Cl: CollectionsLibTerms[L, F],
Sw: SwaggerTerms[L, F],
Fw: FrameworkTerms[L, F]
): F[ResolvedType[L]] = {
import Fw._
propMetaImpl(property)(
propMetaImpl(property, arrayTypeLifter)(
_.refine({ case schema: ObjectSchema if Option(schema.getProperties).exists(p => !p.isEmpty) => schema })(
_ => (Resolved[L](tpe, None, None, None, None): ResolvedType[L]).pure[F]
).orRefine({ case o: ObjectSchema => o })(
Expand All @@ -375,7 +375,7 @@ object SwaggerUtil {
)
}

private def propMetaImpl[L <: LA, F[_]](property: Tracker[Schema[_]])(
private def propMetaImpl[L <: LA, F[_]](property: Tracker[Schema[_]], arrayTypeLifter: (L#Type, Option[L#Type]) => F[L#Type])(
strategy: Tracker[Schema[_]] => Either[Tracker[Schema[_]], F[ResolvedType[L]]]
)(implicit Sc: LanguageTerms[L, F], Cl: CollectionsLibTerms[L, F], Sw: SwaggerTerms[L, F], Fw: FrameworkTerms[L, F]): F[ResolvedType[L]] =
Sw.log.function("propMeta") {
Expand Down Expand Up @@ -411,11 +411,11 @@ object SwaggerUtil {
p =>
for {
items <- getItems(p)
rec <- propMetaImpl[L, F](items)(strategy)
rec <- propMetaImpl[L, F](items, arrayTypeLifter)(strategy)
arrayType <- customArrayTypeName(p).flatMap(_.flatTraverse(x => parseType(Tracker.cloneHistory(p, x))))
res <- rec match {
case Resolved(inner, dep, default, _, _) =>
(liftVectorType(inner, arrayType), default.traverse(liftVectorTerm))
(arrayTypeLifter(inner, arrayType), default.traverse(liftVectorTerm))
.mapN(Resolved[L](_, dep, _, None, None): ResolvedType[L])
case x: DeferredMap[L] => embedArray(x, arrayType)
case x: DeferredArray[L] => embedArray(x, arrayType)
Expand All @@ -430,7 +430,7 @@ object SwaggerUtil {
.downField("additionalProperties", _.getAdditionalProperties())
.map(_.getOrElse(false))
.refine[F[ResolvedType[L]]]({ case b: java.lang.Boolean => b })(_ => objectType(None).map(Resolved[L](_, None, None, None, None)))
.orRefine({ case s: Schema[_] => s })(propMetaImpl[L, F](_)(strategy))
.orRefine({ case s: Schema[_] => s })(propMetaImpl[L, F](_, arrayTypeLifter)(strategy))
.orRefineFallback({ s =>
log.debug(s"Unknown structure cannot be reflected: ${s.unwrapTracker} (${s.showHistory})") >> objectType(None).map(
Resolved[L](_, None, None, None, None)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package com.twilio.guardrail
package generators

import io.swagger.v3.oas.models.media.Schema
import io.swagger.v3.oas.models.parameters._
import cats.syntax.all._
import com.twilio.guardrail.SwaggerUtil.ResolvedType
import com.twilio.guardrail.core.Tracker
import com.twilio.guardrail.extract.{ Default, FileHashAlgorithm }
import com.twilio.guardrail.generators.syntax._
import com.twilio.guardrail.languages.LA
import com.twilio.guardrail.shims._
import com.twilio.guardrail.terms.{ CollectionsLibTerms, LanguageTerms, SwaggerTerms }
import com.twilio.guardrail.terms.framework.FrameworkTerms
import cats.syntax.all._
import com.twilio.guardrail.SwaggerUtil.ResolvedType
import com.twilio.guardrail.terms.{ CollectionsLibTerms, LanguageTerms, SwaggerTerms }
import io.swagger.v3.oas.models.media.Schema
import io.swagger.v3.oas.models.parameters._

case class RawParameterName private[generators] (value: String)
case class RawParameterType private[generators] (tpe: Option[String], format: Option[String])
Expand Down Expand Up @@ -55,9 +55,9 @@ object LanguageParameter {
Cl: CollectionsLibTerms[L, F],
Sw: SwaggerTerms[L, F]
): Tracker[Parameter] => F[LanguageParameter[L]] = { parameter =>
import Cl._
import Fw._
import Sc._
import Cl._
import Sw._

def paramMeta(param: Tracker[Parameter]): F[SwaggerUtil.ResolvedType[L]] = {
Expand Down Expand Up @@ -86,9 +86,12 @@ object LanguageParameter {
customParamTypeName <- SwaggerUtil.customTypeName(param)
customSchemaTypeName <- schema.unwrapTracker.flatTraverse(SwaggerUtil.customTypeName(_: Schema[_]))
customTypeName = Tracker.cloneHistory(schema, customSchemaTypeName).fold(Tracker.cloneHistory(param, customParamTypeName))(_.map(Option.apply))
res <- (SwaggerUtil.typeName[L, F](tpeName.map(Option(_)), fmt, customTypeName), getDefault(tpeName.unwrapTracker, fmt, param))
.mapN(SwaggerUtil.Resolved[L](_, None, _, Some(tpeName.unwrapTracker), fmt.unwrapTracker))
} yield res
tpe <- SwaggerUtil.typeName[L, F](tpeName.map(Option(_)), fmt, customTypeName)
resolvedType <- schema.fold[F[ResolvedType[L]]](
(tpe.pure[F], getDefault(tpeName.unwrapTracker, fmt, param))
.mapN(SwaggerUtil.Resolved[L](_, None, _, Some(tpeName.unwrapTracker), fmt.unwrapTracker))
)(schema => SwaggerUtil.propMetaWithName(tpe, schema, liftArrayType))
} yield resolvedType

def paramHasRefSchema(p: Parameter): Boolean = Option(p.getSchema).exists(s => Option(s.get$ref()).nonEmpty)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.twilio.guardrail.{ SwaggerUtil, Target }
import com.twilio.guardrail.generators.syntax.Java._
import com.twilio.guardrail.languages.JavaLanguage
import com.twilio.guardrail.terms.CollectionsLibTerms

@SuppressWarnings(Array("org.wartremover.warts.Null"))
object JavaCollectionsGenerator {

Expand All @@ -29,8 +30,8 @@ object JavaCollectionsGenerator {

def emptyOptionalTerm(): Target[Node] = buildMethodCall("java.util.Optional.empty")

def arrayType(format: Option[String]): Target[Type] =
safeParseClassOrInterfaceType("java.util.List").map(_.setTypeArguments(new NodeList[Type](STRING_TYPE)))
def liftArrayType(value: Type, customTpe: Option[Type]): Target[Type] =
liftVectorType(value, customTpe)

def liftVectorType(value: Type, customTpe: Option[Type]): Target[Type] =
customTpe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,22 @@ object JavaVavrCollectionsGenerator {

override def emptyOptionalTerm(): Target[Node] = buildMethodCall("io.vavr.control.Option.none")

override def arrayType(format: Option[String]): Target[Type] =
safeParseClassOrInterfaceType("io.vavr.collection.Vector").map(_.setTypeArguments(new NodeList[Type](STRING_TYPE)))

override def liftVectorType(value: Type, customTpe: Option[Type]): Target[Type] =
private def liftSimpleType(value: Type, tpeStr: String, customTpe: Option[Type]): Target[Type] =
customTpe
.fold[Target[ClassOrInterfaceType]](safeParseClassOrInterfaceType("io.vavr.collection.Vector").map(identity))({
.fold[Target[ClassOrInterfaceType]](safeParseClassOrInterfaceType(tpeStr))({
case t: ClassOrInterfaceType =>
Target.pure(t)
case x =>
Target.raiseUserError(s"Unsure how to map $x")
})
.map(_.setTypeArguments(new NodeList(value)))

override def liftArrayType(value: Type, customTpe: Option[Type]): Target[Type] =
liftSimpleType(value, "io.vavr.collection.Vector", customTpe)

override def liftVectorType(value: Type, customTpe: Option[Type]): Target[Type] =
liftSimpleType(value, "io.vavr.collection.Vector", customTpe)

override def liftVectorTerm(value: Node): Target[Node] =
buildMethodCall("io.vavr.collection.Vector.of", Some(value))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ object ScalaCollectionsGenerator {
def liftSomeTerm(value: Term): Target[Term] = Target.pure(q"Some($value)")
def emptyOptionalTerm(): Target[Term] = Target.pure(q"None")

def arrayType(format: Option[String]): Target[Type] = Target.pure(t"Iterable[String]")
def liftArrayType(value: Type, customTpe: Option[Type]): Target[Type] =
Target.pure(t"${customTpe.getOrElse(t"Iterable")}[$value]")
def liftVectorType(value: Type, customTpe: Option[Type]): Target[Type] =
Target.pure(t"${customTpe.getOrElse(t"Vector")}[$value]")
def liftVectorTerm(value: Term): Target[Term] = Target.pure(q"Vector($value)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ abstract class CollectionsLibTerms[L <: LA, F[_]] {
def liftSomeTerm(value: L#Term): F[L#Term]
def emptyOptionalTerm(): F[L#Term]

def arrayType(format: Option[String]): F[L#Type]
def liftArrayType(value: L#Type, customTpe: Option[L#Type]): F[L#Type]
def liftVectorType(value: L#Type, customTpe: Option[L#Type]): F[L#Type]
def liftVectorTerm(value: L#Term): F[L#Term]
def emptyArray(): F[L#Term]
Expand All @@ -31,7 +31,7 @@ abstract class CollectionsLibTerms[L <: LA, F[_]] {
newLiftOptionalTerm: L#Term => F[L#Term] = liftOptionalTerm,
newLiftSomeTerm: L#Term => F[L#Term] = liftSomeTerm,
newEmptyOptionalTerm: () => F[L#Term] = emptyOptionalTerm _,
newArrayType: Option[String] => F[L#Type] = arrayType,
newLiftArrayType: (L#Type, Option[L#Type]) => F[L#Type] = liftArrayType,
newLiftVectorType: (L#Type, Option[L#Type]) => F[L#Type] = liftVectorType,
newLiftVectorTerm: L#Term => F[L#Term] = liftVectorTerm,
newEmptyArray: () => F[L#Term] = emptyArray _,
Expand All @@ -46,7 +46,7 @@ abstract class CollectionsLibTerms[L <: LA, F[_]] {
def liftOptionalTerm(value: L#Term) = newLiftOptionalTerm(value)
def liftSomeTerm(value: L#Term) = newLiftSomeTerm(value)
def emptyOptionalTerm() = newEmptyOptionalTerm()
def arrayType(format: Option[String]) = newArrayType(format)
def liftArrayType(value: L#Type, customTpe: Option[L#Type]) = newLiftArrayType(value, customTpe)
def liftVectorType(value: L#Type, customTpe: Option[L#Type]) = newLiftVectorType(value, customTpe)
def liftVectorTerm(value: L#Term) = newLiftVectorTerm(value)
def emptyArray() = newEmptyArray()
Expand Down
78 changes: 75 additions & 3 deletions modules/codegen/src/test/scala/swagger/VendorExtensionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ package swagger

import java.util

import cats.data.NonEmptyList
import com.twilio.guardrail.extract.VendorExtension
import com.twilio.guardrail.generators.Scala.AkkaHttp
import com.twilio.guardrail.{ ClassDefinition, Client, Clients, CodegenTarget, Context, ProtocolDefinitions }
import io.swagger.parser.OpenAPIParser
import io.swagger.v3.parser.core.models.ParseOptions

import scala.collection.JavaConverters._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import support.SwaggerSpecRunner

import scala.collection.JavaConverters._

class VendorExtensionTest extends AnyFunSuite with Matchers {
class VendorExtensionTest extends AnyFunSuite with Matchers with SwaggerSpecRunner {

val spec: String = s"""
|swagger: '2.0'
Expand Down Expand Up @@ -80,4 +84,72 @@ class VendorExtensionTest extends AnyFunSuite with Matchers {
} ()
} ()
}

private val itemsOverrideSpec =
s"""
|openapi: 3.0.1
|info:
| title: Foo
| version: 0.0.0
|components:
| schemas:
| Foo:
| type: object
| required:
| - arrayprop
| properties:
| arrayprop:
| type: array
| items:
| type: string
| x-jvm-type: SomethingElse
|paths:
| /foo:
| get:
| operationId: getFoo
| parameters:
| - name: arrayquery
| in: query
| required: true
| schema:
| type: array
| items:
| type: string
| x-jvm-type: SomethingElse
| - name: regularquery
| in: query
| required: true
| schema:
| type: string
| x-jvm-type: SomethingElse
| requestBody:
| required: true
| content:
| application/json:
| schema:
| required:
| - arrayprop
| properties:
| arrayprop:
| type: array
| items:
| type: string
| x-jvm-type: SomethingElse
| responses:
| 200: {}
|""".stripMargin

test("CustomTypeName works on array items") {
val (
ProtocolDefinitions(ClassDefinition("Foo", _, _, protoCls, _, _) :: Nil, _, _, _, _),
Clients(Client(_, _, _, _, NonEmptyList(Right(clientCls), Nil), _) :: Nil, _),
_
) =
runSwaggerSpec(itemsOverrideSpec)(Context.empty, AkkaHttp, targets = NonEmptyList.one(CodegenTarget.Client))

protoCls.syntax shouldBe """case class Foo(arrayprop: Vector[SomethingElse] = Vector.empty)"""
clientCls.syntax should include(
"def getFoo(arrayquery: Iterable[SomethingElse], regularquery: SomethingElse, arrayprop: Iterable[SomethingElse], headers: List[HttpHeader] = Nil)"
)
}
}

0 comments on commit 84d2583

Please sign in to comment.