private static void generateEventStreamMiddleware()

in codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsEventStreamUtils.java [257:550]


    private static void generateEventStreamMiddleware(
            GenerationContext context,
            OperationShape operationShape,
            boolean withInitialMessages
    ) {

        var serviceShape = context.getService();
        var middlewareName = getSerDeName(operationShape, serviceShape, context.getProtocolName(),
                "_deserializeOpEventStream");

        var errorf = getSymbol("Errorf", SmithyGoDependency.FMT, false);
        var getSignedRequestSignature = getSymbol("GetSignedRequestSignature", AwsGoDependency.AWS_SIGNER_V4, false);
        var symbolProvider = context.getSymbolProvider();
        var model = context.getModel();
        var outputShape = model.expectShape(operationShape.getOutput().get());

        var inputInfo = EventStreamIndex.of(model).getInputInfo(operationShape);
        var outputInfo = EventStreamIndex.of(model).getOutputInfo(operationShape);

        var writer = context.getWriter().get();

        var middleware = GoStackStepMiddlewareGenerator.createDeserializeStepMiddleware(middlewareName, MiddlewareIdentifier.builder()
                .name("OperationEventStreamDeserializer")
                .build());

        middleware.writeMiddleware(writer,
                (mg, w) -> {
                    w.write("""
                            defer func() {
                                if err == nil {
                                    return
                                }
                                m.closeResponseBody(out)
                            }()

                            logger := $T(ctx)
                            """, getSymbol("GetLogger",
                            SmithyGoDependency.SMITHY_MIDDLEWARE, false));

                    w.write("""
                            request, ok := in.Request.($P)
                            if !ok {
                                return out, metadata, $T("unknown transport type: %T", in.Request)
                            }
                            _ = request
                            """, getSymbol("Request", SmithyGoDependency.SMITHY_HTTP_TRANSPORT), errorf);

                    if (inputInfo.isPresent()) {
                        w.write("""
                                if err := $T(request); err != nil {
                                    return out, metadata, err
                                }
                                """, getEventStreamApiSymbol("ApplyHTTPTransportFixes"))
                                .write("");

                        w.writeGoTemplate("""
                                requestSignature, err := $getSignature:T(request.Request)
                                if err != nil {
                                   return out, metadata, $errorf:T("failed to get event stream seed signature: %v", err)
                                }

                                identity := getIdentity(ctx)
                                if identity == nil {
                                    return out, metadata, $errorf:T("no identity")
                                }

                                creds, ok := identity.($credentialsAdapter:P)
                                if !ok {
                                    return out, metadata, $errorf:T("identity is not sigv4 credentials")
                                }

                                rscheme := getResolvedAuthScheme(ctx)
                                if rscheme == nil {
                                    return out, metadata, $errorf:T("no resolved auth scheme")
                                }

                                name, ok := $getSigningName:T(&rscheme.SignerProperties)
                                if !ok {
                                    return out, metadata, $errorf:T("no sigv4 signing name")
                                }

                                region, ok := $getSigningRegion:T(&rscheme.SignerProperties)
                                if !ok {
                                    return out, metadata, $errorf:T("no sigv4 signing region")
                                }

                                signer := v4.NewStreamSigner(creds.Credentials, name, region, requestSignature)
                                """,
                                MapUtils.of(
                                        "getSignature", getSignedRequestSignature,
                                        "errorf", GoStdlibTypes.Fmt.Errorf,
                                        "credentialsAdapter", SdkGoTypes.Internal.Auth.Smithy.CredentialsAdapter,
                                        "getSigningName", SmithyGoTypes.Transport.Http.GetSigV4SigningName,
                                        "getSigningRegion", SmithyGoTypes.Transport.Http.GetSigV4SigningRegion
                                ));

                        var events = inputInfo.get().getEventStreamTarget().asUnionShape()
                                .get();
                        var constructorName = getEventStreamWriterImplConstructorName(events,
                                serviceShape);
                        var newEncoder = getEventStreamSymbol("NewEncoder", false);
                        var encoderOptions = getEventStreamSymbol("EncoderOptions");
                        w.openBlock("eventWriter := $L(", ")", constructorName, () -> {
                                    w.write("$T(ctx),", getEventStreamApiSymbol("GetInputStreamWriter",
                                                    false))
                                            .openBlock("$T(func(options $P) {", "}),", newEncoder,
                                                    encoderOptions, () -> w
                                                            .write("""
                                                                   options.Logger = logger
                                                                   options.LogMessages = m.LogEventStreamWrites
                                                                   """))
                                            .write("signer,");
                                    if (withInitialMessages) {
                                        w.write("$L,", getEventStreamMessageRequestSerializerName(
                                                operationShape.getInput().get(), serviceShape,
                                                context.getProtocolName()));
                                    }
                                })
                                .write("""
                                       defer func() {
                                           if err == nil {
                                               return
                                           }
                                           _ = eventWriter.Close()
                                       }()
                                       """);

                        if (withInitialMessages) {
                            var inputShape = model.expectShape(operationShape.getInput().get());
                            w.write("""
                                    params, ok := $L(ctx).($P)
                                    if !ok || params == nil {
                                        return out, metadata, $T("unexpected nil type: %T", params)
                                    }

                                    reqSend := make(chan error, 1)
                                    go func() {
                                        defer close(reqSend)
                                        sErr := eventWriter.send(ctx, &$T{Value: params})
                                        reqSend <- sErr
                                    }()
                                    """, CONTEXT_GET_EVENT_STREAM_INPUT, symbolProvider.toSymbol(inputShape),
                                    errorf, getWriterEventWrapperInitialRequestType(symbolProvider,
                                            inputInfo.get().getEventStreamTarget().asUnionShape().get(), serviceShape));
                        }
                    }

                    var outputSymbol = symbolProvider.toSymbol(outputShape);
                    w.write("out, metadata, err = next.HandleDeserialize(ctx, in)");

                    writer.openBlock("if err != nil {", "}", () -> {
                        if (withInitialMessages && inputInfo.isPresent()) {
                            w.write("""
                                    select {
                                    case sErr := <-reqSend:
                                        if sErr != nil {
                                            err = $T("%v: %w", err, sErr)
                                        }
                                    default:
                                    }""", errorf);
                        }
                        writer.write("return out, metadata, err");
                    }).write("");

                    if (withInitialMessages && inputInfo.isPresent()) {
                        w.write("""
                                if err := <-reqSend; err != nil {
                                    return out, metadata, err
                                }
                                """);
                    }

                    w.write("""
                            deserializeOutput, ok := out.RawResponse.($P)
                            if !ok {
                                return out, metadata, $T("unknown transport type: %T", out.RawResponse)
                            }
                            _ = deserializeOutput

                            output, ok := out.Result.($P)
                            if out.Result != nil && !ok {
                                return out, metadata, $T("unexpected output result type: %T", out.Result)
                            } else if out.Result == nil {
                                output = &$T{}
                                out.Result = output
                            }
                            """, getSymbol("Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT), errorf,
                            outputSymbol, errorf, outputSymbol
                    );

                    if (outputInfo.isPresent()) {
                        var events = outputInfo.get().getEventStreamTarget().asUnionShape()
                                .get();
                        var constructorName = getEventStreamReaderImplConstructorName(events,
                                serviceShape);
                        var newDecoder = getEventStreamSymbol("NewDecoder", false);
                        var decoderOptions = getEventStreamSymbol("DecoderOptions");
                        w.openBlock("eventReader := $L(", ")", constructorName, () -> {
                                    w.write("deserializeOutput.Body,")
                                            .openBlock("$T(func(options $P) {", "}),", newDecoder,
                                                    decoderOptions, () -> w
                                                            .write("""
                                                                   options.Logger = logger
                                                                   options.LogMessages = m.LogEventStreamReads
                                                                   """));
                                    if (withInitialMessages) {
                                        w.write("$L,", getEventStreamMessageResponseDeserializerName(
                                                operationShape.getOutput().get(), serviceShape,
                                                context.getProtocolName()));
                                    }
                                })
                                .write("""
                                       defer func() {
                                           if err == nil {
                                               return
                                           }
                                           _ = eventReader.Close()
                                       }()
                                       """);

                        if (withInitialMessages) {
                            w.write("""
                                    ir := <-eventReader.initialResponse
                                    irv, ok := ir.($P)
                                    if !ok {
                                        return out, metadata, $T("unexpected output result type: %T", ir)
                                    }
                                    *output = *irv
                                    """, outputSymbol, errorf);
                        }
                    }

                    var streamConstructor = EventStreamGenerator.getEventStreamOperationStructureConstructor(
                            serviceShape, operationShape);
                    var operationStream = EventStreamGenerator.getEventStreamOperationStructureSymbol(
                            serviceShape, operationShape);

                    w.openBlock("output.eventStream = $T(func(stream $P) {", "})", streamConstructor,
                                    operationStream, () -> {
                                        inputInfo.ifPresent(eventStreamInfo -> {
                                            w.write("stream.Writer = eventWriter");
                                        });
                                        outputInfo.ifPresent(eventStreamInfo -> {
                                            w.write("stream.Reader = eventReader");
                                        });
                                    }).write("")
                            .write("go output.eventStream.waitStreamClose()").write("")
                            .write("return out, metadata, nil");
                },
                (mg, w) -> w.write("""
                                   LogEventStreamWrites bool
                                   LogEventStreamReads  bool
                                   """));

        var deserializeOutput = getSymbol("DeserializeOutput", SmithyGoDependency.SMITHY_MIDDLEWARE);
        var httpResponse = getSymbol("Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT);
        var copy = getSymbol("Copy", SmithyGoDependency.IO);
        var discard = getSymbol("Discard", SmithyGoDependency.IOUTIL);
        writer.write("""

                     func ($P) closeResponseBody(out $T) {
                         if resp, ok := out.RawResponse.($P); ok && resp != nil && resp.Body != nil {
                             _, _ = $T($T, resp.Body)
                             _ = resp.Body.Close()
                         }
                     }
                     """, middleware.getMiddlewareSymbol(), deserializeOutput, httpResponse, copy, discard);

        var stack = getSymbol("Stack", SmithyGoDependency.SMITHY_MIDDLEWARE);
        var after = getSymbol("After", SmithyGoDependency.SMITHY_MIDDLEWARE);
        var before = getSymbol("Before", SmithyGoDependency.SMITHY_MIDDLEWARE);

        writer.openBlock("func $T(stack $P, options Options) error {", "}",
                getAddEventStreamOperationMiddlewareSymbol(operationShape), stack,
                () -> {
                    if (withInitialMessages && inputInfo.isPresent()) {
                        writer.write("""
                                     if err := stack.Serialize.Insert(&$T{}, "OperationSerializer", $T); err != nil {
                                         return err
                                     }
                                     """, getModuleSymbol(context.getSettings(), EVENT_STREAM_SERIALIZER_HELPER),
                                after);
                    }
                    writer.write("""
                                 if err := stack.Deserialize.Insert(&$T{
                                     LogEventStreamWrites: options.ClientLogMode.IsRequestEventMessage(),
                                     LogEventStreamReads:  options.ClientLogMode.IsResponseEventMessage(),
                                 }, "OperationDeserializer", $T); err != nil {
                                     return err
                                 }
                                 return nil
                                 """, middleware.getMiddlewareSymbol(), before);
                });
    }