2

Basically I have some gridded meteorological data with dimensions (time, lat, lon).

  1. I need to go through each timeseries at each gridsquare, identify consecutive days ("events") when the variable is above a threshold and store that to a new variable (THdays)
  2. Then I look through the new variable and find the events which are longer than a certain duration (THevents)

Currently I have a super scrappy (non-vectorised) iteration and I'd appreciate your advice on how to speed this up. Thanks!

import numpy as np
import itertools as it
##### Parameters
lg = 2000  # length to initialise array (must be long to store large number of events)
rl = 180  # e.g latitude
cl = 360  # longitude
pcts = [95, 97, 99] # percentiles which are the thresholds that will be compared
dt = [1,2,3] #duration thresholds, i.e. consecutive values (days) above threshold

##### Data
data   # this is the gridded data that is (time,lat,lon) , e.g. data = np.random.rand(1000,rl,cl)
# From this data calculate the percentiles at each gridsquare (lat/lon combination) which will act as our thresholds
histpcts = np.percentile(data, q=pcts, axis = 0)


##### Initialize arrays to store the results
THdays = np.ndarray((rl, cl, lg, len(pcts)), dtype='int16') #Array to store consecutive threshold timesteps
THevents = np.ndarray((rl,cl,lg,len(pcts),len(dt)),dtype='int16')

##### Start iteration to identify events
for p in range(len(pcts)):  # for each threshold value
    br = data>histpcts[p,:,:]  # Make boolean array where data is bigger than threshold

    # for every lat/lon combination
    for r,c in it.product(range(rl),range(cl)): 
        if br[:,r,c].any()==True: # This is to skip timeseries with nans only and so the iteration is skipped. Important to keep this or something that ignores an array of nans
            a = [ sum( 1 for _ in group ) for key, group in it.groupby( br[:,r,c] ) if key ] # Find the consecutive instances
            tm = np.full(lg-len(a), np.nan)   # create an array of nans to fill in the rest


            # Assign to new array
            THdays[r,c,0:len(a),p] = a  # Consecutive Thresholds days
            THdays[r,c,len(a):,p] = tm  # Fill the rest of array

            # Now cycle through and identify events 
            # (consecutive values) longer than a certain duration (dt)
            for d in range(len(dt)):
                b = THdays[r,c,THdays[r,c,:,p]>=dt[d],p]
                THevents[r,c,0:len(b),p,d] = b
dreab
  • 705
  • 3
  • 12
  • 22
  • Just work from the inside out. You have a triple nested for loop. Just extract the inner loop into a function and figure out how to vectorize that. Break the problem down into manageable pieces, same as any other programming task. Here are some previous ones I've done along the same lines: https://stackoverflow.com/questions/17529342 and https://stackoverflow.com/questions/36853770 and https://stackoverflow.com/questions/26251997 – John Zwinck Jul 21 '17 at 11:42
  • @JohnZwinck thanks ill take a look – dreab Jul 21 '17 at 12:40

1 Answers1

0

Have you tried numba? It will grant you a great speedup when you are using simple loops in numpy. All you need to do it is to put your code inside a function and apply the decorator @jit to decorate the fucntion. That's all!!!

@jit
def myfun(inputs):
    ## crazy nested loops here

Of course more information you give numba better speedup you will get: you find more information here: http://numba.pydata.org/numba-doc/0.34.0/user/overview.html

Gioelelm
  • 2,645
  • 5
  • 30
  • 49
  • If you want optimal speedup you have to use `njit()` instead of `jit()`, and then you have to deal with some limitations. OP's code won't be magically transformed into optimal code. Maybe faster, but not optimal. – John Zwinck Jul 22 '17 at 01:37
  • I agree. But as a way to get started with numba jit is better. Also because the error messages of njit will be difficult to interpret the first time you stumble upon them. – Gioelelm Jul 22 '17 at 09:40
  • Thanks @gioelelm and @JohnZwinck , I tried `jit` and `njit` and then stuck on the error messages. any idea what this means? `raise patched_exception LoweringError: make_function(closure=None, code= at 00000000180D87B0, file "", line 24>, name=None, defaults=None)` – dreab Jul 22 '17 at 14:04
  • 1
    And try to refactoring everything to use numpy array (instead of lists) and numpy functions (e.g. np.sum instead of sum) – Gioelelm Jul 22 '17 at 14:08