13

I have a file containing data in the format:

0.0 x1
0.1 x2
0.2 x3
0.0 x4
0.1 x5
0.2 x6
0.3 x7
...

The data consists of multiple datasets, each starting with 0 in the first column (so x1,x2,x3 would be one set and x4,x5,x6,x7 another one). I need to plot each dataset separately so I need to somehow split the data. What would be the easiest way to accomplish this?

I realize I could go through the data line-by-line and split the data every time I encounter a 0 in the first column but this seems very inefficient.

pafcu
  • 7,808
  • 12
  • 42
  • 55

4 Answers4

27

I actually liked Benjamin's answer, a slightly shorter solution would be:

B= np.split(A, np.where(A[:, 0]== 0.)[0][1:])
eat
  • 7,440
  • 1
  • 19
  • 27
  • 4
    If there is one thing I know for sure, it's that no matter what you write in Python, there is always a shorter way of doing it! – Benjamin Mar 11 '11 at 15:32
  • 2
    @bafcu: I honestly think the credit really should go to Benjamin (instead to me). I was merely 'fine tuning' his answer. Thanks – eat Mar 11 '11 at 19:28
16

Once you have the data in a long numpy array, just do:

import numpy as np

A = np.array([[0.0, 1], [0.1, 2], [0.2, 3], [0.0, 4], [0.1, 5], [0.2, 6], [0.3, 7], [0.0, 8], [0.1, 9], [0.2, 10]])
B = np.split(A, np.argwhere(A[:,0] == 0.0).flatten()[1:])

which will give you B containing three arrays B[0], B[1] and B[2] (in this case; I added a third "section" to prove to myself that it was working correctly).

Benjamin
  • 11,560
  • 13
  • 70
  • 119
1

You don't need a python loop to evaluate the locations of each split. Do a difference on the first column and find where the values decrease.

import numpy

# read the array
arry = numpy.fromfile(file, dtype=('float, S2'))

# determine where the data "splits" shoule be
col1 = arry['f0']
diff = col1 - numpy.roll(col1,1)
idxs = numpy.where(diff<0)[0]

# only loop thru the "splits"
strts = idxs
stops = list(idxs[1:])+[None]
groups = [data[strt:stop] for strt,stop in zip(strts,stops)]
Paul
  • 42,322
  • 15
  • 106
  • 123
0
def getDataSets(fname):
    data_sets = []
    data = []
    prev = None
    with open(fname) as inf:
        for line in inf:
            index,rem = line.strip().split(None,1)
            if index < prev:
                data_sets.append(data)
                data = []
            data.append(rem)
            prev = index
        data_sets.append(data)
    return data_sets

def main():
    data = getDataSets('split.txt')
    print data

if __name__=="__main__":
    main()

results in

[['x1', 'x2', 'x3'], ['x4', 'x5', 'x6', 'x7']]
Hugh Bothwell
  • 55,315
  • 8
  • 84
  • 99