3

I have a problem of how python evaluates vector functions of the form (f1(x),f2(x)) for arrays of x when trying curve fitting.

import numpy as np
from scipy.optimize import curve_fit

def func(x,a,b,c):
    return np.array([a*x**b+c,a*x**b+c+1])

ydata = np.array([[1,2],[3,4],[5,6],[7,8]],dtype=float)
xdata=np.array([1,2,3,4], dtype=float)
popt,pcov = curve_fit(func, xdata, ydata)

gives "ValueError: operands could not be broadcast together with shapes (2,4) (4,2)" Transposing the data to be fitted:

ydata=np.array([[1,2],[3,4],[5,6],[7,8]],dtype=float).transpose()

gives "TypeError: Improper input: N=3 must not exceed M=2", because now I have less function values than parameters. Ok, I know why I can't make a fit for that. So I need to transpose the function values:

def func(x,a,b,c):
    return np.array([a*x**b+c,a*x**b+c+1]).transpose()

This gives me "Result from function call is not a proper array of floats."

How do I get a solution from such a problem? Mathematically, it should be well determined, if the data can fit the model.

  • [related question](http://stackoverflow.com/questions/41090791/how-do-i-optimize-and-find-the-coefficients-for-two-equations-simultaneously-in) – Stelios Dec 23 '16 at 19:39
  • You might also be interested in my answer here: http://stackoverflow.com/questions/40829791/fitting-a-vector-function-with-curve-fit-in-scipy/40961491#40961491 – tBuLi Dec 24 '16 at 17:23

1 Answers1

3

curve_fit expects a func that returns 1D array so output should be flattend. In this case, you should feed in ydata.T.ravel() to curve_fit to have correct order as the elements of func(x,a,b,c).

import numpy as np
from scipy.optimize import curve_fit

def func(x,a,b,c):
    output = np.array([a*(x**b)+c,a*(x**b)+c+1])
    return output.ravel()

ydata = np.array([[1,2],[3,4],[5,6],[7,8]],dtype=float)
xdata=np.array([1,2,3,4], dtype=float)
popt,pcov = curve_fit(func, xdata, ydata.T.ravel())
# print (popt)
# [ 2.,  1., -1.]

Testing the results,

func(xdata,*popt).reshape(-1,len(xdata)).T
#  [[ 1.,  2.],
#   [ 3.,  4.],
#   [ 5.,  6.],
#   [ 7.,  8.]]
Mahdi
  • 3,188
  • 2
  • 20
  • 33