I want to extend the structured array object in numpy such that I can easily add new elements.
For example, for a simple structured array
>>> import numpy as np
>>> x=np.ndarray((2,),dtype={'names':['A','B'],'formats':['f8','f8']})
>>> x['A']=[1,2]
>>> x['B']=[3,4]
I would like to easily add a new element x['C']=[5,6]
, but then an error appears associated to the undefined name 'C'
.
Just adding a new method to np.ndarray
works:
import numpy as np
class sndarray(np.ndarray):
def column_stack(self,i,x):
formats=['f8']*len(self.dtype.names)
new=sndarray(shape=self.shape,dtype={'names':list(self.dtype.names)+[i],'formats':formats+['f8']})
for key in self.dtype.names:
new[key]=self[key]
new[i]=x
return new
Then,
>>> x=sndarray((2,),dtype={'names':['A','B'],'formats':['f8','f8']})
>>> x['A']=[1,2]
>>> x['B']=[3,4]
>>> x=x.column_stack('C',[4,4])
>>> x
sndarray([(1.0, 3.0, 4.0), (2.0, 4.0, 4.0)],
dtype=[('A', '<f8'), ('B', '<f8'), ('C', '<f8')])
Is there any way that the new element could be added in a dictionary-like way?, e.g
>>> x['C']=[4,4]
>>> x
sndarray([(1.0, 3.0, 4.0), (2.0, 4.0, 4.0)],
dtype=[('A', '<f8'), ('B', '<f8'), ('C', '<f8')])
Update:
By using __setitem__
I am still one step away from the ideal solution because I don't know how:
change the object referenced at self
import numpy as np
class sdarray(np.ndarray):
def __setitem__(self, i,x):
if i in self.dtype.names:
super(sdarray, self).__setitem__(i,x)
else:
formats=['f8']*len(self.dtype.names)
new=sdarray(shape=self.shape,dtype={'names':list(self.dtype.names)+[i],'formats':formats+['f8']})
for key in self.dtype.names:
new[key]=self[key]
new[i]=x
self.with_new_column=new
Then
>>> x=sndarray((2,),dtype={'names':['A','B'],'formats':['f8','f8']})
>>> x['A']=[1,2]
>>> x['B']=[3,4]
>>> x['C']=[4,4]
>>> x=x.with_new_column #extra uggly step!
>>> x
sndarray([(1.0, 3.0, 4.0), (2.0, 4.0, 4.0)],
dtype=[('A', '<f8'), ('B', '<f8'), ('C', '<f8')])
Update 2
After the right implementation in the selected answer, I figure out that the problem is already solved by pandas
DataFrame
object:
>>> import numpy as np
>>> import pandas as pd
>>> x=np.ndarray((2,),dtype={'names':['A','B'],'formats':['f8','f8']})
>>> x=pd.DataFrame(x)
>>> x['A']=[1,2]
>>> x['B']=[3,4]
>>> x['C']=[4,4]
>>> x
A B C
0 1 3 4
1 2 4 4
>>>