1

Summary

I'm trying to write a small API for my project that would provide a near drop-in replacement for the System.DllImportAttribute. How would I dynamically either replace or add a body to a method marked static extern?

Background

I've seen a few answers on here (here and here) that show how to intercept and dynamically replace methods, but they are not extern, and I can't get any of them to work with methods that aren't.

My current API does the following:

  1. Finds all methods marked with a NativeCallAttribute and returns a MethodInfo[].
  2. For each MethodInfo in the returned MethodInfo[]:
    1. Loads the specified library (if not already loaded) using either LoadLibrary or dlopen.
    2. Get an IntPtr representing the function pointer for a specified method from the loaded library using either GetProcAddress or dlsym.
    3. Generates a delegate for the native function pointer based on the current MethodInfo.
    4. Gets a MethodInfo from the generated delegate to replace the existing one.
    5. Replaces the old method with the new one.

Code

My current implementation takes a similar approach to this project in terms of getting an attribute set up and gathering information about the attached method, and this stackoverflow answer in terms of replacing the function body with a native delegate.

The current API that I have to load the library and get the function pointer work as they should (as they work in other situations than this), and they have been excluded from the code below.

Current API

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
using System.Security;
using TCDFx.ComponentModel;

namespace TCDFx.InteropServices
{
    // Indicates that the attributed method is exposed by an native assembly as a static entry point.
    [CLSCompliant(false)]
    [SuppressUnmanagedCodeSecurity]
    [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)]
    public sealed class NativeCallAttribute : Attribute
    {
        // The name of the native method.
        public string EntryPoint;

        // Initializes a new instance of the NativeCallAttribute.
        public NativeCallAttribute(params string[] assemblyNames)
        {
            if (assemblyNames == null || assemblyNames.Length == 0)
                throw new NativeCallException("No assembly specified.");

            string[] names = new string[] { };
            int i = 0;
            foreach (string name in assemblyNames)
            {
                if (!string.IsNullOrWhiteSpace(name))
                {
                    names[i] = name;
                    i++;
                }
            }

            AssemblyNames = names;
        }

        // An ordered list of assembly names.
        public string[] AssemblyNames { get; }
    }

    [SuppressUnmanagedCodeSecurity]
    public static class NativeCalls
    {
        private static readonly object sync = new object();

        // Replaces all defined functions with the 'NativeCallAttribute' that are 'static' and 'extern' with their native call.
        public static void Load()
        {
            lock (sync)
            {
                MethodInfo[] funcInfo = GetNativeCalls();

                for (int i = 0; i < funcInfo.Length; i++)
                {
                    NativeCallAttribute attribute = funcInfo[i].GetCustomAttribute<NativeCallAttribute>(false);
                    NativeAssemblyBase nativeAssembly;

                    if (IsAssemblyCached(attribute.AssemblyNames, out NativeAssemblyBase cachedAssembly))
                        nativeAssembly = cachedAssembly;
                    else
                    {
                        if (TryLoadAssembly(attribute.AssemblyNames, out NativeAssemblyBase loadedAssembly, out Exception loadingEx))
                            nativeAssembly = loadedAssembly;
                        else
                            throw loadingEx;
                    }

                    string funcName = attribute.EntryPoint ?? funcInfo[i].Name;
                    IntPtr funcPtr = nativeAssembly.LoadFunctionPointer(funcName);

                    Delegate funcDelegate = GenerateNativeDelegate(funcName, nativeAssembly.Name, funcInfo[i], funcPtr);
                    MethodInfo funcInfoNew = funcDelegate.GetMethodInfo();

                    MethodReplacementState state = ReplaceMethod(funcInfo[i], funcInfoNew);
                    replacements.Add(state);
                }
            }
        }

        // Gets all methods marked with a 'NativeCallAttribute'.
        private static MethodInfo[] GetNativeCalls()
        {
            List<MethodInfo> result = new List<MethodInfo>();
            Assembly[] assemblies = AppDomain.CurrentDomain.GetAssemblies();
            for (int i = 0; i < assemblies.Length; i++)
            {
                Type[] types = assemblies[i].GetTypes();
                for (int ii = 0; ii < types.Length; ii++)
                {
                    MethodInfo[] methods = types[ii].GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
                    for (int iii = 0; iii < methods.Length; iii++)
                    {
                        Attribute attr = methods[iii].GetCustomAttribute<NativeCallAttribute>(false);
                        if (attr != null)
                            result.Add(methods[iii]);
                    }
                }
            }
            return result.ToArray();
        }

        // Gets a 'Delegate' for a native function pointer with information provided from the method to replace.
        private static Delegate GenerateNativeDelegate(string funcName, string assemblyName, MethodInfo funcInfo, IntPtr funcPtr)
        {
            Type returnType = funcInfo.ReturnType;
            ParameterInfo[] @params = funcInfo.GetParameters();
            Type[] paramTypes = new Type[] { };

            for (int i = 0; i < @params.Length; i++)
                paramTypes[i] = @params[i].ParameterType;

            DynamicMethod nativeMethod = new DynamicMethod($"{assemblyName}_{funcName}", returnType, paramTypes, funcInfo.Module);
            ILGenerator ilGenerator = nativeMethod.GetILGenerator();

            // Generate the arguments
            for (int i = 0; i < @params.Length; i++)
            {
                //TODO: See if I need separate out code for this...
                if (@params[i].ParameterType.IsByRef || @params[i].IsOut)
                {
                    ilGenerator.Emit(OpCodes.Ldarg, i);
                    ilGenerator.Emit(OpCodes.Ldnull);
                    ilGenerator.Emit(OpCodes.Stind_Ref);
                }
                else
                {
                    ilGenerator.Emit(OpCodes.Ldarg, i);
                }
            }

            // Push the funcPtr to the stack
            if (IntPtr.Size == 4)
                ilGenerator.Emit(OpCodes.Ldc_I4, funcPtr.ToInt32());
            else if (IntPtr.Size == 8)
                ilGenerator.Emit(OpCodes.Ldc_I8, funcPtr.ToInt64());
            else throw new PlatformNotSupportedException();

            // Call it and return;
            ilGenerator.EmitCall(OpCodes.Call, funcInfo, null);
            ilGenerator.Emit(OpCodes.Ret);

            Type delegateType = Expression.GetDelegateType((from param in @params select param.ParameterType).Concat(new[] { returnType }).ToArray());
            return nativeMethod.CreateDelegate(delegateType);
        }

        private static bool IsAssemblyCached(string[] assemblyNames, out NativeAssemblyBase cachedAssembly)
        {
            bool result = false;
            cachedAssembly = null;
            foreach (string name in assemblyNames)
            {
                if (!Component.Cache.ContainsKey(Path.GetFileNameWithoutExtension(name)))
                {
                    Type asmType = Component.Cache[Path.GetFileNameWithoutExtension(name)].Value1;
                    if (asmType == typeof(NativeAssembly))
                        cachedAssembly = (NativeAssembly)Component.Cache[Path.GetFileNameWithoutExtension(name)].Value2;
                    else if (asmType == typeof(NativeAssembly))
                        cachedAssembly = (DependencyNativeAssembly)Component.Cache[Path.GetFileNameWithoutExtension(name)].Value2;
                    else if (asmType == typeof(NativeAssembly))
                        cachedAssembly = (EmbeddedNativeAssembly)Component.Cache[Path.GetFileNameWithoutExtension(name)].Value2;
                    result = true;
                    break;
                }
            }
            return result;
        }

        private static bool TryLoadAssembly(string[] assemblyNames, out NativeAssemblyBase loadedAssembly, out Exception exception)
        {
            bool result = false;
            exception = null;
            try
            {
                loadedAssembly = new NativeAssembly(assemblyNames);
            }
            catch (Exception ex)
            {
                exception = ex;
                loadedAssembly = null;
            }
            try
            {
                if (exception != null)
                    loadedAssembly = new DependencyNativeAssembly(assemblyNames);
            }
            catch (Exception ex)
            {
                exception = ex;
                loadedAssembly = null;
            }
            try
            {
                if (exception == null)
                    loadedAssembly = new EmbeddedNativeAssembly(assemblyNames);
            }
            catch (Exception ex)
            {
                exception = ex;
                loadedAssembly = null;
            }
            return result;
        }

        private static unsafe MethodReplacementState ReplaceMethod(MethodInfo targetMethod, MethodInfo replacementMethod)
        {
            if (!(targetMethod.GetMethodBody() == null && targetMethod.IsStatic))
                throw new NativeCallException($"Only the replacement of methods marked 'static extern' is supported.");

#if DEBUG
            RuntimeHelpers.PrepareMethod(targetMethod.MethodHandle);
            RuntimeHelpers.PrepareMethod(replacementMethod.MethodHandle);
#endif
            IntPtr target = targetMethod.MethodHandle.Value;
            IntPtr replacement = replacementMethod.MethodHandle.Value + 8;
            if (!targetMethod.IsVirtual)
                target += 8;
            else
            {
                int i = (int)(((*(long*)target) >> 32) & 0xFF);
                IntPtr classStart = *(IntPtr*)(targetMethod.DeclaringType.TypeHandle.Value + (IntPtr.Size == 4 ? 40 : 64));
                target = classStart + (IntPtr.Size * i);
            }

#if DEBUG
            target = *(IntPtr*)target + 1;
            replacement = *(IntPtr*)replacement + 1;

            MethodReplacementState state = new MethodReplacementState(target, new IntPtr(*(int*)target));
            *(int*)target = *(int*)replacement + (int)(long)replacement - (int)(long)target;
            return state;
#else
            MethodReplacementState state = new MethodReplacementState(target, *(IntPtr*)target);
            * (IntPtr*)target = *(IntPtr*)replacement;
            return state;
#endif
        }

        private readonly struct MethodReplacementState : IDisposable
        {
            private readonly IntPtr Location;
            private readonly IntPtr OriginalValue;

            public MethodReplacementState(IntPtr location, IntPtr origValue)
            {
                Location = location;
                OriginalValue = origValue;
            }
            public void Dispose() => Restore();

            private unsafe void Restore() =>
#if DEBUG
        *(int*)Location = (int)OriginalValue;
#else
        *(IntPtr*)Location = OriginalValue;
#endif
        }
    }
}

Test Code

using TCDFx.InteropServices;

namespace NativeCallExample
{
    internal class Program
    {
        internal static void Main()
        {
            NativeCalls.Load();
            Beep(2000, 400);
        }

        [NativeCall("kernel32.dll")]
        private static extern bool Beep(uint frequency, uint duration);
    }
}

Expected/Actual Results

I expected it to run as it should (as listed above in the description), but it is crashing dotnet.exe with error code -532462766. No breakpoints are hit anywhere in the code (in the test app or API library), and no exception is thrown. I believe the problem to be happening between steps 2.3 and 2.5 above, but I'm pretty stuck at the moment. Any help would be appreciated!

More Information

If you want to see the referenced code that isn't included and a full copy of what I have for this, you can find it in this branch of my project.

tom-corwin
  • 11
  • 3

0 Answers0