C#: Replace P/Invoke with delegate pointers

- Moves interop functions to UnmanagedCallbacks struct that
  contains the function pointers and is passed to C#.

- Implements UnmanagedCallbacksGenerator, a C# source generator that
  generates the UnmanagedCallbacks struct in C# and the body for the
  NativeFuncs methods (their implementation just calls the function
  pointer in the UnmanagedCallbacks). The generated methods are needed
  because .NET pins byref parameters of native calls, even if they are
  'ref struct's, which don't need pinning. The generated methods use
  `Unsafe.AsPointer` so that we can benefit from byref parameters
  without suffering overhead of pinning.

Co-authored-by: Raul Santos <raulsntos@gmail.com>
This commit is contained in:
Ignacio Roldán Etcheverry 2022-08-05 03:32:59 +02:00
parent 186d7f6239
commit 2c180f62d9
21 changed files with 1314 additions and 740 deletions

View file

@ -0,0 +1,24 @@
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
namespace Godot.SourceGenerators.Internal;
internal struct CallbacksData
{
public CallbacksData(INamedTypeSymbol nativeTypeSymbol, INamedTypeSymbol funcStructSymbol)
{
NativeTypeSymbol = nativeTypeSymbol;
FuncStructSymbol = funcStructSymbol;
Methods = NativeTypeSymbol.GetMembers()
.Where(symbol => symbol is IMethodSymbol { IsPartialDefinition: true })
.Cast<IMethodSymbol>()
.ToImmutableArray();
}
public INamedTypeSymbol NativeTypeSymbol { get; }
public INamedTypeSymbol FuncStructSymbol { get; }
public ImmutableArray<IMethodSymbol> Methods { get; }
}

View file

@ -0,0 +1,65 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Godot.SourceGenerators.Internal;
internal static class Common
{
public static void ReportNonPartialUnmanagedCallbacksClass(
GeneratorExecutionContext context,
ClassDeclarationSyntax cds, INamedTypeSymbol symbol
)
{
string message =
"Missing partial modifier on declaration of type '" +
$"{symbol.FullQualifiedName()}' which has attribute '{GeneratorClasses.GenerateUnmanagedCallbacksAttr}'";
string description = $"{message}. Classes with attribute '{GeneratorClasses.GenerateUnmanagedCallbacksAttr}' " +
"must be declared with the partial modifier.";
context.ReportDiagnostic(Diagnostic.Create(
new DiagnosticDescriptor(id: "GODOT-INTERNAL-G0001",
title: message,
messageFormat: message,
category: "Usage",
DiagnosticSeverity.Error,
isEnabledByDefault: true,
description),
cds.GetLocation(),
cds.SyntaxTree.FilePath));
}
public static void ReportNonPartialUnmanagedCallbacksOuterClass(
GeneratorExecutionContext context,
TypeDeclarationSyntax outerTypeDeclSyntax
)
{
var outerSymbol = context.Compilation
.GetSemanticModel(outerTypeDeclSyntax.SyntaxTree)
.GetDeclaredSymbol(outerTypeDeclSyntax);
string fullQualifiedName = outerSymbol is INamedTypeSymbol namedTypeSymbol ?
namedTypeSymbol.FullQualifiedName() :
"type not found";
string message =
$"Missing partial modifier on declaration of type '{fullQualifiedName}', " +
$"which contains one or more subclasses with attribute " +
$"'{GeneratorClasses.GenerateUnmanagedCallbacksAttr}'";
string description = $"{message}. Classes with attribute " +
$"'{GeneratorClasses.GenerateUnmanagedCallbacksAttr}' and their " +
"containing types must be declared with the partial modifier.";
context.ReportDiagnostic(Diagnostic.Create(
new DiagnosticDescriptor(id: "GODOT-INTERNAL-G0002",
title: message,
messageFormat: message,
category: "Usage",
DiagnosticSeverity.Error,
isEnabledByDefault: true,
description),
outerTypeDeclSyntax.GetLocation(),
outerTypeDeclSyntax.SyntaxTree.FilePath));
}
}

View file

@ -0,0 +1,119 @@
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Godot.SourceGenerators.Internal;
internal static class ExtensionMethods
{
public static AttributeData? GetGenerateUnmanagedCallbacksAttribute(this INamedTypeSymbol symbol)
=> symbol.GetAttributes()
.FirstOrDefault(a => a.AttributeClass?.IsGenerateUnmanagedCallbacksAttribute() ?? false);
private static bool HasGenerateUnmanagedCallbacksAttribute(
this ClassDeclarationSyntax cds, Compilation compilation,
out INamedTypeSymbol? symbol
)
{
var sm = compilation.GetSemanticModel(cds.SyntaxTree);
var classTypeSymbol = sm.GetDeclaredSymbol(cds);
if (classTypeSymbol == null)
{
symbol = null;
return false;
}
if (!classTypeSymbol.GetAttributes()
.Any(a => a.AttributeClass?.IsGenerateUnmanagedCallbacksAttribute() ?? false))
{
symbol = null;
return false;
}
symbol = classTypeSymbol;
return true;
}
private static bool IsGenerateUnmanagedCallbacksAttribute(this INamedTypeSymbol symbol)
=> symbol.ToString() == GeneratorClasses.GenerateUnmanagedCallbacksAttr;
public static IEnumerable<(ClassDeclarationSyntax cds, INamedTypeSymbol symbol)> SelectUnmanagedCallbacksClasses(
this IEnumerable<ClassDeclarationSyntax> source,
Compilation compilation
)
{
foreach (var cds in source)
{
if (cds.HasGenerateUnmanagedCallbacksAttribute(compilation, out var symbol))
yield return (cds, symbol!);
}
}
public static bool IsNested(this TypeDeclarationSyntax cds)
=> cds.Parent is TypeDeclarationSyntax;
public static bool IsPartial(this TypeDeclarationSyntax cds)
=> cds.Modifiers.Any(SyntaxKind.PartialKeyword);
public static bool AreAllOuterTypesPartial(
this TypeDeclarationSyntax cds,
out TypeDeclarationSyntax? typeMissingPartial
)
{
SyntaxNode? outerSyntaxNode = cds.Parent;
while (outerSyntaxNode is TypeDeclarationSyntax outerTypeDeclSyntax)
{
if (!outerTypeDeclSyntax.IsPartial())
{
typeMissingPartial = outerTypeDeclSyntax;
return false;
}
outerSyntaxNode = outerSyntaxNode.Parent;
}
typeMissingPartial = null;
return true;
}
public static string GetDeclarationKeyword(this INamedTypeSymbol namedTypeSymbol)
{
string? keyword = namedTypeSymbol.DeclaringSyntaxReferences
.OfType<TypeDeclarationSyntax>().FirstOrDefault()?
.Keyword.Text;
return keyword ?? namedTypeSymbol.TypeKind switch
{
TypeKind.Interface => "interface",
TypeKind.Struct => "struct",
_ => "class"
};
}
private static SymbolDisplayFormat FullyQualifiedFormatOmitGlobal { get; } =
SymbolDisplayFormat.FullyQualifiedFormat
.WithGlobalNamespaceStyle(SymbolDisplayGlobalNamespaceStyle.Omitted);
public static string FullQualifiedName(this ITypeSymbol symbol)
=> symbol.ToDisplayString(NullableFlowState.NotNull, FullyQualifiedFormatOmitGlobal);
public static string NameWithTypeParameters(this INamedTypeSymbol symbol)
{
return symbol.IsGenericType ?
string.Concat(symbol.Name, "<", string.Join(", ", symbol.TypeParameters), ">") :
symbol.Name;
}
public static string FullQualifiedName(this INamespaceSymbol symbol)
=> symbol.ToDisplayString(FullyQualifiedFormatOmitGlobal);
public static string SanitizeQualifiedNameForUniqueHint(this string qualifiedName)
=> qualifiedName
// AddSource() doesn't support angle brackets
.Replace("<", "(Of ")
.Replace(">", ")");
}

View file

@ -0,0 +1,6 @@
namespace Godot.SourceGenerators.Internal;
internal static class GeneratorClasses
{
public const string GenerateUnmanagedCallbacksAttr = "Godot.SourceGenerators.Internal.GenerateUnmanagedCallbacksAttribute";
}

View file

@ -0,0 +1,11 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<LangVersion>10</LangVersion>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="3.10.0" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.3" PrivateAssets="all" />
</ItemGroup>
</Project>

View file

@ -0,0 +1,463 @@
using System.Text;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Godot.SourceGenerators.Internal;
[Generator]
public class UnmanagedCallbacksGenerator : ISourceGenerator
{
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForPostInitialization(ctx => { GenerateAttribute(ctx); });
}
public void Execute(GeneratorExecutionContext context)
{
INamedTypeSymbol[] unmanagedCallbacksClasses = context
.Compilation.SyntaxTrees
.SelectMany(tree =>
tree.GetRoot().DescendantNodes()
.OfType<ClassDeclarationSyntax>()
.SelectUnmanagedCallbacksClasses(context.Compilation)
// Report and skip non-partial classes
.Where(x =>
{
if (x.cds.IsPartial())
{
if (x.cds.IsNested() && !x.cds.AreAllOuterTypesPartial(out var typeMissingPartial))
{
Common.ReportNonPartialUnmanagedCallbacksOuterClass(context, typeMissingPartial!);
return false;
}
return true;
}
Common.ReportNonPartialUnmanagedCallbacksClass(context, x.cds, x.symbol);
return false;
})
.Select(x => x.symbol)
)
.Distinct<INamedTypeSymbol>(SymbolEqualityComparer.Default)
.ToArray();
foreach (var symbol in unmanagedCallbacksClasses)
{
var attr = symbol.GetGenerateUnmanagedCallbacksAttribute();
if (attr == null || attr.ConstructorArguments.Length != 1)
{
// TODO: Report error or throw exception, this is an invalid case and should never be reached
System.Diagnostics.Debug.Fail("FAILED!");
continue;
}
var funcStructType = (INamedTypeSymbol?)attr.ConstructorArguments[0].Value;
if (funcStructType == null)
{
// TODO: Report error or throw exception, this is an invalid case and should never be reached
System.Diagnostics.Debug.Fail("FAILED!");
continue;
}
var data = new CallbacksData(symbol, funcStructType);
GenerateInteropMethodImplementations(context, data);
GenerateUnmanagedCallbacksStruct(context, data);
}
}
private void GenerateAttribute(GeneratorPostInitializationContext context)
{
string source = @"using System;
namespace Godot.SourceGenerators.Internal
{
internal class GenerateUnmanagedCallbacksAttribute : Attribute
{
public Type FuncStructType { get; }
public GenerateUnmanagedCallbacksAttribute(Type funcStructType)
{
FuncStructType = funcStructType;
}
}
}";
context.AddSource("GenerateUnmanagedCallbacksAttribute.generated",
SourceText.From(source, Encoding.UTF8));
}
private void GenerateInteropMethodImplementations(GeneratorExecutionContext context, CallbacksData data)
{
var symbol = data.NativeTypeSymbol;
INamespaceSymbol namespaceSymbol = symbol.ContainingNamespace;
string classNs = namespaceSymbol != null && !namespaceSymbol.IsGlobalNamespace ?
namespaceSymbol.FullQualifiedName() :
string.Empty;
bool hasNamespace = classNs.Length != 0;
bool isInnerClass = symbol.ContainingType != null;
var source = new StringBuilder();
var methodSource = new StringBuilder();
var methodCallArguments = new StringBuilder();
var methodSourceAfterCall = new StringBuilder();
source.Append(
@"using System;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Godot.Bridge;
using Godot.NativeInterop;
#pragma warning disable CA1707 // Disable warning: Identifiers should not contain underscores
");
if (hasNamespace)
{
source.Append("namespace ");
source.Append(classNs);
source.Append("\n{\n");
}
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("partial ");
source.Append(containingType.GetDeclarationKeyword());
source.Append(" ");
source.Append(containingType.NameWithTypeParameters());
source.Append("\n{\n");
containingType = containingType.ContainingType;
}
}
source.Append("[System.Runtime.CompilerServices.SkipLocalsInit]\n");
source.Append($"unsafe partial class {symbol.Name}\n");
source.Append("{\n");
source.Append($" private static {data.FuncStructSymbol.FullQualifiedName()} _unmanagedCallbacks;\n\n");
foreach (var callback in data.Methods)
{
methodSource.Clear();
methodCallArguments.Clear();
methodSourceAfterCall.Clear();
source.Append(" [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]\n");
source.Append($" {SyntaxFacts.GetText(callback.DeclaredAccessibility)} ");
if (callback.IsStatic)
source.Append("static ");
source.Append("partial ");
source.Append(callback.ReturnType.FullQualifiedName());
source.Append(' ');
source.Append(callback.Name);
source.Append('(');
for (int i = 0; i < callback.Parameters.Length; i++)
{
var parameter = callback.Parameters[i];
source.Append(parameter.ToDisplayString());
source.Append(' ');
source.Append(parameter.Name);
if (parameter.RefKind == RefKind.Out)
{
// Only assign default if the parameter won't be passed by-ref or copied later.
if (IsGodotInteropStruct(parameter.Type))
methodSource.Append($" {parameter.Name} = default;\n");
}
if (IsByRefParameter(parameter))
{
if (IsGodotInteropStruct(parameter.Type))
{
methodSource.Append(" ");
AppendCustomUnsafeAsPointer(methodSource, parameter, out string varName);
methodCallArguments.Append(varName);
}
else if (parameter.Type.IsValueType)
{
methodSource.Append(" ");
AppendCopyToStackAndGetPointer(methodSource, parameter, out string varName);
methodCallArguments.Append($"&{varName}");
if (parameter.RefKind is RefKind.Out or RefKind.Ref)
{
methodSourceAfterCall.Append($" {parameter.Name} = {varName};\n");
}
}
else
{
// If it's a by-ref param and we can't get the pointer
// just pass it by-ref and let it be pinned.
AppendRefKind(methodCallArguments, parameter.RefKind)
.Append(' ')
.Append(parameter.Name);
}
}
else
{
methodCallArguments.Append(parameter.Name);
}
if (i < callback.Parameters.Length - 1)
{
source.Append(", ");
methodCallArguments.Append(", ");
}
}
source.Append(")\n");
source.Append(" {\n");
source.Append(methodSource);
source.Append(" ");
if (!callback.ReturnsVoid)
{
if (methodSourceAfterCall.Length != 0)
source.Append($"{callback.ReturnType.FullQualifiedName()} ret = ");
else
source.Append("return ");
}
source.Append($"_unmanagedCallbacks.{callback.Name}(");
source.Append(methodCallArguments);
source.Append(");\n");
if (methodSourceAfterCall.Length != 0)
{
source.Append(methodSourceAfterCall);
if (!callback.ReturnsVoid)
source.Append(" return ret;\n");
}
source.Append(" }\n\n");
}
source.Append("}\n");
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("}\n"); // outer class
containingType = containingType.ContainingType;
}
}
if (hasNamespace)
source.Append("\n}");
source.Append("\n\n#pragma warning restore CA1707\n");
context.AddSource($"{data.NativeTypeSymbol.FullQualifiedName().SanitizeQualifiedNameForUniqueHint()}.generated",
SourceText.From(source.ToString(), Encoding.UTF8));
}
private void GenerateUnmanagedCallbacksStruct(GeneratorExecutionContext context, CallbacksData data)
{
var symbol = data.FuncStructSymbol;
INamespaceSymbol namespaceSymbol = symbol.ContainingNamespace;
string classNs = namespaceSymbol != null && !namespaceSymbol.IsGlobalNamespace ?
namespaceSymbol.FullQualifiedName() :
string.Empty;
bool hasNamespace = classNs.Length != 0;
bool isInnerClass = symbol.ContainingType != null;
var source = new StringBuilder();
source.Append(
@"using System.Runtime.InteropServices;
using Godot.NativeInterop;
#pragma warning disable CA1707 // Disable warning: Identifiers should not contain underscores
");
if (hasNamespace)
{
source.Append("namespace ");
source.Append(classNs);
source.Append("\n{\n");
}
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("partial ");
source.Append(containingType.GetDeclarationKeyword());
source.Append(" ");
source.Append(containingType.NameWithTypeParameters());
source.Append("\n{\n");
containingType = containingType.ContainingType;
}
}
source.Append("[StructLayout(LayoutKind.Sequential)]\n");
source.Append($"unsafe partial struct {symbol.Name}\n{{\n");
foreach (var callback in data.Methods)
{
source.Append(" ");
source.Append(callback.DeclaredAccessibility == Accessibility.Public ? "public " : "internal ");
source.Append("delegate* unmanaged<");
foreach (var parameter in callback.Parameters)
{
if (IsByRefParameter(parameter))
{
if (IsGodotInteropStruct(parameter.Type) || parameter.Type.IsValueType)
{
AppendPointerType(source, parameter.Type);
}
else
{
// If it's a by-ref param and we can't get the pointer
// just pass it by-ref and let it be pinned.
AppendRefKind(source, parameter.RefKind)
.Append(' ')
.Append(parameter.Type.FullQualifiedName());
}
}
else
{
source.Append(parameter.Type.FullQualifiedName());
}
source.Append(", ");
}
source.Append(callback.ReturnType.FullQualifiedName());
source.Append($"> {callback.Name};\n");
}
source.Append("}\n");
if (isInnerClass)
{
var containingType = symbol.ContainingType;
while (containingType != null)
{
source.Append("}\n"); // outer class
containingType = containingType.ContainingType;
}
}
if (hasNamespace)
source.Append("}\n");
source.Append("\n#pragma warning restore CA1707\n");
context.AddSource($"{symbol.FullQualifiedName().SanitizeQualifiedNameForUniqueHint()}.generated",
SourceText.From(source.ToString(), Encoding.UTF8));
}
private static bool IsGodotInteropStruct(ITypeSymbol type) =>
GodotInteropStructs.Contains(type.FullQualifiedName());
private static bool IsByRefParameter(IParameterSymbol parameter) =>
parameter.RefKind is RefKind.In or RefKind.Out or RefKind.Ref;
private static StringBuilder AppendRefKind(StringBuilder source, RefKind refKind) =>
refKind switch
{
RefKind.In => source.Append("in"),
RefKind.Out => source.Append("out"),
RefKind.Ref => source.Append("ref"),
_ => source,
};
private static void AppendPointerType(StringBuilder source, ITypeSymbol type)
{
source.Append(type.FullQualifiedName());
source.Append('*');
}
private static void AppendCustomUnsafeAsPointer(StringBuilder source, IParameterSymbol parameter,
out string varName)
{
varName = $"{parameter.Name}_ptr";
AppendPointerType(source, parameter.Type);
source.Append(' ');
source.Append(varName);
source.Append(" = ");
source.Append('(');
AppendPointerType(source, parameter.Type);
source.Append(')');
if (parameter.RefKind == RefKind.In)
source.Append("CustomUnsafe.ReadOnlyRefAsPointer(in ");
else
source.Append("CustomUnsafe.AsPointer(ref ");
source.Append(parameter.Name);
source.Append(");\n");
}
private static void AppendCopyToStackAndGetPointer(StringBuilder source, IParameterSymbol parameter,
out string varName)
{
varName = $"{parameter.Name}_copy";
source.Append(parameter.Type.FullQualifiedName());
source.Append(' ');
source.Append(varName);
if (parameter.RefKind is RefKind.In or RefKind.Ref)
{
source.Append(" = ");
source.Append(parameter.Name);
}
source.Append(";\n");
}
private static readonly string[] GodotInteropStructs =
{
"Godot.NativeInterop.godot_ref",
"Godot.NativeInterop.godot_variant_call_error",
"Godot.NativeInterop.godot_variant",
"Godot.NativeInterop.godot_string",
"Godot.NativeInterop.godot_string_name",
"Godot.NativeInterop.godot_node_path",
"Godot.NativeInterop.godot_signal",
"Godot.NativeInterop.godot_callable",
"Godot.NativeInterop.godot_array",
"Godot.NativeInterop.godot_dictionary",
"Godot.NativeInterop.godot_packed_byte_array",
"Godot.NativeInterop.godot_packed_int32_array",
"Godot.NativeInterop.godot_packed_int64_array",
"Godot.NativeInterop.godot_packed_float32_array",
"Godot.NativeInterop.godot_packed_float64_array",
"Godot.NativeInterop.godot_packed_string_array",
"Godot.NativeInterop.godot_packed_vector2_array",
"Godot.NativeInterop.godot_packed_vector3_array",
"Godot.NativeInterop.godot_packed_color_array",
};
}