I am doing 3D interpolation and I would like to get the function definition as output so that I can use it in gradient descent method optimization. My data set is of the form
x y z loss
val1 val2 val3 val4
I would like to get a function definition from my interpolation of the form f(x,y,z)= loss. I need it in an explicit form like this so that I can do optimization using the function. Any assistance is very greatly appreciated.
So far my code is as follows:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.interpolate import griddata as gd
x = electrode2["Base Gap"].values
y = electrode2["Top Gap"].values
z = electrode2["SiO2 Thickness"].values
out = electrode2["Loss"].values
X, Y, Z=np.mgrid[0:8:20j, 0:7:20j, 0:2:20j]
V = gd((x,y,z), out, (X.flatten(),Y.flatten(),Z.flatten()), method='linear')
#Plotting original data
fig1 = plt.figure()
ax1=fig1.gca(projection='3d')
sc1=ax1.scatter(x, y, z, c=out, cmap=plt.hot())
plt.colorbar(sc1)
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
#Plotting interpolated data
fig2 = plt.figure()
ax2=fig2.gca(projection='3d')
sc2=ax2.scatter(X, Y, Z, c=V, cmap=plt.hot())
plt.colorbar(sc2)
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
How can I get a function definition f(x,y,z) = loss from this? I found some suggestions here Get the formula of a interpolation function created by scipy but they are all for simple 1D interpolation and they don't seem to suggest that I can find a simple formula definition. I'm also open to other suggestions on how I can perform optimization without an explicit function definition. Any thoughts are very appreciated.
Also, how can I check the validity of my interpolation? My original data was nearly grid structured but not entirely so I'm hoping this method still works. I'm having some trouble visualizing it because it is a 4D result. My scatter plots gave the following which I think looks good but I'm not entirely sure: Scatter Plot of Original Data Scatter Plot of Interpolated Data
Also, here is some of my data (not all) so that you can get a better feel for the type of values I'm using:
|x | y | z | val|
|------|------------|---------|------------|
| 4.0 | 2.5 | 0.7 | 0.7503418140000001 |
| 4.0 | 2.75 | 0.7 | 0.64151014 |
| 4.0 | 3.0 | 0.7 | 0.6229984510000001 |
| 4.5 | 2.3 | 0.7 | 0.730188891 |
| 4.5 | 2.6 | 0.7 | 0.236956251 |
| 4.5 | 3.2 | 0.7 | 0.092038571 |
| 4.5 | 3.5 | 0.7 | 0.08848320300000001 |
| 5.0 | 2.25 | 0.7 | 0.9488792770000001 |
| 5.0 | 2.6 | 0.7 | 0.190624075 |
| 5.0 | 2.95 | 0.7 | 0.135275036 |
| 5.0 | 3.3 | 0.7 | 0.032763743 |
| 5.0 | 3.65 | 0.7 | 0.029430211 |
| 5.0 | 4.0 | 0.7 | 0.027537075 |
| 5.5 | 2.5 | 0.7 | 0.281232137 |
| 5.5 | 2.9 | 0.7 | 0.054980707000000004 |
| 5.5 | 3.3 | 0.7 | 0.023085487999999998 |
| 5.5 | 3.7 | 0.7 | 0.031317836 |
| 5.5 | 4.1 | 0.7 | 0.010820878 |
| 5.5 | 4.5 | 0.7 | 0.01016186 |
| 6.0 | 2.3 | 0.7 | 0.397089577 |
| 6.0 | 2.75 | 0.7 | 0.074292346 |
| 6.0 | 3.2 | 0.7 | 0.015918433 |
| 6.0 | 3.65 | 0.7 | 0.004494633 |
| 6.0 | 4.1 | 0.7 | 0.002195262 |
| 6.0 | 4.55 | 0.7 | 0.00175018 |
| 6.0 | 5.0 | 0.7 | 0.0017010179999999999 |
| 6.5 | 2.5 | 0.7 | 0.32217322 |
| 6.5 | 3.0 | 0.7 | 0.037894300000000006 |
| 6.5 | 3.5 | 0.7 | 0.005178418 |
| 6.5 | 4.0 | 0.7 | 0.001170227 |
| 6.5 | 4.5 | 0.7 | 0.00104462 |
| 6.5 | 5.0 | 0.7 | 0.00022831099999999998 |
| 6.5 | 5.5 | 0.7 | 0.000204873 |
| 7.0 | 2.15 | 0.7 | 0.799562745 |
| 7.0 | 2.7 | 0.7 | 0.169379167 |
| 7.0 | 3.25 | 0.7 | 0.019800399 |
| 7.0 | 3.8 | 0.7 | 0.001902681 |
| 7.0 | 4.35 | 0.7 | 0.000501842 |
| 7.0 | 4.9 | 0.7 | 0.000126206 |
| 7.0 | 5.45 | 0.7 | 2.6700000000000002e-05 |
| 7.0 | 6.0 | 0.7 | 2.0600000000000003e-05 |
| 7.5 | 2.3 | 0.7 | 1.21559407 |
| 7.5 | 2.9 | 0.7 | 0.050740772 |
| 7.5 | 3.5 | 0.7 | 0.009225684 |
| 7.5 | 4.1 | 0.7 | 0.000752695 |
| 7.5 | 4.7 | 0.7 | 0.000144944 |
| 7.5 | 5.3 | 0.7 | 1.62e-05 |
| 7.5 | 5.9 | 0.7 | 3.72e-06 |
| 7.5 | 6.5 | 0.7 | 2.67e-06 |
| 8.0 | 2.45 | 0.7 | 0.31635749999999996 |
| 8.0 | 3.1 | 0.7 | 0.049589355 |
| 8.0 | 3.75 | 0.7 | 0.0016244529999999999 |
| 8.0 | 4.4 | 0.7 | 0.000321852 |
| 8.0 | 5.05 | 0.7 | 4.3799999999999994e-05 |
| 8.0 | 5.7 | 0.7 | 3.32e-06 |
| 8.0 | 6.35 | 0.7 | 1.04e-06 |
| 8.0 | 7.0 | 0.7 | 4.25e-07 |
| 4.0 | 2.0 | 0.83 | 1.123030977 |
| 4.0 | 2.25 | 0.83 | 0.643862594 |
| 4.0 | 2.5 | 0.83 | 0.494527641 |
| 4.0 | 2.75 | 0.83 | 0.460772203 |
| 4.0 | 3.0 | 0.83 | 0.449952666 |
| 4.5 | 2.0 | 0.83 | 1.177597462 |
| 4.5 | 2.3 | 0.83 | 0.24628630399999998 |
| 4.5 | 2.6 | 0.83 | 0.126303219 |
| 4.5 | 2.9 | 0.83 | 0.8058431640000001 |
| 4.5 | 3.2 | 0.83 | 0.064124542 |
Also, I should mention that I have tried using scipy's Rbf for interpolation because I read it would be useful when the data is not in grid format. I don't think I trust the results though because the resulting graph looked like this Rbf Interpolation Results