2

So I have a dataset consisting of several million rows of trading data that I am trying to apply a filter to as in the following code. The function trade_quant_filter looks for outliers and then adds the index of all the outliers to a list to deal with later.

def trim_moments(arr, alpha):
    np.sort(arr)
    n = len(arr)
    k = int(round(n*float(alpha))/2)
    return np.mean(arr[k+1:n-k]), np.std(arr[k+1:n-k])


def trade_quant_filter(dataset, alpha, window, gamma):
    radius = int(round(window /2))
    bad_values = []
    for count, row in dataset.iterrows():
         if count < radius: # Starting case when we can't have a symmetric radius 
            local_mean, local_std = trim_moments(
                    dataset['price'][: count + window].values,alpha)

            if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma:
                bad_values.append(count)

         elif count > (dataset.shape[0] - radius): # 2
            local_mean, local_std = trim_moments(
                    dataset['price'][count - window: count].values,alpha) 

            if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma:
                bad_values.append(count)       

         else:
            local_mean, local_std = trim_moments(
                    dataset['price'][count - radius: count + radius].values,alpha)

            if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma: #4
                bad_values.append(count)

    return bad_values

The problem is that I've written this code too poorly to deal with several million entries. 150k rows takes about 30 seconds:

stats4 = %prun -r trade_quant_filter(trades_reduced[:150000], alpha,window,gamma)

 Ordered by: internal time
   List reduced from 154 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   600002    3.030    0.000    3.030    0.000 {method 'reduce' of 'numpy.ufunc' objects}
   150000    2.768    0.000    4.663    0.000 _methods.py:73(_var)
        1    2.687    2.687   40.204   40.204 <ipython-input-102-4f16164d899e>:8(trade_quant_filter)
   300000    1.738    0.000    1.738    0.000 {pandas.lib.infer_dtype}
  6000025    1.548    0.000    1.762    0.000 {isinstance}
   300004    1.481    0.000    6.937    0.000 internals.py:1804(make_block)
   300001    1.426    0.000   13.157    0.000 series.py:134(__init__)
   300000    1.033    0.000    3.553    0.000 common.py:1862(_possibly_infer_to_datetimelike)
   150000    0.945    0.000    2.562    0.000 _methods.py:49(_mean)
   300000    0.902    0.000   12.220    0.000 series.py:482(__getitem__)

There are a couple of things that make optimizing this function challenging:

  1. As far as I can tell, there is no way to avoid iterating row by row here and still take the trimmed rolling means and standard deviation. I plan on looking into how functions like rolling_mean are implemented in Pandas next.
  2. The lack of hashability of dictionaries also makes it impossible to calculates the trimmed rolling means and standard deviations so I can't convert the dataframe to a dict.

Using Cython and NDarrays as recommended here seems like a possibility and I'm learning about Cython now.

What is the most straightforward way of optimizing this code? I'm looking for at least a 10x speed improvement.

Community
  • 1
  • 1
FoxRocks
  • 68
  • 6
  • 2
    I might be missing something, but this seems like it could be fully vectorized with `rolling_mean` and `rolling_std`? Then you could simply filter the entire datafram on your `bad_values` criteria in one step. – chrisb Aug 22 '14 at 18:54
  • Sorry that I was unclear. `trimmed_mean` varies from `rolling_mean` in that the most extreme values aren't included in the mean calculation. This is done here: `np.mean(arr[k+1:n-k])`. The idea being that outliers shouldn't be affecting the mean. Using `rolling_median` would actually be a suitable alternative in this case but there's no built-in function that has a [robust](http://en.wikipedia.org/wiki/Robust_measures_of_scale) standard deviation. – FoxRocks Aug 22 '14 at 18:57
  • 3
    why not do `rolling_apply` then? – Paul H Aug 22 '14 at 18:58
  • I was just going to suggest rolling_apply as well. – DataSwede Aug 22 '14 at 19:00
  • Thanks Paul. I think this will solve my problem. I'm done with the project for the day but that looks extremely promising! I'm surprised I missed it. – FoxRocks Aug 22 '14 at 19:01

0 Answers0