SharpGen/Generator/Marshallers/MarshallerBase.cs (307 lines of code) (raw):

using System; using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using SharpGen.Logging; using SharpGen.Model; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace SharpGen.Generator.Marshallers { internal enum StructMarshalMethod { From, To, Free } internal abstract partial class MarshallerBase { protected static readonly SyntaxToken PtrIdentifier = Identifier("__ptr"); protected static readonly IdentifierNameSyntax PtrIdentifierName = IdentifierName(PtrIdentifier); protected static readonly SyntaxToken LengthIdentifier = Identifier("__length"); protected static readonly IdentifierNameSyntax LengthIdentifierName = IdentifierName(LengthIdentifier); protected static readonly LiteralExpressionSyntax DefaultLiteral = LiteralExpression( SyntaxKind.DefaultLiteralExpression ); protected static readonly LiteralExpressionSyntax NullLiteral = LiteralExpression( SyntaxKind.NullLiteralExpression ); private readonly Ioc ioc; protected GlobalNamespaceProvider GlobalNamespace => ioc.GlobalNamespace; protected Logger Logger => ioc.Logger; protected MarshallerBase(Ioc ioc) { this.ioc = ioc ?? throw new ArgumentNullException(nameof(ioc)); } protected static TypeSyntax IntPtrType => GeneratorHelpers.IntPtrType; protected static MemberAccessExpressionSyntax IntPtrZero => GeneratorHelpers.IntPtrZero; protected static TypeSyntax VoidPtrType => GeneratorHelpers.VoidPtrType; protected static LiteralExpressionSyntax ZeroLiteral => GeneratorHelpers.ZeroLiteral; private static bool IsNullable(CsMarshalBase marshallable) => marshallable is CsParameter {IsNullable: true}; protected static StatementSyntax GenerateNullCheckIfNeeded(CsMarshalBase marshallable, StatementSyntax statement) => IsNullable(marshallable) ? IfStatement( BinaryExpression(SyntaxKind.NotEqualsExpression, IdentifierName(marshallable.Name), NullLiteral), statement ) : statement; protected static ExpressionSyntax GenerateNullCheckIfNeeded(CsMarshalBase marshallable, ExpressionSyntax expression, ExpressionSyntax nullAlternative) => IsNullable(marshallable) ? ConditionalExpression( BinaryExpression(SyntaxKind.EqualsExpression, IdentifierName(marshallable.Name), NullLiteral), nullAlternative, expression ) : expression; protected static StatementSyntax LoopThroughArrayParameter( CsMarshalBase marshallable, Func<ElementAccessExpressionSyntax, ElementAccessExpressionSyntax, StatementSyntax> loopBodyFactory, string variableName = "i") { var indexVariable = Identifier(variableName); var indexVariableName = IdentifierName(variableName); var arrayIdentifier = IdentifierName(marshallable.Name); var element = ElementAccessExpression( arrayIdentifier, BracketedArgumentList(SingletonSeparatedList(Argument(indexVariableName))) ); var nativeElement = ElementAccessExpression( ParenthesizedExpression(GetMarshalStorageLocation(marshallable)), BracketedArgumentList(SingletonSeparatedList(Argument(indexVariableName))) ); return GenerateNullCheckIfNeeded( marshallable, ForStatement(loopBodyFactory(element, nativeElement)) .WithDeclaration( VariableDeclaration( TypeInt32, SeparatedList( new[] { VariableDeclarator(indexVariable, default, EqualsValueClause(ZeroLiteral)), VariableDeclarator( LengthIdentifier, default, EqualsValueClause(GeneratorHelpers.LengthExpression(arrayIdentifier)) ) } ))) .WithCondition( BinaryExpression(SyntaxKind.LessThanExpression, indexVariableName, LengthIdentifierName) ) .WithIncrementors( SingletonSeparatedList<ExpressionSyntax>( PrefixUnaryExpression( SyntaxKind.PreIncrementExpression, indexVariableName))) ); } protected static StatementSyntax CreateMarshalStructStatement( CsMarshalBase marshallable, StructMarshalMethod marshalMethod, ExpressionSyntax publicElementExpr, ExpressionSyntax marshalElementExpr) { StatementSyntaxList statements = new(); var marshalArgument = Argument(marshalElementExpr).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); if (marshallable.PublicType is CsStruct {GenerateAsClass: true} structType && marshalMethod == StructMarshalMethod.From) { var constructor = ObjectCreationExpression(ParseTypeName(structType.QualifiedName)); var argumentList = !structType.HasCustomMarshal ? ArgumentList(SingletonSeparatedList(marshalArgument)) : ArgumentList(); statements.Add( ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, publicElementExpr, constructor.WithArgumentList(argumentList) ) ) ); if (!structType.HasCustomMarshal) return statements.ToStatement(); } var methodName = IdentifierName($"__Marshal{marshalMethod}"); var invocationExpression = marshallable.IsStaticMarshal ? InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, ParseTypeName(marshallable.PublicType.QualifiedName), methodName ), ArgumentList( SeparatedList( new[] { Argument(publicElementExpr) .WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), marshalArgument } ) ) ) : InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, publicElementExpr, methodName ), ArgumentList(SingletonSeparatedList(marshalArgument)) ); statements.Add(GenerateNullCheckIfNeeded(marshallable, ExpressionStatement(invocationExpression))); return statements.ToStatement(); } protected static StatementSyntax GenerateMarshalStructManagedToNative(CsMarshalBase csElement, ExpressionSyntax publicElement, ExpressionSyntax marshalElement) { var marshalTo = CreateMarshalStructStatement( csElement, StructMarshalMethod.To, publicElement, marshalElement ); return ((CsStruct) csElement.PublicType).HasCustomNew ? Block( CreateMarshalCustomNewStatement(csElement, marshalElement), marshalTo ) : marshalTo; } protected static ExpressionStatementSyntax CreateMarshalCustomNewStatement(CsMarshalBase csElement, ExpressionSyntax marshalElement) { return ExpressionStatement( AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, marshalElement, InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseTypeName(csElement.PublicType.QualifiedName), IdentifierName("__NewNative"))) .WithArgumentList(ArgumentList()))); } protected internal static SyntaxToken GetMarshalStorageLocationIdentifier(CsMarshalCallableBase marshallable) => marshallable switch { CsParameter => Identifier($"{marshallable.Name}_"), CsReturnValue => Identifier(CsReturnValue.MarshalStorageLocation), _ => throw new ArgumentException(nameof(marshallable)) }; protected internal static SyntaxToken GetRefLocationIdentifier(CsMarshalCallableBase marshallable) => marshallable switch { CsParameter => Identifier($"{marshallable.Name}_ref_"), CsReturnValue => throw new Exception("Return values as ref locals are not supported"), _ => throw new ArgumentException(nameof(marshallable)) }; protected internal static ExpressionSyntax GetMarshalStorageLocation(CsMarshalBase marshallable) => marshallable switch { CsParameter parameter => IdentifierName(GetMarshalStorageLocationIdentifier(parameter)), CsReturnValue returnValue => IdentifierName(GetMarshalStorageLocationIdentifier(returnValue)), CsField => MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, IdentifierName("@ref"), IdentifierName(marshallable.Name) ), _ => throw new ArgumentException(nameof(marshallable)) }; protected static StatementSyntax MarshalInterfaceInstanceFromNative(CsMarshalBase csElement, ExpressionSyntax publicElement, ExpressionSyntax marshalElement) => ExpressionStatement( csElement switch { CsParameter {IsFast: true, IsOut: true} => AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, publicElement, IdentifierName("NativePointer") ), marshalElement ), _ => AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, publicElement, ConditionalExpression( BinaryExpression(SyntaxKind.NotEqualsExpression, marshalElement, IntPtrZero), ObjectCreationExpression( ParseTypeName(csElement.PublicType.GetNativeImplementationQualifiedName()) ) .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(marshalElement)))), NullLiteral ) ) } ); protected static ArgumentSyntax GenerateManagedValueTypeArgument(CsParameter csElement) { var arg = Argument(IdentifierName(csElement.Name)); if (csElement.IsOut) { return arg.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)); } if (csElement.PassedByManagedReference) { return arg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); } return arg; } protected static ParameterSyntax GenerateManagedValueTypeParameter(CsParameter csElement) { var param = Parameter(Identifier(csElement.Name)); if (csElement.IsOut) { param = param.AddModifiers(Token(SyntaxKind.OutKeyword)); } else if (csElement.PassedByManagedReference) { param = param.AddModifiers(Token(SyntaxKind.RefKeyword)); } var type = ParseTypeName(csElement.PublicType.QualifiedName); if (csElement.IsNullableStruct) { type = NullableType(type); } return param.WithType(type); } protected static ParameterSyntax GenerateManagedArrayParameter(CsParameter csElement) { var param = Parameter(Identifier(csElement.Name)) .WithType(ArrayType(ParseTypeName(csElement.PublicType.QualifiedName), SingletonList(ArrayRankSpecifier()))); if (csElement.HasParams) { param = param.AddModifiers(Token(SyntaxKind.ParamsKeyword)); } return param; } protected StatementSyntax GenerateArrayNativeToManagedExtendedProlog(CsMarshalCallableBase csElement) { // e.g. Function(int[] buffer, int length) // callable is Function // csElement is buffer // lengthParam is length var callable = (CsCallable) csElement.Parent; bool RelationPredicate(LengthRelation relation) => relation.Identifier == csElement.CppElementName; bool MatchPredicate(CsParameter param) => param.Relations.OfType<LengthRelation>().Any(RelationPredicate); var lengthParam = callable.Parameters.Where(MatchPredicate).ToArray(); return lengthParam.Length switch { 0 => NotSupported("Cannot marshal a native array [{0}] to a managed array when length is not specified"), > 1 => NotSupported( "Cannot marshal a native array [{0}] to a managed array when length is specified multiple times" ), _ => LengthRelationMarshaller.GenerateNativeToManaged(csElement, lengthParam[0]) }; StatementSyntax NotSupported(string hint) { Logger.Error(LoggingCodes.InvalidLengthRelation, hint, csElement.QualifiedName); return null; } } protected static StatementSyntax GenerateGCKeepAlive(CsMarshalBase csElement) => ExpressionStatement( InvocationExpression( ParseName("System.GC.KeepAlive"), ArgumentList( SingletonSeparatedList( Argument(IdentifierName(csElement.Name)) ) ) ) ); } }