4

I have searched the internet for the last couple days for a solution to this, and haven't found what I've wanted. Basically, here is my problem:

  1. I have an interface I need to implement that has a method that returns an IQueryable (I don't have access to the interface, so I cannot change this)
  2. I would like the method to return the concatenation of (a) an IQueryable that points to a very large database table, and (b) a large IEnumerable that has been computed in memory of the same Entity type
  3. I cannot do queryableA.Concat(enumerableB).Where(condition) because it will try to send the entire array to the server (and, aside from that, I get an exception that it only supports primitive types)
  4. I cannot do enumerableB.Concat(queryableA).Where(condition) because it will pull the entirety of the table into memory and treat it as an IEnumerable

So, after some searching, I think I've decided a good way to approach this problem is to write my own ConcatenatingQueryable implementation of IQueryable that takes two IQueryable's and executes the Expression tree on each independently, and then concatenations the results. However, I seem to be having issues as it returns a stack overflow. Based on http://blogs.msdn.com/b/mattwar/archive/2007/07/30/linq-building-an-iqueryable-provider-part-i.aspx, this is what I've implemented so far:

class Program
{
    static void Main(string[] args)
    {
        var source1 = new[] {  1, 2 }.AsQueryable();
        var source2 = new[] { -1, 1 }.AsQueryable();
        var matches = new ConcatenatingQueryable<int>(source1, source2).Where(x => x <= 1).ToArray();
        Console.WriteLine(string.Join(",", matches));
        Console.ReadKey();
    }

    public class ConcatenatingQueryable<T> : IQueryable<T>
    {
        private readonly ConcatenatingQueryableProvider<T> provider;
        private readonly Expression expression;

        public ConcatenatingQueryable(IQueryable<T> source1, IQueryable<T> source2)
            : this(new ConcatenatingQueryableProvider<T>(source1, source2))
        {}

        public ConcatenatingQueryable(ConcatenatingQueryableProvider<T> provider)
        {
            this.provider = provider;
            this.expression = Expression.Constant(this);
        }

        public ConcatenatingQueryable(ConcatenatingQueryableProvider<T> provider, Expression expression)
        {
            this.provider = provider;
            this.expression = expression;
        }

        Expression IQueryable.Expression
        {
            get { return expression; }
        }

        Type IQueryable.ElementType
        {
            get { return typeof(T); }
        }

        IQueryProvider IQueryable.Provider
        {
            get { return provider; }
        }

        public IEnumerator<T> GetEnumerator()
        {
            // This line is calling Execute below
            return ((IEnumerable<T>)provider.Execute(expression)).GetEnumerator();
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return ((IEnumerable)provider.Execute(expression)).GetEnumerator();
        }
    }

    public class ConcatenatingQueryableProvider<T> : IQueryProvider
    {
        private readonly IQueryable<T> source1;
        private readonly IQueryable<T> source2;

        public ConcatenatingQueryableProvider(IQueryable<T> source1, IQueryable<T> source2)
        {
            this.source1 = source1;
            this.source2 = source2;
        }

        IQueryable<TS> IQueryProvider.CreateQuery<TS>(Expression expression)
        {
            var elementType = TypeSystem.GetElementType(expression.Type);
            try
            {
                return (IQueryable<TS>)Activator.CreateInstance(typeof(ConcatenatingQueryable<>).MakeGenericType(elementType), new object[] { this, expression });
            }
            catch (TargetInvocationException tie)
            {
                throw tie.InnerException;
            }
        }

        IQueryable IQueryProvider.CreateQuery(Expression expression)
        {
            var elementType = TypeSystem.GetElementType(expression.Type);
            try
            {
                return (IQueryable)Activator.CreateInstance(typeof(ConcatenatingQueryable<>).MakeGenericType(elementType), new object[] { this, expression });
            }
            catch (TargetInvocationException tie)
            {
                throw tie.InnerException;
            }
        }

        TS IQueryProvider.Execute<TS>(Expression expression)
        {
            return (TS)Execute(expression);
        }

        object IQueryProvider.Execute(Expression expression)
        {
            return Execute(expression);
        }

        public object Execute(Expression expression)
        {
            // This is where I suspect the problem lies, as executing the 
            // Expression.Constant from above here will call Enumerate again,
            // which then calls this, and... you get the point
            dynamic results1 = source1.Provider.Execute(expression);
            dynamic results2 = source2.Provider.Execute(expression);
            return results1.Concat(results2);
        }
    }

    internal static class TypeSystem
    {
        internal static Type GetElementType(Type seqType)
        {
            var ienum = FindIEnumerable(seqType);
            if (ienum == null)
                return seqType;
            return ienum.GetGenericArguments()[0];
        }

        private static Type FindIEnumerable(Type seqType)
        {
            if (seqType == null || seqType == typeof(string))
                return null;
            if (seqType.IsArray)
                return typeof(IEnumerable<>).MakeGenericType(seqType.GetElementType());
            if (seqType.IsGenericType)
            {
                foreach (var arg in seqType.GetGenericArguments())
                {
                    var ienum = typeof(IEnumerable<>).MakeGenericType(arg);
                    if (ienum.IsAssignableFrom(seqType))
                    {
                        return ienum;
                    }
                }
            }
            var ifaces = seqType.GetInterfaces();
            if (ifaces.Length > 0)
            {
                foreach (var iface in ifaces)
                {
                    var ienum = FindIEnumerable(iface);
                    if (ienum != null)
                        return ienum;
                }
            }
            if (seqType.BaseType != null && seqType.BaseType != typeof(object))
            {
                return FindIEnumerable(seqType.BaseType);
            }
            return null;
        }
    }
}

I don't have much experience with this interface, and am a bit lost as to what to do from here. Does anyone have any suggestions on how to do this? I'm also open to abandoning this approach entirely if need be.

Just to reiterate, I'm getting a StackOverflowException, and the stacktrace is simply a bunch of calls between the two commented lines above, with "[External Code]" in between each pair of calls. I have added an example Main method that uses two tiny enumerables, but you can imagine these were larger data sources that take a very long time to enumerate.

Thank you very much in advance for your help!

cfred
  • 123
  • 6
  • sounds like you will need to perhaps use some Abstract Classes if I am understanding this correctly take a look at this article if this is the case http://stackoverflow.com/questions/3311788/when-to-implement-an-interface-and-when-to-extend-a-superclass – MethodMan Sep 15 '14 at 19:15
  • Thank you for your response. Do you have any suggestions on which classes I should be using? My reason for assuming I need to implement IQueryable is because that is what the interface I'm using requires, and I don't know of any abstract (or concrete, for that matter) classes that give me the concatenating properties I require. – cfred Sep 15 '14 at 19:30
  • seriously off the top of my head all I could think of would be what I have seen in the past in this msdn posting http://msdn.microsoft.com/en-us/library/vstudio/bb534644(v=vs.100).aspx – MethodMan Sep 15 '14 at 19:33
  • Does this have to work in the general case? I mean, can't you just do `enumerableB.Where(condition).Concat(queryableA.Where(condition))` for this particular case? – Lucas Trzesniewski Sep 15 '14 at 20:30
  • @Lucas Unfortunately, yes, it needs to be a general solution because the IQueryable is used in other code. – cfred Sep 16 '14 at 07:08

1 Answers1

2

When you break down the expression tree that gets passed into the IQueryProvider, you will see the call chain of LINQ methods. Remember that generally LINQ works by chaining extension methods, where the return value of the previous method is passed into the next method as the first argument.

If we follow that logically, that means the very first LINQ method in the chain must have a source argument, and it's plain from the code that its source is, in fact, the very same IQueryable that kicked the whole thing off in the first place (your ConcatenatingQueryable).

You pretty much got the idea right when you built this - you just need to go one small step further. What we need to do is re-point that first LINQ method to use the actual source, then allow the execution to follow its natural path.

Here is some example code that does this:

    public object Execute(Expression expression)
    {
        var query1 = ChangeQuerySource(expression, Expression.Constant(source1));
        var query2 = ChangeQuerySource(expression, Expression.Constant(source2));
        dynamic results1 = source1.Provider.Execute(query1);
        dynamic results2 = source2.Provider.Execute(query2);
        return Enumerable.Concat(results1, results2);
    }

    private static Expression ChangeQuerySource(Expression query, Expression newSource)
    {
        // step 1: cast the Expression as a MethodCallExpression.
        // This will usually work, since a chain of LINQ statements
        // is generally a chain of method calls, but I would not
        // make such a blind assumption in production code.
        var methodCallExpression = (MethodCallExpression)query;

        // step 2: Create a new MethodCallExpression, passing in
        // the existing one's MethodInfo so we're calling the same
        // method, but just changing the parameters. Remember LINQ
        // methods are extension methods, so the first argument is
        // always the source. We carry over any additional arguments.
        query = Expression.Call(
            methodCallExpression.Method,
            new Expression[] { newSource }.Concat(methodCallExpression.Arguments.Skip(1)));

        // step 3: We call .AsEnumerable() at the end, to get an
        // ultimate return type of IEnumerable<T> instead of
        // IQueryable<T>, so we can safely use this new expression
        // tree in any IEnumerable statement.
        query = Expression.Call(
            typeof(Enumerable).GetMethod("AsEnumerable", BindingFlags.Static | BindingFlags.Public)
            .MakeGenericMethod(
                TypeSystem.GetElementType(methodCallExpression.Arguments[0].Type)
            ),
            query);
        return query;
    }
Rex M
  • 142,167
  • 33
  • 283
  • 313
  • This won't fix the problem. `source1.Provider.Execute` returns a `ConcatenatingQueryable`, which is the cause of the infinite recursion. This won't change that. – Servy Sep 15 '14 at 20:36
  • @Servy ah, yes. This is a little more complex. We need to peel off that layer of the expression tree. – Rex M Sep 15 '14 at 20:49
  • Thanks for the comments guys. It seems that I may need to detect the end of the expression tree and end the recursion somehow, as I really only need to concatenate once, and then it's no longer needed. In the Execute() method I can detect when I've reached the ConstantExpression using: "if (expression.GetType() == typeof(ConstantExpression) && expression.Type.IsGenericType && expression.Type.GetGenericTypeDefinition() == typeof(ConcatenatedQueryable<>))", but then I essentially want to call .Where(expression) at that point instead of .Provider.Execute(), but expression isn't the correct type. – cfred Sep 16 '14 at 08:27
  • Thank you - that's it! Pretty slick piece of code. Although this doesn't seem to work with LINQ-to-Entities. If I pass in an `IQueryable` for `source1`, for example, I get: **The specified cast from a materialized 'TableObject' type to the 'System.Linq.IQueryable`1[TableObject]' type is not valid** on `source1.Provider.Execute(query1)`. Any ideas? It works if I pull out a subset of the table into memory. You can reproduce with: `q.Provider.Execute(q.Expression)` for any EntitySet q, but this works fine for any in-memory IQueryable. Thanks. – cfred Sep 16 '14 at 16:40
  • @cfred What is the exact query you're trying to run from that? – Rex M Sep 16 '14 at 18:06
  • @Rex M Just a simple .Where() clause on one of the properties. If you don't know then I can mark this as Answered and look into it myself since you answered the original question... just thought I'd try. I have `q = context.Entities.Where(x => x.Property = Value)`, and if i do `q.Provider.Execute(q.Expression)`, I get the above exception. – cfred Sep 16 '14 at 18:30
  • @cfred I don't know off the top of my head, but I'd suggest debugging the `Execute` method and start inspecting the expression tree very closely... there can be expressions that look valid at first, second and third glance. – Rex M Sep 16 '14 at 19:22