Skip to content

Commit

Permalink
Merge branch 'awslabs:main' into fix/dist-files
Browse files Browse the repository at this point in the history
  • Loading branch information
riywo authored Apr 10, 2023
2 parents d0bc925 + 93dfc75 commit 8355cce
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ private void writeSerdeDispatcher(boolean isInput) {
writer.write("throw new Error(\"No supported protocol was found\");");
} else {
String serdeFunctionName = isInput
? ProtocolGenerator.getSerFunctionName(symbol, protocolGenerator.getName())
: ProtocolGenerator.getDeserFunctionName(symbol, protocolGenerator.getName());
? ProtocolGenerator.getSerFunctionShortName(symbol)
: ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.addImport(serdeFunctionName, serdeFunctionName,
Paths.get(".", CodegenUtils.SOURCE_FOLDER, ProtocolGenerator.PROTOCOLS_FOLDER,
ProtocolGenerator.getSanitizedName(protocolGenerator.getName())).toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ private String getDelegateDeserializer(Shape shape) {
private String getDelegateDeserializer(Shape shape, String customDataSource) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);
return ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName())
return ProtocolGenerator.getDeserFunctionShortName(symbol)
+ "(" + customDataSource + ", context)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ public final String unionShape(UnionShape shape) {
private String getDelegateSerializer(Shape shape) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);
return ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName())
return ProtocolGenerator.getSerFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,12 @@ protected final void generateDeserFunction(

Symbol symbol = symbolProvider.toSymbol(shape);
// Use the shape name for the function name.
String methodName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName =
ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());

writer.addImport(symbol, symbol.getName());
writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: __SerdeContext\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,12 @@ private void generateSerFunction(

Symbol symbol = symbolProvider.toSymbol(shape);
// Use the shape name for the function name.
String methodName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());
String methodName = ProtocolGenerator.getSerFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());

writer.addImport(symbol, symbol.getName());

writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " input: $T,\n"
+ " context: __SerdeContext\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,31 +177,34 @@ public void generateEventStreamDeserializers(

private void generateEventStreamSerializer(GenerationContext context, UnionShape eventsUnion) {
String methodName = getSerFunctionName(context, eventsUnion);
String methodLongName = ProtocolGenerator.getSerFunctionName(getSymbol(context, eventsUnion),
context.getProtocolName());
Symbol eventsUnionSymbol = getSymbol(context, eventsUnion);
TypeScriptWriter writer = context.getWriter();
Model model = context.getModel();
writer.addImport("Message", "__Message", TypeScriptDependency.AWS_SDK_TYPES.packageName);

writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " input: any,\n"
+ " context: $L\n"
+ "): any => {", "}", methodName, getEventStreamSerdeContextType(context, eventsUnion), () -> {
writer.openBlock("const eventMarshallingVisitor = (event: any): __Message => $T.visit(event, {", "});",
eventsUnionSymbol, () -> {
eventsUnion.getAllMembers().forEach((memberName, memberShape) -> {
StructureShape target = model.expectShape(memberShape.getTarget(), StructureShape.class);
StructureShape target = model.expectShape(memberShape.getTarget(), StructureShape.class);
String eventSerMethodName = getEventSerFunctionName(context, target);
writer.write("$L: value => $L(value, context),", memberName, eventSerMethodName);
});
writer.write("_: value => value as any");
});
writer.write("return context.eventStreamMarshaller.serialize(input, eventMarshallingVisitor);");
});
writer.write("return context.eventStreamMarshaller.serialize(input, eventMarshallingVisitor);");
});
}

private String getSerFunctionName(GenerationContext context, Shape shape) {
Symbol symbol = getSymbol(context, shape);
String protocolName = context.getProtocolName();
return ProtocolGenerator.getSerFunctionName(symbol, protocolName);
return ProtocolGenerator.getSerFunctionShortName(symbol);
}

public String getEventSerFunctionName(GenerationContext context, Shape shape) {
Expand Down Expand Up @@ -347,7 +350,7 @@ private void writeEventBody(
writer.write("body = context.utf8Decoder(input.$L);", payloadMemberName);
} else if (payloadShape instanceof BlobShape || payloadShape instanceof StringShape) {
Symbol symbol = getSymbol(context, payloadShape);
String serFunctionName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
documentShapesToSerialize.add(payloadShape);
writer.write("body = $L(input.$L, context);", payloadMemberName, serFunctionName);
serializeInputEventDocumentPayload.run();
Expand All @@ -364,7 +367,7 @@ private void writeEventBody(
}
}
Symbol symbol = getSymbol(context, event);
String serFunctionName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
documentShapesToSerialize.add(event);
writer.write("body = $L(input, context);", serFunctionName);
serializeInputEventDocumentPayload.run();
Expand All @@ -373,10 +376,14 @@ private void writeEventBody(

private void generateEventStreamDeserializer(GenerationContext context, UnionShape eventsUnion) {
String methodName = getDeserFunctionName(context, eventsUnion);
String methodLongName = ProtocolGenerator.getDeserFunctionName(getSymbol(context, eventsUnion),
context.getProtocolName());
Symbol eventsUnionSymbol = getSymbol(context, eventsUnion);
TypeScriptWriter writer = context.getWriter();
Model model = context.getModel();
String contextType = getEventStreamSerdeContextType(context, eventsUnion);

writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: $L\n"
Expand All @@ -401,8 +408,7 @@ private void generateEventStreamDeserializer(GenerationContext context, UnionSha

private String getDeserFunctionName(GenerationContext context, Shape shape) {
Symbol symbol = getSymbol(context, shape);
String protocolName = context.getProtocolName();
return ProtocolGenerator.getDeserFunctionName(symbol, protocolName);
return ProtocolGenerator.getDeserFunctionShortName(symbol);
}

public String getEventDeserFunctionName(GenerationContext context, Shape shape) {
Expand Down Expand Up @@ -444,7 +450,7 @@ private void generateErrorEventUnmarshaller(
TypeScriptWriter writer = context.getWriter();
// If this is an error event, we need to generate the error deserializer.
errorShapesToDeserialize.add(event);
String errorDeserMethodName = getDeserFunctionName(context, event) + "Response";
String errorDeserMethodName = getDeserFunctionName(context, event) + "Res";
if (isErrorCodeInBody) {
// If error code is in body, parseBody() won't be called inside error deser. So we parse body here.
// It's ok to parse body here because body won't be streaming if 'isErrorCodeInBody' is set.
Expand Down Expand Up @@ -489,14 +495,14 @@ private void readEventBody(
} else if (payloadShape instanceof StructureShape || payloadShape instanceof UnionShape) {
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, payloadShape);
String deserFunctionName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.write("contents.$L = $L(data, context);", payloadMemberName, deserFunctionName);
eventShapesToDeserialize.add(payloadShape);
}
} else {
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, event);
String deserFunctionName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.write("Object.assign(contents, $L(data, context));", deserFunctionName);
eventShapesToDeserialize.add(event);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ public void generateServiceHandlerFactory(GenerationContext context) {
generateServiceMux(context);
writer.addImport("ServiceException", "__ServiceException", "@aws-smithy/server-common");
writer.openBlock("const serFn: (op: $1T) => __OperationSerializer<$2T<Context>, $1T, __ServiceException> = "
+ "(op) => {", "};", operationsSymbol, serviceSymbol, () -> {
+ "(op) => {", "};", operationsSymbol, serviceSymbol, () -> {
writer.openBlock("switch (op) {", "}", () -> {
operations.stream()
.filter(o -> o.getTrait(HttpTrait.class).isPresent())
Expand Down Expand Up @@ -440,12 +440,12 @@ public void generateOperationHandlerFactory(GenerationContext context, Operation

if (context.getSettings().isDisableDefaultValidation()) {
writer.write("export const get$L = <Context>(operation: __Operation<$T, $T, Context>, "
+ "customizer: __ValidationCustomizer<$S>): "
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
+ "customizer: __ValidationCustomizer<$S>): "
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
operationHandlerSymbol.getName(), inputType, outputType, operationSymbol.getName());
} else {
writer.write("export const get$L = <Context>(operation: __Operation<$T, $T, Context>): "
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
operationHandlerSymbol.getName(), inputType, outputType);
}
writer.indent();
Expand Down Expand Up @@ -642,19 +642,25 @@ private void generateOperationRequestSerializer(
// Ensure that the request type is imported.
writer.addUseImports(requestType);
writer.addImport("Endpoint", "__Endpoint", "@aws-sdk/types");

// e.g., se_ES
String methodName = ProtocolGenerator.getSerFunctionShortName(symbol);
// e.g., serializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getSerFunctionName(symbol, getName());
String methodLongName = ProtocolGenerator.getSerFunctionName(symbol, getName());

// Add the normalized input type.
Symbol inputType = symbol.expectProperty("inputType", Symbol.class);
String contextType = CodegenUtils.getOperationSerializerContextType(writer, context.getModel(), operation);

writer.writeDocs(methodLongName);
writer.openBlock("export const $L = async(\n"
+ " input: $T,\n"
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, inputType, contextType, requestType, () -> {

// Get the hostname, path, port, and scheme from client's resolved endpoint. Then construct the request from
// them. The client's resolved endpoint can be default one or supplied by users.
// Get the hostname, path, port, and scheme from client's resolved endpoint.
// Then construct the request from them. The client's resolved endpoint can
// be default one or supplied by users.
writer.write("const {hostname, protocol = $S, port, path: basePath} = await context.endpoint();", "https");

writeRequestHeaders(context, operation, bindingIndex);
Expand Down Expand Up @@ -777,12 +783,12 @@ private void writeResolvedPath(
Shape target = model.expectShape(binding.getMember().getTarget());

String labelValueProvider = "() => " + getInputValue(
context,
binding.getLocation(),
"input." + memberName + "!",
binding.getMember(),
target
);
context,
binding.getLocation(),
"input." + memberName + "!",
binding.getMember(),
target
);

// Get the correct label to use.
Segment uriLabel = uriLabels.stream().filter(s -> s.getContent().equals(memberName)).findFirst().get();
Expand Down Expand Up @@ -1342,7 +1348,7 @@ private String getNamedMembersInputParam(
switch (bindingType) {
case PAYLOAD:
Symbol symbol = context.getSymbolProvider().toSymbol(target);
return ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName())
return ProtocolGenerator.getSerFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
default:
throw new CodegenException("Unexpected named member shape binding location `" + bindingType + "`");
Expand Down Expand Up @@ -1887,17 +1893,18 @@ private void readDirectQueryBindings(GenerationContext context, List<HttpBinding
"@aws-smithy/server-common");
writer.write("let queryValue: string;");
writer.openBlock("if (Array.isArray(query[$S])) {", "}",
binding.getLocationName(),
() -> {
writer.openBlock("if (query[$S].length === 1) {", "}",
binding.getLocationName(),
() -> {
writer.write("queryValue = query[$S][0];", binding.getLocationName());
});
writer.openBlock("else {", "}", () -> {
writer.write("throw new __SerializationException();");
});
binding.getLocationName(),
() -> {
writer.openBlock("if (query[$S].length === 1) {", "}",
binding.getLocationName(),
() -> {
writer.write("queryValue = query[$S][0];", binding.getLocationName());
}
);
writer.openBlock("else {", "}", () -> {
writer.write("throw new __SerializationException();");
});
});
writer.openBlock("else {", "}", () -> {
writer.write("queryValue = query[$S] as string;", binding.getLocationName());
});
Expand Down Expand Up @@ -2052,18 +2059,21 @@ private void generateOperationResponseDeserializer(
// Ensure that the response type is imported.
writer.addUseImports(responseType);
// e.g., deserializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = methodName + "Error";
// Add the normalized output type.
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
String contextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
context.getModel(), operation);

// Handle the general response.
writer.writeDocs(methodLongName);
writer.openBlock("export const $L = async(\n"
+ " output: $T,\n"
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, responseType, contextType, outputType, () -> {
+ "): Promise<$T> => {", "}",
methodName, responseType, contextType, outputType, () -> {
// Redirect error deserialization to the dispatcher if we receive an error range
// status code that's not the modeled code (300 or higher). This allows for
// returning other 2XX codes that don't match the defined value.
Expand Down Expand Up @@ -2103,10 +2113,13 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape
HttpBindingIndex bindingIndex = HttpBindingIndex.of(context.getModel());
Model model = context.getModel();
Symbol errorSymbol = symbolProvider.toSymbol(error);
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol,
context.getProtocolName()) + "Response";
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionShortName(errorSymbol) + "Res";
String errorDeserMethodLongName = ProtocolGenerator.getDeserFunctionName(errorSymbol, context.getProtocolName())
+ "Res";

String outputName = isErrorCodeInBody ? "parsedOutput" : "output";

writer.writeDocs(errorDeserMethodLongName);
writer.openBlock("const $L = async (\n"
+ " $L: any,\n"
+ " context: __SerdeContext\n"
Expand Down Expand Up @@ -2661,8 +2674,8 @@ private String getNamedMembersOutputParam(
case PAYLOAD:
// Redirect to a deserialization function.
Symbol symbol = context.getSymbolProvider().toSymbol(target);
return ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName())
+ "(" + dataSource + ", context)";
return ProtocolGenerator.getDeserFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
default:
throw new CodegenException("Unexpected named member shape binding location `" + bindingType + "`");
}
Expand Down
Loading

0 comments on commit 8355cce

Please sign in to comment.