1

I am implementing a segment tree in cython and comparing it with the python implementation.

The cython version seems only 1.5x times faster and I am wanting to make it even faster.

Both implementations can be assumed to be correct.

Here's the cython code:

# distutils: language = c++
from libcpp.vector cimport vector

cdef struct Result:
    int range_sum  
    int range_min 
    int range_max



cdef class SegmentTree:
    cdef vector[int] nums
    cdef vector[Result] tree 

    def __init__(self, vector[int] nums):
        self.nums = nums
        self.tree.resize(4 * len(nums)) #just a safe upper bound 
        self._build(1, 0, len(nums)-1)

    cdef Result _build(self, int index, int left, int right):
        cdef Result result

        if left == right:
            value = self.nums[left]
            result.range_max, result.range_min, result.range_sum = value, value, value 
            self.tree[index] = result
            return self.tree[index]
        else:
            mid = (left+right)//2
            left_range_result = self._build(index*2, left, mid)
            right_range_result = self._build(index*2+1, mid+1, right)
            self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
            return self.tree[index]

    cdef Result range_query(self, int query_i, int query_j):
        return self._range_query(query_i, query_j, 0, len(self.nums)-1, 1)

    cdef Result _range_query(self, int query_i, int query_j, int current_i, int current_j, int index):
        if current_i == query_i and current_j == query_j:
            return self.tree[index]
        else:
            mid = (current_i + current_j)//2 
            if query_j <= mid:
                return self._range_query(query_i, query_j, current_i, mid, index*2)
            elif mid < query_i:
                return self._range_query(query_i, query_j, mid+1, current_j, index*2+1 )  
            else:
                left_range_result = self._range_query(query_i, mid, current_i, mid, index*2)
                right_range_result = self._range_query(mid+1, query_j, mid+1, current_j, index*2+1)
                return self.combine_range_results(left_range_result, right_range_result)


    cpdef int range_sum(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_sum 
    cpdef int range_min(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_min
    cpdef int range_max(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_max

    cpdef void  update(self, int i, int new_value):
        self._update(i, new_value, 1, 0, len(self.nums)-1)

    cdef Result _update(self, int i, int new_value, int index, int left, int right):
        if left == right == i:
            self.tree[index] = [new_value, new_value, new_value]
            return self.tree[index]
        if left == right:
            return self.tree[index]
        mid = (left+right)//2 
        left_range_result = self._update(i, new_value, index*2, left, mid)
        right_range_result = self._update(i, new_value, index*2+1, mid+1, right)
        self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
        return self.tree[index]

    cdef Result combine_range_results(self, Result r1, Result r2):
        cdef Result result;
        result.range_min = min(r1.range_min, r2.range_min)
        result.range_max = max(r1.range_max, r2.range_max)
        result.range_sum = r1.range_sum + r2.range_sum
        return result 
        

Here's the python version:




class PurePythonSegmentTree:
    def __init__(self, nums):
        self.nums = nums
        self.tree = [0] * (len(nums) * 4)
        self._build(1, 0, len(nums) - 1)

    def _build(self, index, left, right):
        if left == right:
            value = self.nums[left]
            self.tree[index] = (value, value, value)
            return self.tree[index]
        else:
            mid = (left + right) // 2
            left_range_result = self._build(index * 2, left, mid)
            right_range_result = self._build(index * 2 + 1, mid + 1, right)
            self.tree[index] = self._combine_range_results(
                left_range_result, right_range_result)
            return self.tree[index]

    def range_query(self, query_i, query_j):
        return self._range_query(query_i, query_j, 0, len(self.nums) - 1, 1)

    def _range_query(self, query_i, query_j, current_i, current_j, index):
        if current_i == query_i and current_j == query_j:
            return self.tree[index]
        else:
            mid = (current_i + current_j) // 2
            if query_j <= mid:
                return self._range_query(query_i, query_j, current_i, mid,
                                         index * 2)
            elif mid < query_i:
                return self._range_query(query_i, query_j, mid + 1, current_j,
                                         index * 2 + 1)
            else:
                left_range_result = self._range_query(query_i, mid, current_i,
                                                      mid, index * 2)
                right_range_result = self._range_query(mid + 1, query_j,
                                                       mid + 1, current_j,
                                                       index * 2 + 1)
                return self._combine_range_results(left_range_result,
                                                   right_range_result)

    def range_sum(self, query_i, query_j):
        return self.range_query(query_i, query_j)[0]

    def range_min(self, query_i, query_j):
        return self.range_query(query_i, query_j)[1]

    def range_max(self, query_i, query_j):
        return self.range_query(query_i, query_j)[2]

    def _combine_range_results(self, r1, r2):
        return (r1[0] + r2[0], min(r1[1], r2[1]), max(r1[2], r2[2]))


The benchmarking code:

import pytest
from segment_tree import SegmentTree

def _test_all_ranges(nums, correct_fn, test_fn, threshold=float("inf")):
    count = 0
    for i in range(len(nums)):
        for j in range(i + 1, len(nums)):
            if count > threshold:
                break
            expected = correct_fn(nums[i:j + 1])
            actual = test_fn(i, j)
            assert actual == expected
            count += 1


def test_cython_tree_speed(benchmark):
    nums = [i for i in range(1000)]

    @benchmark
    def foo():
        s = SegmentTree(nums)
        _test_all_ranges(nums, max, s.range_max, 20)


def test_python_tree_speed(benchmark):
    nums = [i for i in range(1000)]

    @benchmark
    def foo():
        s = PurePythonSegmentTree(nums)
        _test_all_ranges(nums, max, s.range_max, 20)

The stats:

-------------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------------
Name (time in us)                 Min                   Max                  Mean              StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_cython_tree_speed       708.0450 (1.0)      1,534.6150 (1.0)        739.7052 (1.0)       59.9436 (1.0)        717.7565 (1.0)      21.0070 (1.0)       116;200  1,351.8900 (1.0)        1290           1
test_python_tree_speed     1,625.1940 (2.30)     2,676.9020 (1.74)     1,696.8420 (2.29)     135.9121 (2.27)     1,644.7810 (2.29)     79.6613 (3.79)        36;37    589.3300 (0.44)        391           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

How do I make the cythonized version faster?

nz_21
  • 6,140
  • 7
  • 34
  • 80
  • 1
    ( this might belong on https://codereview.stackexchange.com/ ) Without actually looking at your code, I know this is often an issue in python. Have you tried replacing the recursive calls with a stack? That will handle the logic of the recursion without actually making function calls. Here's the best existing answer I found on a brief search: https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion – Kenny Ostrom Jun 20 '20 at 13:38
  • 2
    I assume you've tried generating Cython's annotated HTML to see what it thinks might be slow. I wonder if the recursive functions are a bit of a red herring and aren't actually your problem. – DavidW Jun 20 '20 at 14:16
  • @KennyOstrom I am not convinced that the bottleneck is the recursive calls. The cython version of fibonacci is magnitudes faster than the python version https://blog.nelsonliu.me/2016/04/29/gsoc-week-0-cython-vs-python/ It stands to reason it should more or less or true for this too, given that it's recursion heavy – nz_21 Jun 20 '20 at 14:17
  • I'd be worried that your calls to `len` on a vector are causing it to be converted to a list then `len` to be called on that, for example. – DavidW Jun 20 '20 at 14:19
  • @DavidW ah good spot. I changed it to `nums.size()` (which is constant time) but the stats remain almost the same. – nz_21 Jun 20 '20 at 14:21
  • To really shine, cython needs input as memoryview/(numpy) array. I didn’t check, but my guess is that the bottle neck of cython’s version is converting list into std::vector. – ead Jun 20 '20 at 15:06
  • @ead Thanks for the suggestion - are there any easy ways to check for bottlenecks? – nz_21 Jun 20 '20 at 15:09
  • Also, I just tweaked the benchmark test so that it would only test the query calls, not the constructor invocation. The result shows only a very slight improvement – nz_21 Jun 20 '20 at 15:13
  • DavidWs suggestion to build with annotations is a good first step – ead Jun 20 '20 at 15:13
  • 1
    Have you tried using compiler directives? In your case, that you are doing a division, I recommend adding @cython.cdivision(False) before the definition of your class. See more details here https://stackoverflow.com/questions/19537673/slow-division-in-cython – sebacastroh Jun 21 '20 at 16:30

1 Answers1

3

When trying to optimize cython code, first step is to build with annotations (see for example this part of Cython-documentation), i.e.

 cython -a xxx.pyx

or similar. It generates a html, in which one can see which parts of code use Python-functionality.

In you case one can see, that mid = (current_i + current_j)//2 is a problem.

It generates the following C-code:

  /*else*/ {
    __pyx_t_3 = __Pyx_PyInt_From_long(__Pyx_div_long((__pyx_v_current_i + __pyx_v_current_j), 2)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 42, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_3);
    __pyx_v_mid = __pyx_t_3;
    __pyx_t_3 = 0;

I.e. mid is a Python-integer (due to __Pyx_PyInt_From_long), and all operation with it will lead to more conversion to Python-integer and slow operations.

Make mid cdef int. Investigate other yellow lines (interaction with Python) in the annotated code.

ead
  • 32,758
  • 6
  • 90
  • 153
  • Thanks for the pointer! I have fixed that, and I notice that the `range_sum`, `range_min` and `range_max` functions are dark yellow, indicating heavy python interaction. Is there any way I can fix those? – nz_21 Jun 20 '20 at 15:31
  • Make them cpdef which declare they return type (int), the same for combine-function. Cython will pick the cdef function without the conversion of the result to python integer which is costly – ead Jun 20 '20 at 15:37
  • Sorry I looked up the wrong version. Don’t worry then: because they are partly def, python interaction is needed. Concentrate on the code first. How does benchmark looks like? – ead Jun 20 '20 at 15:46
  • they are almost the same, cython is about 1.18x faster – nz_21 Jun 20 '20 at 15:47
  • Getting some improvement! Instead of storing the input list a vector, I tweaked the code so that it just stores its length. For the same benchmark, it's about 4.5x faster – nz_21 Jun 20 '20 at 15:55
  • @nz_21 For some of your calculations, you are writing to intermediate variables (e.g. left_range_result). These should be given cdef int types as well, otherwise you are going to create python objects within your function calls. – CodeSurgeon Jun 28 '20 at 18:54