5

I have the problem that I want to count the number of combinations that fulfill the following condition:

 a < b < a+d < c < b+d

Where a, b, c are elements of a list, and d is a fixed delta.

Here is a vanilla implementation:

def count(l, d):
    s = 0
    for a in l:
        for b in l:
            for c in l:
                if a < b < a + d < c < b + d:
                    s += 1
    return s

Here is a test:

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    assert(32 == count(l, 4)) # Gone through everything by hand.

Question

How can I speed this up? I am looking at list sizes of 2 Million.

Supplementary Information

I am dealing with floats in the range of [-pi, pi]. For example, this limits a < 0.

What I have so far:

I have some implementation where I build indices that I use for b and c. However, the below code fails some cases. (i.e. This is wrong).

def count(l, d=pi):
    low = lower(l, d)
    high = upper(l, d)
    s = 0
    for indA in range(len(l)):
            for indB in range(indA+1, low[indA]+1):
                    s += low[indB] + 1 - high[indA]
    return s

def lower(l, d=pi):
    '''Returns ind, s.t l[ind[i]] < l[i] + d and l[ind[i]+1] >= l[i] + d, for all i
    Input must be sorted!
    '''
    ind = []
    x = 0
    length = len(l)
    for  elem in l:
        while x < length and l[x] < elem + d:
            x += 1
        if l[x-1] < elem + d:
            ind.append(x-1)
        else:
            assert(x == length)
            ind.append(x)
    return ind


def upper(l, d=pi):
    ''' Returns first index where l[i] > l + d'''
    ind = []
    x = 0
    length = len(l)
    for elem in l:
        while x < length and l[x] <= elem + d:
            x += 1
        ind.append(x)
    return ind

Original Problem

The original problem is from a well known math/comp-sci competition. The competition asks that you don't post solutions on the net. But it is from two weeks ago.

I can generate the list with this function:

def points(n):
    x = 1
    y = 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = points(n)
    angles.sort()
    return count(angles, pi)
Unapiedra
  • 15,037
  • 12
  • 64
  • 93
  • Is the list always sorted? – keyser Feb 08 '14 at 15:45
  • Initially it isn't sorted, but we can make it so. – Unapiedra Feb 08 '14 at 15:46
  • @FMc In the real deal, `d=pi` and all elements of `l` will be [-pi, +pi]. I realise this puts an upper bound on `a`. – Unapiedra Feb 08 '14 at 16:03
  • 2
    To me, this is one test: `a < (c - d) < b`. So for a start, you never have to look at `b` that is less than or equal to `a`. That would cut down your comparisons a lot. – hughdbrown Feb 08 '14 at 16:03
  • So you're working with floating point numbers here and not integers? That's a pretty important piece of information – Niklas B. Feb 08 '14 at 16:10
  • @NiklasB. Why is it important that I am working with floats? I'll add it into the question. – Unapiedra Feb 08 '14 at 16:12
  • Also, can you tell us the original problem? Maybe you made the problem harder by reducing it to this general form. – Niklas B. Feb 08 '14 at 16:17
  • 1
    Because for integers in a given range you can use segment trees. You can't do that with floats. In general, for problems like these, provide all the information you have. – Niklas B. Feb 08 '14 at 16:18
  • 1
    Can you link to the problem statement or describe it briefly? I think it's possible to solve this in polylinear time even in the general formulation you gave, but I don't think the contest would require that – Niklas B. Feb 08 '14 at 19:56
  • @PeterdeRivaz: Thanks a lot. Looks indeed like OPs generalization makes the problem harder. I guess the intended approach involves some kind of [inversion](http://en.wikipedia.org/wiki/Inversive_geometry). – Niklas B. Feb 09 '14 at 00:46
  • @OP: I don't think using doubles here will be exact enough to find out the precise number. You will probably need a different approach that allows you to work with integers – Niklas B. Feb 09 '14 at 03:43

5 Answers5

2
from bisect import bisect_left, bisect_right
from collections import Counter

def count(l, d):
    # cdef long bleft, bright, cleft, cright, ccount, s
    s = 0

    # Find the unique elements and their counts
    cc = Counter(l)

    l = sorted(cc.keys())

    # Generate a cumulative sum array
    cumulative = [0] * (len(l) + 1)
    for i, key in enumerate(l, start=1):
        cumulative[i] = cumulative[i-1] + cc[key]

    # Pregenerate all the left and right lookups
    lefthand = [bisect_right(l, a + d) for a in l]
    righthand = [bisect_left(l, a + d) for a in l]

    aright = bisect_left(l, l[-1] - d)
    for ai in range(len(l)):
        bleft = ai + 1
        # Search only the values of a that have a+d in range
        if bleft > aright:
            break
        # This finds b such that a < b < a + d.
        bright = righthand[ai]
        for bi in range(bleft, bright):
            # This finds the range for c such that a+d < c < b+d.
            cleft = lefthand[ai]
            cright = righthand[bi]
            if cleft != cright:
                # Find the count of c elements in the range cleft..cright.
                ccount = cumulative[cright] - cumulative[cleft]
                s += cc[l[ai]] * cc[l[bi]] * ccount
    return s

def testCount():
    l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
    result = count(l, 4)
    assert(32 == result)

testCount()
  1. gets rid of repeated, identical values

  2. iterates over only the required range for a value

  3. uses a cumulative count across two indices to eliminate the loop over c

  4. cache lookups on x + d

This is no longer O(n^3) but more like O(n^2)`.

This clearly does not yet scale up to 2 million. Here are my times on smaller floating point data sets (i.e. few or no duplicates) using cython to speed up execution:

50: 0:00:00.157849 seconds
100: 0:00:00.003752 seconds
200: 0:00:00.022494 seconds
400: 0:00:00.071192 seconds
800: 0:00:00.253750 seconds
1600: 0:00:00.951133 seconds
3200: 0:00:03.508596 seconds
6400: 0:00:10.869102 seconds
12800: 0:00:55.986448 seconds

Here is my benchmarking code (not including the operative code above):

from math import atan2, pi

def points(n):
    x, y = 1, 1
    for _ in range(n):
        x = (x * 1248) % 32323
        y = (y * 8421) % 30103
        yield atan2(x - 16161, y - 15051)

def C(n):
    angles = sorted(points(n))
    return count(angles, pi)

def test_large():
    from datetime import datetime
    for n in [50, 100, 200, 400, 800, 1600, 3200, 6400, 12800]:
        s = datetime.now()
        C(n)
        elapsed = datetime.now() - s
        print("{1}: {0} seconds".format(elapsed, n))

if __name__ == '__main__':
    testCount()
    test_large()
hughdbrown
  • 47,733
  • 20
  • 85
  • 108
2

There is an approach to your problem that yields an O(n log n) algorithm. Let X be the set of values. Now let's fix b. Let A_b be the set of values { x in X: b - d < x < b } and C_b be the set of values { x in X: b < x < b + d }. If we can find |{ (x,y) : A_b X C_b | y > x + d }| fast, we solved the problem.

If we sort X, we can represent A_b and C_b as pointers into the sorted array, because they are contiguous. If we process the b candidates in non-decreasing order, we can thus maintain these sets using a sliding window algorithm. It goes like this:

  1. sort X. Let X = { x_1, x_2, ..., x_n }, x_1 <= x_2 <= ... <= x_n.
  2. Set left = i = 1 and set right so that C_b = { x_{i + 1}, ..., x_right }. Set count = 0
  3. Iterate i from 1 to n. In every iteration we find out the number of valid triples (a,b,c) with b = x_i. To do that, increase left and right as much as necessary so that A_b = { x_left, ..., x_{i-1} } and C_b = { x_{i + 1}, ..., x_right } still holds. In the process, you basically add and remove elements from the imaginary sets A_b and C_b. If you remove or add an element to one of the sets, check how many pairs (a, c) with c > a + d, a from A_b and c from C_b you add or destroy (this can be achieved by a simple binary search in the other set). Update count accordingly so that the invariant count = |{ (x,y) : A_b X C_b | y > x + d }| still holds.
  4. sum up the values of count in every iteration. This is the final result.

The complexity is O(n log n).

If you want to solve the Euler problem with this algorithm, you have to avoid floating point issues. I suggest sorting the points by angle using a custom comparison function that uses integer arithmetics only (using 2D vector geometry). Implementing the |a-b| < d comparisons can also be done using integer operations only. Also, since you are working modulo 2*pi, you would probably have to introduce three copies of every angle a: a - 2*pi, a and a + 2*pi. You then only look for b in the range [0, 2*pi) and divide the result by three.

UPDATE OP implemented this algorithm in Python. Apparently it contains some bugs but it demonstrates the general idea:

def count(X, d):
    X.sort()
    count = 0
    s = 0
    length = len(X)
    a_l = 0
    a_r = 1
    c_l = 0
    c_r = 0
    for b in X:
        if X[a_r-1] < b:
            # find boundaries of A s.t. b -d < a < b
            while a_r < length and X[a_r] < b:
                a_r += 1  # This adds an element to A_b. 
                ind = bisect_right(X, X[a_r-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count += (ind - c_l)
            while a_l < length and X[a_l] <= b - d:
                a_l += 1  # This removes an element from A_b
                ind = bisect_right(X, X[a_l-1]+d, c_l, c_r)
                if c_l <= ind < c_r:
                    count -= (c_r - ind)
            # Find boundaries of C s.t. b < c < b + d
            while c_l < length and X[c_l] <= b:
                c_l += 1  # this removes an element from C_b
                ind = bisect_left(X, X[c_l-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count -= (ind - a_l)
            while c_r  < length and X[c_r] < b + d:
                c_r += 1 # this adds an element to C_b
                ind = bisect_left(X, X[c_r-1]-d, a_l, a_r)
                if a_l <= ind <= a_r:
                    count += (ind - a_l)
            s += count
    return s
Community
  • 1
  • 1
Niklas B.
  • 92,950
  • 18
  • 194
  • 224
  • Not following this because it is not written in python. Re: "...introduce three copies of every angle a: a - pi, a and a + pi." Isn't `a - pi` the same as `a + pi` if this is modulo `2\*pi`? – hughdbrown Feb 10 '14 at 17:20
  • @hughdbrown: First of all, I meant `a - 2*pi` and `a + 2*pi`. Sure, but the algorithm is unaware of modulus so we need to simulate it to use sliding windows. I can't write this in Python right now because I'm on vacation and I also don't think it's necessary to see source code to understand the general idea (rather to the contrary). – Niklas B. Feb 10 '14 at 17:21
  • @hughdbrown: Actually the last paragraph doesn't really apply to the question as asked here, but more to the solution of the underlying Euler Project problem. – Niklas B. Feb 10 '14 at 17:27
  • Okay, so as I understand this, iterate over values of `b` in `O(n)` time and values for `a` and `c` can be selected in `O(logn)` time? I've shown that counts of elements over arbitrary ranges can be calculated in constant time (the cumulative sum lookup table). – hughdbrown Feb 10 '14 at 17:31
  • You iterate over `b` and maintain the valid set of values for `a` and the valid set of values for `c` at any time (that is, those `a` and `c` that are within range `d` of `b`). That's easy. Now we want to now how many `(a,c)` pairs there are at any time that fulfill `c - a > d`. We can maintain that count as well. Every time we insert into the set of `a` candidates, we look at how many `c` candidates in the other set there are with `c - a > d`. We can do that by binary search (log n). Same thing the other way round, if we insert into the set of `c` candidates. Deletion is very similar as well. – Niklas B. Feb 10 '14 at 17:35
  • The key observation is that since `d` is constant and if we process the `b` candidates in non-decreasing order, we only need `O(n)` inserts and deletes into the sets of `a` and `c` candidates. We can process these operations in `O(log n)` as described and get `O(n log n)` runtime (optimal if we assume a comparison-based sort) – Niklas B. Feb 10 '14 at 17:39
  • It always seemed to me that if you fixed 'a' and wanted to find out the sets of candidates 'b' and 'c', you needed to iterate over 'b' because the set of candidates for 'c' was a function of 'a' and 'b' values. That's why I always ended up with `O(n^2)` code. I understand that you can fix 'b' and find a base case for counts of 'a' and 'c', but I'd have thought that you need to iterate over 'a' to get the sets of 'c' for each (a, b) pair. So I'd have thought this is still `O(n^2)`. – hughdbrown Feb 10 '14 at 17:56
  • @hughdbrown: "but I'd have thought that you need to iterate over 'a' to get the sets of 'c' for each (a, b) pair" That assumption is wrong, since if you use a sliding window, you can process the `a`s as they are inserted or deleted from the window. You never care about the "inner" elements not at the front or back of the sliding window. I guess you can also implement something similar by fixing `a` and figuring out the (b,c) counts, but I guess that's less intuitive to implement (choosing `b` gives you a nice symmetry). – Niklas B. Feb 10 '14 at 17:59
  • I mean the algorithm as stated is obviously not n^2, but `O(n log n)` since it does `O(n)` iterations and `O(n)` inserts/deletes overall. The question is whether it's correct and I very much think so – Niklas B. Feb 10 '14 at 18:00
  • @NiklasB. because of the strictness of `<` shouldn't I find the bounds closest to b (i.e. `x_{i+1}`, `x_{i-1}` ) as well? Otherwise, we'd have `x_{i+1} == x_i` quite often. – Unapiedra Feb 12 '14 at 15:38
  • @Unapiedra Yes you're right, you need to keep 4 pointers or merge duplicates into a single point with an associated count. It shouldn't make a difference though unless the test set contains 3 colinear points on a line through the origin. In this case the statement is ambiguous because it doesn't mention whether degenerste triangles should be counted – Niklas B. Feb 12 '14 at 16:55
  • @NiklasB. I implemented this. See my answer for more. – Unapiedra Feb 14 '14 at 13:23
  • nice, I added the implementation to the answer. @hughbrown now there's python code ;) – Niklas B. Feb 14 '14 at 15:53
1

Since l is sorted and a < b < c must be true, you could use itertools.combinations() to do fewer loops:

sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)

Looking at combinations only reduces this loop to 816 iterations.

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> d = 4
>>> sum(1 for a, b, c in combinations(l, r=3))
816
>>> sum(1 for a, b, c in combinations(l, r=3) if a < b < a + d < c < b + d)
32

where the a < b test is redundant.

Martijn Pieters
  • 1,048,767
  • 296
  • 4,058
  • 3,343
1

1) To reduce amount of iterations on each level you can remove elements from list that dont pass condition on each level
2) Using set with collections.counter you can reduce iterations by removing duplicates:

from collections import Counter
def count(l, d):
    n = Counter(l)
    l = set(l)
    s = 0
    for a in l:
        for b in (i for i in l if a < i < a+d):
            for c in (i for i in l if a+d < i < b+d):
                s += (n[a] * n[b] * n[c])
    return s

>>> l = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 8, 9, 9, 10, 10]
>>> count(l, 4)
32

Tested count of iterations (a, b, c) for your version:

>>> count1(l, 4)
18 324 5832

my version:

>>> count2(l, 4)
9 16 7
ndpu
  • 22,225
  • 6
  • 54
  • 69
  • The important thing is to count how many inner iterations you are doing versus the original code. – hughdbrown Feb 08 '14 at 16:07
  • And: `i for i in l if a < (i-d) < b` – hughdbrown Feb 08 '14 at 16:08
  • @hughdbrown number of inner iterations is fully depends on list contents... both algorithms is O(n^3) but my version should be faster – ndpu Feb 08 '14 at 16:20
  • And still both are too slow. O(n^(2-eps) is needed at least. – Niklas B. Feb 08 '14 at 16:28
  • @NiklasB. did you know better solution? why -1? – ndpu Feb 08 '14 at 16:32
  • OP asks for a faster algorithm, you give basically the same algorithm with some constant factor optimizations... – Niklas B. Feb 08 '14 at 16:44
  • your generators are still going to loop through all of `l`, you are *not* eliminating loops here. – Martijn Pieters Feb 08 '14 at 16:47
  • @MartijnPieters i alredy tested it and posted result: my version innerest loop iterations: 32, OP version: 5832 – ndpu Feb 08 '14 at 16:49
  • But you are *adding* loops on the form of generators. – Martijn Pieters Feb 08 '14 at 16:54
  • @ndpu 2 million list elements is what op stated. So something around 4e12 speedup is required... – Niklas B. Feb 08 '14 at 16:59
  • This is better than OP's solution. Adding the generators considerably cuts down on number or comparisons that must be made. It may not be possible to solve this in better than n^2, assuming an n^2 solution exists. – IceArdor Feb 09 '14 at 18:42
  • @IceArdor: `O(n^2 log n)` is trivial. You just fix `a` and `b` and binary search on `c`. That's already asymptotically better than most answers here... I think `O(n log^2 n)` or something is also possible, but non-trivial – Niklas B. Feb 10 '14 at 16:01
  • @IceArdor: check my answer for an example of a `O(n log n)` algorithm – Niklas B. Feb 10 '14 at 17:28
0

The basic ideas are:

  1. Get rid of repeated, identical values
  2. Have each value iterate only over the range it has to iterate.

As a result you can increase s unconditionally and the performance is roughly O(N), with N being the size of the array.

import collections

def count(l, d):
    s = 0
    # at first we get rid of repeated items
    counter = collections.Counter(l)
    # sort the list
    uniq = sorted(set(l))
    n = len(uniq)
    # kad is the index of the first element > a+d
    kad = 0 
    # ka is the index of a
    for ka in range(n):
        a = uniq[ka]
        while uniq[kad] <= a+d:
            kad += 1
            if kad == n:
                return s

        for kb in range( ka+1, kad ):
            # b only runs in the range [a..a+d)
            b = uniq[kb]
            if b  >= a+d:
                break
            for kc in range( kad, n ):
                # c only rund from (a+d..b+d)
                c = uniq[kc]
                if c >= b+d:
                    break
                print( a, b, c )
                s += counter[a] * counter[b] * counter[c]
    return s

EDIT: Sorry, I messed up the submission. Fixed.

pentadecagon
  • 4,717
  • 2
  • 18
  • 26
  • 1
    This is definitely not O(n). – Unapiedra Feb 08 '14 at 17:23
  • it depends on what you take for n. It is O(M), with M being the final result. – pentadecagon Feb 08 '14 at 17:24
  • 3
    That is a peculiar approach to algorithm estimation. Fibonacci calculation produces a single number regardless of implementation, but that is not to say that all implementations are `O(1)`, right? Looking at the number of outputs is not an appropriate way to estimate the running time. Looking at the size of the input is the more usual approach. – hughdbrown Feb 09 '14 at 02:49
  • In this case here the complexity estimation tells you there is no counting-based algorithm much faster than the one described here. – pentadecagon Feb 09 '14 at 05:18
  • @pentadecagon both Niklas B and hughdbrown have algorithms that are much, much faster. – Unapiedra Feb 14 '14 at 13:11