src/WebJobs.Extensions.DurableTask.Analyzers/Analyzers/ActivityFunction/FunctionAnalyzer.cs (267 lines of code) (raw):
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.Analyzers
{
/// <summary>
/// Collects ActivityFunctionDefinitions and ActivityFunctionCalls and diagnoses issues on them.
/// Requires full solution analysis.
/// </summary>
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class FunctionAnalyzer : DiagnosticAnalyzer
{
private List<ActivityFunctionDefinition> availableFunctions = new List<ActivityFunctionDefinition>();
private List<ActivityFunctionCall> calledFunctions = new List<ActivityFunctionCall>();
private SemanticModel semanticModel;
private OrchestratorMethodCollector orchestratorMethodCollector;
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics
{
get
{
return ImmutableArray.Create(
NameAnalyzer.MissingRule,
NameAnalyzer.CloseRule,
ArgumentAnalyzer.MismatchRule,
FunctionReturnTypeAnalyzer.Rule);
}
}
public override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
FunctionAnalyzer functionAnalyzer = new FunctionAnalyzer();
functionAnalyzer.orchestratorMethodCollector = new OrchestratorMethodCollector();
context.RegisterCompilationStartAction(compilation =>
{
compilation.RegisterSyntaxNodeAction(functionAnalyzer.orchestratorMethodCollector.FindOrchestratorMethods, SyntaxKind.MethodDeclaration);
compilation.RegisterSyntaxNodeAction(functionAnalyzer.FindActivityFunctionDefinition, SyntaxKind.Attribute);
compilation.RegisterCompilationEndAction(functionAnalyzer.CompilationEndActions);
});
}
private void CompilationEndActions(CompilationAnalysisContext context)
{
this.FindActivityCalls();
this.RegisterAnalyzers(context);
}
private void RegisterAnalyzers(CompilationAnalysisContext context)
{
NameAnalyzer.ReportProblems(context, this.semanticModel, this.availableFunctions, this.calledFunctions);
ArgumentAnalyzer.ReportProblems(context, this.availableFunctions, this.calledFunctions);
FunctionReturnTypeAnalyzer.ReportProblems(context, this.availableFunctions, this.calledFunctions);
}
private void FindActivityCalls()
{
var orchestratorMethods = this.orchestratorMethodCollector.GetOrchestratorMethods();
foreach (MethodInformation methodInformation in orchestratorMethods)
{
var declaration = methodInformation.Declaration;
if (declaration != null)
{
var invocationExpressions = declaration.DescendantNodes().OfType<InvocationExpressionSyntax>();
this.semanticModel = methodInformation.SemanticModel;
foreach(var invocation in invocationExpressions)
{
if (IsActivityInvocation(invocation))
{
if (!TryGetFunctionNameFromActivityInvocation(invocation, out SyntaxNode functionNameNode, out string functionName))
{
//Do not store ActivityFunctionCall if there is no function name
return;
}
SyntaxNodeUtils.TryGetTypeArgumentIdentifier((MemberAccessExpressionSyntax)invocation.Expression, out SyntaxNode returnTypeNode);
SyntaxNodeUtils.TryGetITypeSymbol(this.semanticModel, returnTypeNode, out ITypeSymbol returnType);
TryGetInputNodeFromCallActivityInvocation(this.semanticModel, invocation, out SyntaxNode inputNode);
SyntaxNodeUtils.TryGetITypeSymbol(this.semanticModel, inputNode, out ITypeSymbol inputType);
this.calledFunctions.Add(new ActivityFunctionCall
{
FunctionName = functionName,
NameNode = functionNameNode,
InputNode = inputNode,
InputType = inputType,
ReturnTypeNode = returnTypeNode,
ReturnType = returnType,
InvocationExpression = invocation
});
}
}
}
}
}
private bool IsActivityInvocation(InvocationExpressionSyntax invocationExpression)
{
if (invocationExpression != null && invocationExpression.Expression is MemberAccessExpressionSyntax memberAccessExpression)
{
var name = memberAccessExpression.Name;
if (name != null
&& (name.ToString().StartsWith("CallActivityAsync")
|| name.ToString().StartsWith("CallActivityWithRetryAsync")))
{
return true;
}
}
return false;
}
private bool TryGetFunctionNameFromActivityInvocation(InvocationExpressionSyntax invocationExpression, out SyntaxNode functionNameNode, out string functionName)
{
var functionArgument = invocationExpression.ArgumentList.Arguments.FirstOrDefault();
if (functionArgument != null)
{
functionNameNode = functionArgument.ChildNodes().FirstOrDefault();
if (functionNameNode != null)
{
SyntaxNodeUtils.TryParseFunctionName(this.semanticModel, functionNameNode, out functionName);
return functionName != null;
}
}
functionNameNode = null;
functionName = null;
return false;
}
private static bool TryGetInputNodeFromCallActivityInvocation(SemanticModel semanticModel, InvocationExpressionSyntax invocationExpression, out SyntaxNode inputNode)
{
// If method invocation is a custom CallActivity extension method defined in user code
if (SyntaxNodeUtils.TryGetDeclaredSyntaxNode(semanticModel, invocationExpression, out SyntaxNode declaration))
{
if (TryGetSpecificParameterIndex(declaration, "object input", out int inputParameterIndex))
{
if (TryGetInvocationArguments(invocationExpression, out IEnumerable<ArgumentSyntax> arguments))
{
var argumentNode = arguments.ElementAt(inputParameterIndex);
inputNode = argumentNode.ChildNodes().First();
return true;
}
}
}
// else assume CallActivity is a DurableFunctions method
else
{
if (TryGetInvocationArguments(invocationExpression, out IEnumerable<ArgumentSyntax> arguments))
{
// Input node is currently the last argument on CallActivity* methods. If this is changed, this will not be sufficient to
// determine which argument is meant to represent the input.
var argumentNode = arguments.Last();
inputNode = argumentNode.ChildNodes().First();
return true;
}
}
inputNode = null;
return false;
}
private static bool TryGetSpecificParameterIndex(SyntaxNode declaration, string parameterToFind, out int inputParameterIndex)
{
if (declaration is MethodDeclarationSyntax methodDeclaration)
{
var parameters = methodDeclaration.ParameterList.ChildNodes();
var length = parameters.Count();
for (int i = 0; i < length; i++)
{
if (parameters.ElementAt(i).ToString() == parameterToFind)
{
inputParameterIndex = i;
if (IsExtensionMethod(parameters))
{
inputParameterIndex--;
}
return true;
}
}
}
inputParameterIndex = int.MinValue;
return false;
}
private static bool IsExtensionMethod(IEnumerable<SyntaxNode> parameters)
{
var firstParameter = parameters.ElementAt(0);
if (firstParameter.ToString().StartsWith("this IDurableOrchestrationContext"))
{
return true;
}
return false;
}
private static bool TryGetInvocationArguments(InvocationExpressionSyntax invocationExpression, out IEnumerable<ArgumentSyntax> arguments)
{
var argumentList = invocationExpression.ArgumentList;
if (argumentList != null)
{
arguments = argumentList.Arguments;
if (arguments != null && arguments.Any())
{
return true;
}
}
arguments = null;
return false;
}
public void FindActivityFunctionDefinition(SyntaxNodeAnalysisContext context)
{
var semanticModel = context.SemanticModel;
if (context.Node is AttributeSyntax attribute
&& SyntaxNodeUtils.IsActivityTriggerAttribute(attribute))
{
if (!SyntaxNodeUtils.TryGetFunctionName(semanticModel, attribute, out string functionName))
{
//Do not store ActivityFunctionDefinition if there is no function name
return;
}
if (!SyntaxNodeUtils.TryGetMethodReturnTypeNode(attribute, out SyntaxNode returnTypeNode))
{
//Do not store ActivityFunctionDefinition if there is no return type
return;
}
SyntaxNodeUtils.TryGetITypeSymbol(semanticModel, returnTypeNode, out ITypeSymbol returnType);
SyntaxNodeUtils.TryGetParameterNodeNextToAttribute(attribute, out SyntaxNode parameterNode);
TryGetDefinitionInputType(semanticModel, parameterNode, out ITypeSymbol inputType);
availableFunctions.Add(new ActivityFunctionDefinition
{
FunctionName = functionName,
ParameterNode = parameterNode,
InputType = inputType,
ReturnTypeNode = returnTypeNode,
ReturnType = returnType
});
}
}
private static bool TryGetDefinitionInputType(SemanticModel semanticModel, SyntaxNode parameterNode, out ITypeSymbol definitionInputType)
{
if (SyntaxNodeUtils.TryGetITypeSymbol(semanticModel, parameterNode, out definitionInputType))
{
if (SyntaxNodeUtils.IsDurableActivityContext(definitionInputType))
{
return TryGetInputTypeFromContext(semanticModel, parameterNode, out definitionInputType);
}
return true;
}
definitionInputType = null;
return false;
}
private static bool TryGetInputTypeFromContext(SemanticModel semanticModel, SyntaxNode node, out ITypeSymbol definitionInputType)
{
if (TryGetDurableActivityContextExpression(semanticModel, node, out SyntaxNode durableContextExpression))
{
if (SyntaxNodeUtils.TryGetTypeArgumentIdentifier((MemberAccessExpressionSyntax)durableContextExpression, out SyntaxNode typeArgument))
{
return SyntaxNodeUtils.TryGetITypeSymbol(semanticModel, typeArgument, out definitionInputType);
}
}
definitionInputType = null;
return false;
}
private static bool TryGetDurableActivityContextExpression(SemanticModel semanticModel, SyntaxNode node, out SyntaxNode durableContextExpression)
{
if (SyntaxNodeUtils.TryGetMethodDeclaration(node, out MethodDeclarationSyntax methodDeclaration))
{
var memberAccessExpressionList = methodDeclaration.DescendantNodes().Where(x => x.IsKind(SyntaxKind.SimpleMemberAccessExpression));
foreach (var memberAccessExpression in memberAccessExpressionList)
{
var identifierName = memberAccessExpression.ChildNodes().FirstOrDefault(x => x.IsKind(SyntaxKind.IdentifierName));
if (identifierName != null)
{
if (SyntaxNodeUtils.TryGetITypeSymbol(semanticModel, identifierName, out ITypeSymbol typeSymbol))
{
if (SyntaxNodeUtils.IsDurableActivityContext(typeSymbol))
{
durableContextExpression = memberAccessExpression;
return true;
}
}
}
}
}
durableContextExpression = null;
return false;
}
}
}