in codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsEventStreamUtils.java [257:550]
private static void generateEventStreamMiddleware(
GenerationContext context,
OperationShape operationShape,
boolean withInitialMessages
) {
var serviceShape = context.getService();
var middlewareName = getSerDeName(operationShape, serviceShape, context.getProtocolName(),
"_deserializeOpEventStream");
var errorf = getSymbol("Errorf", SmithyGoDependency.FMT, false);
var getSignedRequestSignature = getSymbol("GetSignedRequestSignature", AwsGoDependency.AWS_SIGNER_V4, false);
var symbolProvider = context.getSymbolProvider();
var model = context.getModel();
var outputShape = model.expectShape(operationShape.getOutput().get());
var inputInfo = EventStreamIndex.of(model).getInputInfo(operationShape);
var outputInfo = EventStreamIndex.of(model).getOutputInfo(operationShape);
var writer = context.getWriter().get();
var middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware(middlewareName, MiddlewareIdentifier.builder()
.name("OperationEventStreamDeserializer")
.build());
middleware.writeMiddleware(writer,
(mg, w) -> {
w.write("""
defer func() {
if err == nil {
return
}
m.closeResponseBody(out)
}()
logger := $T(ctx)
""", getSymbol("GetLogger",
SmithyGoDependency.SMITHY_MIDDLEWARE, false));
w.write("""
request, ok := in.Request.($P)
if !ok {
return out, metadata, $T("unknown transport type: %T", in.Request)
}
_ = request
""", getSymbol("Request", SmithyGoDependency.SMITHY_HTTP_TRANSPORT), errorf);
if (inputInfo.isPresent()) {
w.write("""
if err := $T(request); err != nil {
return out, metadata, err
}
""", getEventStreamApiSymbol("ApplyHTTPTransportFixes"))
.write("");
w.writeGoTemplate("""
requestSignature, err := $getSignature:T(request.Request)
if err != nil {
return out, metadata, $errorf:T("failed to get event stream seed signature: %v", err)
}
identity := getIdentity(ctx)
if identity == nil {
return out, metadata, $errorf:T("no identity")
}
creds, ok := identity.($credentialsAdapter:P)
if !ok {
return out, metadata, $errorf:T("identity is not sigv4 credentials")
}
rscheme := getResolvedAuthScheme(ctx)
if rscheme == nil {
return out, metadata, $errorf:T("no resolved auth scheme")
}
name, ok := $getSigningName:T(&rscheme.SignerProperties)
if !ok {
return out, metadata, $errorf:T("no sigv4 signing name")
}
region, ok := $getSigningRegion:T(&rscheme.SignerProperties)
if !ok {
return out, metadata, $errorf:T("no sigv4 signing region")
}
signer := v4.NewStreamSigner(creds.Credentials, name, region, requestSignature)
""",
MapUtils.of(
"getSignature", getSignedRequestSignature,
"errorf", GoStdlibTypes.Fmt.Errorf,
"credentialsAdapter", SdkGoTypes.Internal.Auth.Smithy.CredentialsAdapter,
"getSigningName", SmithyGoTypes.Transport.Http.GetSigV4SigningName,
"getSigningRegion", SmithyGoTypes.Transport.Http.GetSigV4SigningRegion
));
var events = inputInfo.get().getEventStreamTarget().asUnionShape()
.get();
var constructorName = getEventStreamWriterImplConstructorName(events,
serviceShape);
var newEncoder = getEventStreamSymbol("NewEncoder", false);
var encoderOptions = getEventStreamSymbol("EncoderOptions");
w.openBlock("eventWriter := $L(", ")", constructorName, () -> {
w.write("$T(ctx),", getEventStreamApiSymbol("GetInputStreamWriter",
false))
.openBlock("$T(func(options $P) {", "}),", newEncoder,
encoderOptions, () -> w
.write("""
options.Logger = logger
options.LogMessages = m.LogEventStreamWrites
"""))
.write("signer,");
if (withInitialMessages) {
w.write("$L,", getEventStreamMessageRequestSerializerName(
operationShape.getInput().get(), serviceShape,
context.getProtocolName()));
}
})
.write("""
defer func() {
if err == nil {
return
}
_ = eventWriter.Close()
}()
""");
if (withInitialMessages) {
var inputShape = model.expectShape(operationShape.getInput().get());
w.write("""
params, ok := $L(ctx).($P)
if !ok || params == nil {
return out, metadata, $T("unexpected nil type: %T", params)
}
reqSend := make(chan error, 1)
go func() {
defer close(reqSend)
sErr := eventWriter.send(ctx, &$T{Value: params})
reqSend <- sErr
}()
""", CONTEXT_GET_EVENT_STREAM_INPUT, symbolProvider.toSymbol(inputShape),
errorf, getWriterEventWrapperInitialRequestType(symbolProvider,
inputInfo.get().getEventStreamTarget().asUnionShape().get(), serviceShape));
}
}
var outputSymbol = symbolProvider.toSymbol(outputShape);
w.write("out, metadata, err = next.HandleDeserialize(ctx, in)");
writer.openBlock("if err != nil {", "}", () -> {
if (withInitialMessages && inputInfo.isPresent()) {
w.write("""
select {
case sErr := <-reqSend:
if sErr != nil {
err = $T("%v: %w", err, sErr)
}
default:
}""", errorf);
}
writer.write("return out, metadata, err");
}).write("");
if (withInitialMessages && inputInfo.isPresent()) {
w.write("""
if err := <-reqSend; err != nil {
return out, metadata, err
}
""");
}
w.write("""
deserializeOutput, ok := out.RawResponse.($P)
if !ok {
return out, metadata, $T("unknown transport type: %T", out.RawResponse)
}
_ = deserializeOutput
output, ok := out.Result.($P)
if out.Result != nil && !ok {
return out, metadata, $T("unexpected output result type: %T", out.Result)
} else if out.Result == nil {
output = &$T{}
out.Result = output
}
""", getSymbol("Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT), errorf,
outputSymbol, errorf, outputSymbol
);
if (outputInfo.isPresent()) {
var events = outputInfo.get().getEventStreamTarget().asUnionShape()
.get();
var constructorName = getEventStreamReaderImplConstructorName(events,
serviceShape);
var newDecoder = getEventStreamSymbol("NewDecoder", false);
var decoderOptions = getEventStreamSymbol("DecoderOptions");
w.openBlock("eventReader := $L(", ")", constructorName, () -> {
w.write("deserializeOutput.Body,")
.openBlock("$T(func(options $P) {", "}),", newDecoder,
decoderOptions, () -> w
.write("""
options.Logger = logger
options.LogMessages = m.LogEventStreamReads
"""));
if (withInitialMessages) {
w.write("$L,", getEventStreamMessageResponseDeserializerName(
operationShape.getOutput().get(), serviceShape,
context.getProtocolName()));
}
})
.write("""
defer func() {
if err == nil {
return
}
_ = eventReader.Close()
}()
""");
if (withInitialMessages) {
w.write("""
ir := <-eventReader.initialResponse
irv, ok := ir.($P)
if !ok {
return out, metadata, $T("unexpected output result type: %T", ir)
}
*output = *irv
""", outputSymbol, errorf);
}
}
var streamConstructor = EventStreamGenerator.getEventStreamOperationStructureConstructor(
serviceShape, operationShape);
var operationStream = EventStreamGenerator.getEventStreamOperationStructureSymbol(
serviceShape, operationShape);
w.openBlock("output.eventStream = $T(func(stream $P) {", "})", streamConstructor,
operationStream, () -> {
inputInfo.ifPresent(eventStreamInfo -> {
w.write("stream.Writer = eventWriter");
});
outputInfo.ifPresent(eventStreamInfo -> {
w.write("stream.Reader = eventReader");
});
}).write("")
.write("go output.eventStream.waitStreamClose()").write("")
.write("return out, metadata, nil");
},
(mg, w) -> w.write("""
LogEventStreamWrites bool
LogEventStreamReads bool
"""));
var deserializeOutput = getSymbol("DeserializeOutput", SmithyGoDependency.SMITHY_MIDDLEWARE);
var httpResponse = getSymbol("Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT);
var copy = getSymbol("Copy", SmithyGoDependency.IO);
var discard = getSymbol("Discard", SmithyGoDependency.IOUTIL);
writer.write("""
func ($P) closeResponseBody(out $T) {
if resp, ok := out.RawResponse.($P); ok && resp != nil && resp.Body != nil {
_, _ = $T($T, resp.Body)
_ = resp.Body.Close()
}
}
""", middleware.getMiddlewareSymbol(), deserializeOutput, httpResponse, copy, discard);
var stack = getSymbol("Stack", SmithyGoDependency.SMITHY_MIDDLEWARE);
var after = getSymbol("After", SmithyGoDependency.SMITHY_MIDDLEWARE);
var before = getSymbol("Before", SmithyGoDependency.SMITHY_MIDDLEWARE);
writer.openBlock("func $T(stack $P, options Options) error {", "}",
getAddEventStreamOperationMiddlewareSymbol(operationShape), stack,
() -> {
if (withInitialMessages && inputInfo.isPresent()) {
writer.write("""
if err := stack.Serialize.Insert(&$T{}, "OperationSerializer", $T); err != nil {
return err
}
""", getModuleSymbol(context.getSettings(), EVENT_STREAM_SERIALIZER_HELPER),
after);
}
writer.write("""
if err := stack.Deserialize.Insert(&$T{
LogEventStreamWrites: options.ClientLogMode.IsRequestEventMessage(),
LogEventStreamReads: options.ClientLogMode.IsResponseEventMessage(),
}, "OperationDeserializer", $T); err != nil {
return err
}
return nil
""", middleware.getMiddlewareSymbol(), before);
});
}