I am trying to translate a Matlab "interpn" interpolation of a large, 4D array, but formulations diverge significantly between Matlab and Python. There's a good question/answer from several years ago here that I've been trying to work with. I think I am almost there, but apparently still don't have my grid interpolator properly formulated.
I modeled my code example after the example given in the linked answer above as closely as I could, while using the dimensions I am actually working in. The only change is that I switched rollaxis for moveaxis, as the former is deprecated.
Essentially, given the 4D array skyrad0 (which is dependent on the four elements defined in the first code block) together with two constants and two 1D arrays defined in the third block, I want the interpolated 2D result.
from scipy.interpolate import interpn
import numpy as np
# Define the data space in the 4D skyrad0 array
solzen = np.arange(0,70,10) # 7
aod = np.arange(0,0.25,0.05) # 5
index = np.arange(1,92477,1) # 92476
wave = np.arange(350,1050,5) # 140
# Simulated skyrad for the values above
skyrad0 = np.random.rand(
solzen.size,aod.size,index.size,wave.size) # 7, 5, 92476, 140
# Data space for desired output values of skyrad
# with interpolation between input data space
solzen0 = 30 # 1
aod0 = 0.1 # 1
index0 = index # 92476
wave0 = np.arange(350,1050,10) # 70
# Matlab
# result = squeeze(interpn(solzen, aod, index, wave,
# skyrad0,
# solzen0, aod0, index0, wave0))
# Scipy
points = (solzen, aod, index, wave) # 7, 5, 92476, 140
interp_mesh = np.array(
np.meshgrid(solzen0, aod0, index0, wave0)) # 4, 1, 1, 92476, 70
interp_points = np.moveaxis(interp_mesh, 0, -1) # 1, 1, 92476, 70, 4
interp_points = interp_points.reshape(
(interp_mesh.size // interp_mesh.shape[3],
interp_mesh.shape[3])) # 280, 92476
result = interpn(points, skyrad0, interp_points)
I am expecting a 4D array "result" that I can numpy.squeeze into the 2D answer I need, but the interpn yields the error:
ValueError: The requested sample points xi have dimension 92476, but this RegularGridInterpolator has dimension 4
Where I am foggiest on this example is the structure of the meshgrid of query points, and the moving of the first dimension to the end and reshaping it. There is more on that here, but it's still not clear to me how to apply that to this problem.
A bonus would be if anyone can identify clear inefficiencies in my formulation. I'll need to run this type of interpolation thousands of times on a number of different structures -- even extending to 6D -- so efficiency is important.
Update The answer below solved the problem very elegantly. However, as the calculations and arrays become more complicated, another problem creeps in, namely what appears to be a problem with elements in the array that do not increase monotonically. Here is the problem reframed in 6D:
# Data space in the 6D rad_boa array
azimuth = np.arange(0, 185, 5) # 37
senzen = np.arange(0, 185, 5) # 37
wave = np.arange(350,1050,5) # 140
# wave = np.array([350, 360, 370, 380, 390, 410, 440, 470, 510, 550, 610, 670, 750, 865, 1040, 1240, 1640, 2250]) # 18
solzen = np.arange(0,65,5) # 13
aod = np.arange(0,0.55,0.05) # 11
wind = np.arange(0, 20, 5) # 4
# Simulated rad_boa
rad_boa = np.random.rand(
azimuth.size,senzen.size,wave.size,solzen.size,aod.size,wind.size,) # 37, 37, 140/18, 13, 11, 4
azimuth0 = 135 # 1
senzen0 = 140 # 1
wave0 = np.arange(350,1010,10) # 66
solzen0 = 30 # 1
aod0 = 0.1 # 1
wind0 = 10 # 1
da = xr.DataArray(name='Radiance_BOA',
data=rad_boa,
dims=['azimuth','senzen','wave','solzen','aod','wind'],
coords=[azimuth,senzen,wave,solzen,aod,wind])
rad_inc_scaXR = da.loc[azimuth0,senzen0,wave0,solzen0,aod0,wind0].squeeze()
As it stands, it runs, but if you change the definition of wave to the commented line, it throws the error:
KeyError: "not all values found in index 'wave'"
Finally, in response to a comment below (and to help boost efficiency), I am including the structure of the HDF5 file (created in Matlab) from which this "rad_boa" 6D array is actually built (this example above only uses a simulated random array). The actual database is read into Xarray as follows:
sdb = xr.open_dataset(db_path, group='sdb')