3

I am researching a way to make Visual Studio fire a warning if I override a specific method in a base class but forget to call the base method in the overridden one. E.g:

class Foo
{
   [SomeAttributeToMarkTheMethodToFireTheWarning]
   public virtual void A() { ... }
}

class Bar : Foo
{
   public override void A()
   {
      // base.A(); // warning if base.A() is not called
      // ...
   }
}

So far I couldn't find a way and probably it is not possible to make the compiler fire such a warning directly. Any ideas for a way to do it, even if it's a 3rd-party tool or using some API from the new Roslyn .NET compiler platform?

UPDATE: For example, in AndroidStudio (IntelliJ) if you override onCreate() in any activity but forget to call the base method super.onCreate(), you get a warning. That's the behavior I need in VS.

Liam
  • 27,717
  • 28
  • 128
  • 190
Borislav Borisov
  • 378
  • 1
  • 3
  • 12
  • But you don't have to fire the base method. So no there is no way for the compiler to do this as it's not a compilation error – Liam Jun 22 '16 at 11:09
  • IntelliJ is Java. Your coding in C#.This funcitonality does not exist in C# – Liam Jun 22 '16 at 11:16
  • I DO understand that my code is in C# and not Java and that's why I am searching for a way to achive the same result as in IntelliJ but with VS – Borislav Borisov Jun 22 '16 at 11:21

3 Answers3

6

I finally had some time to experiment with Roslyn and looks like I found a solution with an analyzer. This is my solution.

The attribute to mark the method that needs to be overriden in the subclass:

[AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
public sealed class RequireBaseMethodCallAttribute : Attribute
{
    public RequireBaseMethodCallAttribute() { }
}

The analyzer:

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class RequiredBaseMethodCallAnalyzer : DiagnosticAnalyzer
{
    public const string DiagnosticId = "RequireBaseMethodCall";

    // You can change these strings in the Resources.resx file. If you do not want your analyzer to be localize-able, you can use regular strings for Title and MessageFormat.
    // See https://github.com/dotnet/roslyn/blob/master/docs/analyzers/Localizing%20Analyzers.md for more on localization
    private static readonly LocalizableString Title = new LocalizableResourceString(nameof(Resources.AnalyzerTitle), Resources.ResourceManager, typeof(Resources));
    private static readonly LocalizableString MessageFormat = new LocalizableResourceString(nameof(Resources.AnalyzerMessageFormat), Resources.ResourceManager, typeof(Resources));
    private static readonly LocalizableString Description = new LocalizableResourceString(nameof(Resources.AnalyzerDescription), Resources.ResourceManager, typeof(Resources));
    private const string Category = "Usage";

    private static DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Warning, isEnabledByDefault: true, description: Description);

    public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get { return ImmutableArray.Create(Rule); } }

    public override void Initialize(AnalysisContext context)
    {
        context.RegisterCompilationStartAction(AnalyzeMethodForBaseCall);
    }

    private static void AnalyzeMethodForBaseCall(CompilationStartAnalysisContext compilationStartContext)
    {
        compilationStartContext.RegisterSyntaxNodeAction(AnalyzeMethodDeclaration, SyntaxKind.MethodDeclaration);
    }

    private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context)
    {
        var mds = context.Node as MethodDeclarationSyntax;
        if (mds == null)
        {
            return;
        }

        IMethodSymbol symbol = context.SemanticModel.GetDeclaredSymbol(mds) as IMethodSymbol;
        if (symbol == null)
        {
            return;
        }

        if (!symbol.IsOverride)
        {
            return;
        }

        if (symbol.OverriddenMethod == null)
        {
            return;
        }

        var overridenMethod = symbol.OverriddenMethod;
        var attrs = overridenMethod.GetAttributes();
        if (!attrs.Any(ad => ad.AttributeClass.MetadataName.ToUpperInvariant() 
                            == typeof(RequireBaseMethodCallAttribute).Name.ToUpperInvariant()))
        {
            return;
        }

        var overridenMethodName = overridenMethod.Name.ToString();
        string methodName = overridenMethodName;

        var invocations = mds.DescendantNodes().OfType<MemberAccessExpressionSyntax>().ToList();
        foreach (var inv in invocations)
        {
            var expr = inv.Expression;
            if ((SyntaxKind)expr.RawKind == SyntaxKind.BaseExpression)
            {
                var memberAccessExpr = expr.Parent as MemberAccessExpressionSyntax;
                if (memberAccessExpr == null)
                {
                    continue;
                }

                // compare exprSymbol and overridenMethod
                var exprMethodName = memberAccessExpr.Name.ToString();

                if (exprMethodName != overridenMethodName)
                {
                    continue;
                }

                var invokationExpr = memberAccessExpr.Parent as InvocationExpressionSyntax;
                if (invokationExpr == null)
                {
                    continue;
                }
                var exprMethodArgs = invokationExpr.ArgumentList.Arguments.ToList();
                var ovrMethodParams = overridenMethod.Parameters.ToList();

                if (exprMethodArgs.Count != ovrMethodParams.Count)
                {
                    continue;
                }

                var paramMismatch = false;
                for (int i = 0; i < exprMethodArgs.Count; i++)
                {
                    var arg = exprMethodArgs[i];
                    var argType = context.SemanticModel.GetTypeInfo(arg.Expression);

                    var param = arg.NameColon != null ? 
                                ovrMethodParams.FirstOrDefault(p => p.Name.ToString() == arg.NameColon.Name.ToString()) : 
                                ovrMethodParams[i];

                    if (param == null || argType.Type != param.Type)
                    {
                        paramMismatch = true;
                        break;
                    }

                    exprMethodArgs.Remove(arg);
                    ovrMethodParams.Remove(param);
                    i--;
                }

                // If there are any parameters left without default value
                // then it is not the base method overload we are looking for
                if (ovrMethodParams.Any(p => p.HasExplicitDefaultValue))
                {
                    continue;
                }

                if (!paramMismatch)
                {
                    // If the actual arguments match with the method params
                    // then the base method invokation was found
                    // and there is no need to continue the search
                    return;
                }
            }
        }

        var diag = Diagnostic.Create(Rule, mds.GetLocation(), methodName);
        context.ReportDiagnostic(diag);
    }
}

The CodeFix provider:

[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(BaseMethodCallCodeFixProvider)), Shared]
public class BaseMethodCallCodeFixProvider : CodeFixProvider
{
    private const string title = "Add base method invocation";

    public sealed override ImmutableArray<string> FixableDiagnosticIds
    {
        get { return ImmutableArray.Create(RequiredBaseMethodCallAnalyzer.DiagnosticId); }
    }

    public sealed override FixAllProvider GetFixAllProvider()
    {
        // See https://github.com/dotnet/roslyn/blob/master/docs/analyzers/FixAllProvider.md for more information on Fix All Providers
        return WellKnownFixAllProviders.BatchFixer;
    }

    public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
    {
        var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);

        var diagnostic = context.Diagnostics.First();
        var diagnosticSpan = diagnostic.Location.SourceSpan;

        // Register a code action that will invoke the fix.
        context.RegisterCodeFix(
            CodeAction.Create(
                title: title,
                createChangedDocument: c => AddBaseMethodCallAsync(context.Document, diagnosticSpan, c),
                equivalenceKey: title),
            diagnostic);
    }

    private async Task<Document> AddBaseMethodCallAsync(Document document, TextSpan diagnosticSpan, CancellationToken cancellationToken)
    {
        var root = await document.GetSyntaxRootAsync(cancellationToken);
        var node = root.FindNode(diagnosticSpan) as MethodDeclarationSyntax;

        var args = new List<ArgumentSyntax>();
        foreach (var param in node.ParameterList.Parameters)
        {
            args.Add(SyntaxFactory.Argument(SyntaxFactory.ParseExpression(param.Identifier.ValueText)));
        }

        var argsList = SyntaxFactory.SeparatedList(args);

        var exprStatement = SyntaxFactory.ExpressionStatement(
            SyntaxFactory.InvocationExpression(
                SyntaxFactory.MemberAccessExpression(
                    SyntaxKind.SimpleMemberAccessExpression,
                    SyntaxFactory.BaseExpression(),
                    SyntaxFactory.Token(SyntaxKind.DotToken),
                    SyntaxFactory.IdentifierName(node.Identifier.ToString())
                ),
                SyntaxFactory.ArgumentList(argsList)
            ),
            SyntaxFactory.Token(SyntaxKind.SemicolonToken)
        );

        var newBodyStatements = SyntaxFactory.Block(node.Body.Statements.Insert(0, exprStatement));
        var newRoot = root.ReplaceNode(node.Body, newBodyStatements).WithAdditionalAnnotations(Simplifier.Annotation);

        return document.WithSyntaxRoot(newRoot);
    }
}

And a demo how it works: http://screencast.com/t/4Jgm989TI

Since I am totally new to the .NET Compiler Platform, I would love to have any feedback and suggestions on how to improve my solution. Thank you in advance!

Borislav Borisov
  • 378
  • 1
  • 3
  • 12
3

If you want to ensure some code is run then you should change your design:

abstract class Foo
{
   protected abstract void PostA();  

   public void A() { 
      ... 
      PostA();
   }
}


class Bar : Foo
{
   protected override void PostA()
   {

   }
}

//method signature remains the same:
Bar.A();

In this way A() is always fired before your overridden method

To have multiple inheritence and to ensure A() is called you would have to make bar abstract as well:

abstract class Bar : Foo
{
   //no need to override now
}

class Baz:Bar
{
   protected override void PostA()
   {

   }
}

There is no way to do exactly what you want in C#. This isn't a Visual Studio issue. This is how C# works.

Virtual method signatures can be overridden or not, called in the base or not. You have two options virtual or abstract. Your using virtual and I've given you an abstract soltuion. It's up to you to choose which one you want to use.

The nearest thing I can think of of what you want would be a #warning. See this answer. But this will only produce the warning in the Output window not in intellisense. Basically C# does not support custom compiler warnings.

Community
  • 1
  • 1
Liam
  • 27,717
  • 28
  • 128
  • 190
  • Plus one, absolutely. Can you prevent `A` from being overridden in C#? – Bathsheba Jun 22 '16 at 11:12
  • The problem is that my PostA method in the base class can not be abstract since it has a body. – Borislav Borisov Jun 22 '16 at 11:13
  • 4
    You put the body into A(). Only the code you want to override goes into PostA() – Liam Jun 22 '16 at 11:14
  • That would work with just 1 subclass but if I need to add class Baz : Bar and ensure that A() from Bar is called in the overridden A() in Baz, such design solution won't help me. – Borislav Borisov Jun 22 '16 at 11:19
  • 2
    As pointed out in the comments to the question you link to related to custom compiler warnings - this is definitely closer to being doable these days, if you write a Roslyn analyser. – Damien_The_Unbeliever Jun 22 '16 at 11:52
  • I'll have to take your word on that @Damien_The_Unbeliever. Not something I have experience with. – Liam Jun 22 '16 at 12:12
0

I have solved this problem with runtime checks.

I created an object which is able to assert that the basemost method is invoked for any of its overridable methods.

Of course this is not ideal; the ideal is to make this a compile-time check, as Borislav's solution suggests. But Borislav's solution represents an awful lot of knowledge, an awful lot of work, an intervention in the build system, an intervention in the editor, aaaargh! Roslyn Analyzers, CodeFix providers, that's exotic stuff. I do not even know where to begin implementing such a thing.

So, here is my rather simple runtime checking approach:

/// <summary>
/// Base class for making sure that descendants always invoke overridable
/// methods of base.
/// </summary>
public abstract class Overridable
{
    private sealed class InvocationGuard : IDisposable
    {
        private readonly Overridable overridable;
        public readonly  string      MethodName;
        public           bool        Invoked;

        public InvocationGuard( Overridable overridable, string methodName )
        {
            this.overridable = overridable;
            MethodName       = methodName;
        }

        public void Dispose()
        {
            Assert( ReferenceEquals( overridable.invocationStack.Peek(), this ) );
            Assert( Invoked );
            overridable.invocationStack.Pop();
        }
    }

    private readonly Stack<InvocationGuard> invocationStack = new Stack<InvocationGuard>();

    public IDisposable NewOverridableGuard( string methodName )
    {
        Assert( ReflectionHelpers.MethodExistsAssertion( GetType(), methodName ) );
        var invocationGuard = new InvocationGuard( this, methodName );
        invocationStack.Push( invocationGuard );
        return invocationGuard;
    }

    public void OverridableWasInvoked( [CanBeNull][CallerMemberName] string methodName = null )
    {
        Assert( methodName != null );
        Assert( ReflectionHelpers.MethodExistsAssertion( GetType(), methodName ) );
        InvocationGuard invocationGuard = invocationStack.Peek();
        Assert( invocationGuard.MethodName == methodName );
        Assert( !invocationGuard.Invoked );
        invocationGuard.Invoked = true;
    }
}

For the sake of brevity the implementation of ReflectionHelpers.MethodExistsAssertion() is left as an exercise to the reader; it is an optional assertion anyway.

Use it as follows:

(I picked OnPropertyChanged( string propertyName ) as an example, since many developers may already be familiar with it, and the problems associated with forgetting to invoke base when using it.)

protected internal void RaisePropertyChanged( [CallerMemberName] string propertyName = null )
{
    using( NewOverridableGuard( nameof(OnPropertyChanged) ) ) //Add this to your existing code
        OnPropertyChanged( propertyName );
}

protected virtual void OnPropertyChanged( string propertyName )
{
    OverridableWasInvoked(); //Add this to your existing code
    PropertyChanged?.Invoke( this, new PropertyChangedEventArgs( propertyName ) );
}

Note: the Stack<> is necessary in order to handle recursive invocations.

Note: the checking only takes into account the name of the overridable method, not its parameters, so for best results you better not have overloaded overridables.

Mike Nakis
  • 56,297
  • 11
  • 110
  • 142