private async Task ConvertToAsyncPackageAsync()

in src/Microsoft.VisualStudio.SDK.Analyzers.CodeFixes/VSSDK001DeriveFromAsyncPackageCodeFix.cs [48:222]


        private async Task<Document> ConvertToAsyncPackageAsync(CodeFixContext context, Diagnostic diagnostic, CancellationToken cancellationToken)
        {
            SemanticModel semanticModel = await context.Document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            Compilation? compilation = await context.Document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
            Assumes.NotNull(compilation);
            SyntaxNode root = await context.Document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
            BaseTypeSyntax baseTypeSyntax = root.FindNode(diagnostic.Location.SourceSpan).FirstAncestorOrSelf<BaseTypeSyntax>();
            ClassDeclarationSyntax classDeclarationSyntax = baseTypeSyntax.FirstAncestorOrSelf<ClassDeclarationSyntax>();
            MethodDeclarationSyntax initializeMethodSyntax = classDeclarationSyntax.DescendantNodes()
                .OfType<MethodDeclarationSyntax>()
                .FirstOrDefault(method => method.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.OverrideKeyword)) && method.Identifier.Text == Types.Package.Initialize);
            InvocationExpressionSyntax? baseInitializeInvocationSyntax = initializeMethodSyntax?.Body?.DescendantNodes()
                .OfType<InvocationExpressionSyntax>()
                .FirstOrDefault(ies => ies.Expression is MemberAccessExpressionSyntax memberAccess && memberAccess.Name?.Identifier.Text == Types.Package.Initialize && memberAccess.Expression is BaseExpressionSyntax);
            var getServiceInvocationsSyntax = new List<InvocationExpressionSyntax>();
            AttributeSyntax? packageRegistrationSyntax = null;
            {
                INamedTypeSymbol userClassSymbol = semanticModel.GetDeclaredSymbol(classDeclarationSyntax, context.CancellationToken);
                INamedTypeSymbol packageRegistrationType = compilation.GetTypeByMetadataName(Types.PackageRegistrationAttribute.FullName);
                AttributeData? packageRegistrationInstance = userClassSymbol?.GetAttributes().FirstOrDefault(a => Equals(a.AttributeClass, packageRegistrationType));
                if (packageRegistrationInstance?.ApplicationSyntaxReference != null)
                {
                    packageRegistrationSyntax = (AttributeSyntax)await packageRegistrationInstance.ApplicationSyntaxReference.GetSyntaxAsync(cancellationToken).ConfigureAwait(false);
                }
            }

            if (initializeMethodSyntax != null)
            {
                getServiceInvocationsSyntax.AddRange(
                    from invocation in initializeMethodSyntax.DescendantNodes().OfType<InvocationExpressionSyntax>()
                    let memberBinding = invocation.Expression as MemberAccessExpressionSyntax
                    let identifierName = invocation.Expression as IdentifierNameSyntax
                    where identifierName?.Identifier.Text == Types.Package.GetService
                       || (memberBinding.Name.Identifier.Text == Types.Package.GetService && memberBinding.Expression.IsKind(SyntaxKind.ThisExpression))
                    select invocation);
            }

            // Make it easier to track nodes across changes.
            var nodesToTrack = new List<SyntaxNode?>
            {
                baseTypeSyntax,
                initializeMethodSyntax,
                baseInitializeInvocationSyntax,
                packageRegistrationSyntax,
            };
            nodesToTrack.AddRange(getServiceInvocationsSyntax);
            nodesToTrack.RemoveAll(n => n == null);
            SyntaxNode updatedRoot = root.TrackNodes(nodesToTrack);

            // Replace the Package base type with AsyncPackage
            baseTypeSyntax = updatedRoot.GetCurrentNode(baseTypeSyntax);
            SimpleBaseTypeSyntax asyncPackageBaseTypeSyntax = SyntaxFactory.SimpleBaseType(Types.AsyncPackage.TypeSyntax.WithAdditionalAnnotations(Simplifier.Annotation))
                .WithLeadingTrivia(baseTypeSyntax.GetLeadingTrivia())
                .WithTrailingTrivia(baseTypeSyntax.GetTrailingTrivia());
            updatedRoot = updatedRoot.ReplaceNode(baseTypeSyntax, asyncPackageBaseTypeSyntax);

            // Update the PackageRegistration attribute
            if (packageRegistrationSyntax != null)
            {
                LiteralExpressionSyntax trueExpression = SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression);
                packageRegistrationSyntax = updatedRoot.GetCurrentNode(packageRegistrationSyntax);
                AttributeArgumentSyntax allowsBackgroundLoadingSyntax = packageRegistrationSyntax.ArgumentList.Arguments.FirstOrDefault(a => a.NameEquals?.Name?.Identifier.Text == Types.PackageRegistrationAttribute.AllowsBackgroundLoading);
                if (allowsBackgroundLoadingSyntax != null)
                {
                    updatedRoot = updatedRoot.ReplaceNode(
                        allowsBackgroundLoadingSyntax,
                        allowsBackgroundLoadingSyntax.WithExpression(trueExpression));
                }
                else
                {
                    updatedRoot = updatedRoot.ReplaceNode(
                        packageRegistrationSyntax,
                        packageRegistrationSyntax.AddArgumentListArguments(
                            SyntaxFactory.AttributeArgument(trueExpression).WithNameEquals(SyntaxFactory.NameEquals(Types.PackageRegistrationAttribute.AllowsBackgroundLoading))));
                }
            }

            // Find the Initialize override, if present, and update it to InitializeAsync
            if (initializeMethodSyntax != null)
            {
                IdentifierNameSyntax cancellationTokenLocalVarName = SyntaxFactory.IdentifierName("cancellationToken");
                IdentifierNameSyntax progressLocalVarName = SyntaxFactory.IdentifierName("progress");
                initializeMethodSyntax = updatedRoot.GetCurrentNode(initializeMethodSyntax);
                BlockSyntax newBody = initializeMethodSyntax.Body;

                SyntaxTriviaList leadingTrivia = SyntaxFactory.TriviaList(
                    SyntaxFactory.Comment(@"// When initialized asynchronously, we *may* be on a background thread at this point."),
                    SyntaxFactory.CarriageReturnLineFeed,
                    SyntaxFactory.Comment(@"// Do any initialization that requires the UI thread after switching to the UI thread."),
                    SyntaxFactory.CarriageReturnLineFeed,
                    SyntaxFactory.Comment(@"// Otherwise, remove the switch to the UI thread if you don't need it."),
                    SyntaxFactory.CarriageReturnLineFeed);

                ExpressionStatementSyntax switchToMainThreadStatement = SyntaxFactory.ExpressionStatement(
                    SyntaxFactory.AwaitExpression(
                        SyntaxFactory.InvocationExpression(
                            SyntaxFactory.MemberAccessExpression(
                                SyntaxKind.SimpleMemberAccessExpression,
                                SyntaxFactory.MemberAccessExpression(
                                    SyntaxKind.SimpleMemberAccessExpression,
                                    SyntaxFactory.ThisExpression(),
                                    SyntaxFactory.IdentifierName(Types.ThreadHelper.JoinableTaskFactory)),
                                SyntaxFactory.IdentifierName(Types.JoinableTaskFactory.SwitchToMainThreadAsync)))
                            .AddArgumentListArguments(SyntaxFactory.Argument(cancellationTokenLocalVarName))))
                    .WithLeadingTrivia(leadingTrivia)
                    .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed);

                if (baseInitializeInvocationSyntax != null)
                {
                    var baseInitializeAsyncInvocationBookmark = new SyntaxAnnotation();
                    AwaitExpressionSyntax baseInitializeAsyncInvocationSyntax = SyntaxFactory.AwaitExpression(
                        baseInitializeInvocationSyntax
                            .WithLeadingTrivia()
                            .WithExpression(
                                SyntaxFactory.MemberAccessExpression(
                                    SyntaxKind.SimpleMemberAccessExpression,
                                    SyntaxFactory.BaseExpression(),
                                    SyntaxFactory.IdentifierName(Types.AsyncPackage.InitializeAsync)))
                            .AddArgumentListArguments(
                                SyntaxFactory.Argument(cancellationTokenLocalVarName),
                                SyntaxFactory.Argument(progressLocalVarName)))
                        .WithLeadingTrivia(baseInitializeInvocationSyntax.GetLeadingTrivia())
                        .WithAdditionalAnnotations(baseInitializeAsyncInvocationBookmark);
                    newBody = newBody.ReplaceNode(initializeMethodSyntax.GetCurrentNode(baseInitializeInvocationSyntax), baseInitializeAsyncInvocationSyntax);
                    StatementSyntax baseInvocationStatement = newBody.GetAnnotatedNodes(baseInitializeAsyncInvocationBookmark).First().FirstAncestorOrSelf<StatementSyntax>();

                    newBody = newBody.InsertNodesAfter(
                        baseInvocationStatement,
                        new[] { switchToMainThreadStatement.WithLeadingTrivia(switchToMainThreadStatement.GetLeadingTrivia().Insert(0, SyntaxFactory.LineFeed)) });
                }
                else
                {
                    newBody = newBody.WithStatements(
                        newBody.Statements.Insert(0, switchToMainThreadStatement));
                }

                MethodDeclarationSyntax initializeAsyncMethodSyntax = initializeMethodSyntax
                    .WithIdentifier(SyntaxFactory.Identifier(Types.AsyncPackage.InitializeAsync))
                    .WithReturnType(Types.Task.TypeSyntax.WithAdditionalAnnotations(Simplifier.Annotation))
                    .AddModifiers(SyntaxFactory.Token(SyntaxKind.AsyncKeyword))
                    .AddParameterListParameters(
                        SyntaxFactory.Parameter(cancellationTokenLocalVarName.Identifier).WithType(Types.CancellationToken.TypeSyntax.WithAdditionalAnnotations(Simplifier.Annotation)),
                        SyntaxFactory.Parameter(progressLocalVarName.Identifier).WithType(Types.IProgress.TypeSyntaxOf(Types.ServiceProgressData.TypeSyntax).WithAdditionalAnnotations(Simplifier.Annotation)))
                    .WithBody(newBody);
                updatedRoot = updatedRoot.ReplaceNode(initializeMethodSyntax, initializeAsyncMethodSyntax);

                // Replace GetService calls with GetServiceAsync
                getServiceInvocationsSyntax = updatedRoot.GetCurrentNodes<InvocationExpressionSyntax>(getServiceInvocationsSyntax).ToList();
                updatedRoot = updatedRoot.ReplaceNodes(
                    getServiceInvocationsSyntax,
                    (orig, node) =>
                    {
                        InvocationExpressionSyntax invocation = node;
                        if (invocation.Expression is IdentifierNameSyntax methodName)
                        {
                            invocation = invocation.WithExpression(SyntaxFactory.IdentifierName(Types.AsyncPackage.GetServiceAsync));
                        }
                        else if (invocation.Expression is MemberAccessExpressionSyntax memberAccess)
                        {
                            invocation = invocation.WithExpression(
                                memberAccess.WithName(SyntaxFactory.IdentifierName(Types.AsyncPackage.GetServiceAsync)));
                        }

                        return SyntaxFactory.ParenthesizedExpression(SyntaxFactory.AwaitExpression(invocation))
                            .WithAdditionalAnnotations(Simplifier.Annotation);
                    });

                updatedRoot = await Utils.AddUsingTaskEqualsDirectiveAsync(updatedRoot, cancellationToken);
            }

            Document newDocument = context.Document.WithSyntaxRoot(updatedRoot);
            newDocument = await ImportAdder.AddImportsAsync(newDocument, Simplifier.Annotation, cancellationToken: cancellationToken);

            return newDocument;
        }