0

The problem originates from fitting a diffraction pattern to data, where the number of slits is known before hand. I have given a simplified version below that highlights the same issue. The function should fit the values of a and b the data, while passing n to the function. I could use a global n which would solve my issues, however I would like to do this using **kwargs as shown in the scipy.optimize.curve_fit() reference.

Here is an example of the issue. The code generates the curve of 4sin(2x)+3cos(2x) with some noise as the data:


import numpy as np
import scipy
import matplotlib.pyplot as plt

def curve(x,a,b,**kwargs):
    n = kwargs["n"]
    return a*np.sin(n*x)+b*np.cos(n*x)

x = np.linspace(-5,5,1000)
y = np.random.normal(loc=curve(x, 4, 3, n=2), scale=0.2, size=None)
result = scipy.optimize.curve_fit(curve, x, y, n = 2)
y2 = curve(x, *result[0], n=2)

plt.plot(x, y2)
plt.plot(x,y)
plt.show()

This returns the error

  File "C:\Users\HP\OneDrive\Documents\Uni\lab year 2\diffraction\kwargs.py", line 13, in <module>
    result = scipy.optimize.curve_fit(curve, x, y, n = 2)

  File "C:\Users\HP\anaconda3\lib\site-packages\scipy\optimize\_minpack_py.py", line 834, in curve_fit
    res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)

TypeError: leastsq() got an unexpected keyword argument 'n'
ray
  • 11,310
  • 7
  • 18
  • 42
  • Have you tried `args=(....,)`? Many `scipy` functions use `args` to pass extra values to the function. `least_squares` documents this; `curve_fit` doesn't, but talks about passing the task on to `least_squares`. Or you might have to use `least_squares` directly. The source of `curve_fit` appears to explicitly disallow `args`. – hpaulj Mar 29 '23 at 16:03
  • I tried args as you mentioned as this is used in other scipy methods such as fmin. However it returns the error: ValueError: 'args' is not a supported keyword argument. – Adam Eaton Mar 29 '23 at 16:39
  • Yes, I saw that test in the `curve_fit` code; that's why I suggested `least_squares` instead. – hpaulj Mar 29 '23 at 19:00
  • Its true that might work. I managed to get a couple other methods to work like using a lambda function or defining a class. However I was just wondering is there was a way to pass a variable directly into curvefit since it is annoying me that I can't figure it out. There must be a way to do it as it is a reasonably common thing to do and it is even defined in the method reference page. I use the curvefit method a lot in labs and figuring out how to use additional optional parameters would lead to a much more eloquent solution. Thanks for your help though. – Adam Eaton Mar 29 '23 at 19:31

1 Answers1

0

I had a similar issue with curve_fit not accepting the args argument and after following the suggestion here:, managed to get something working with a curried function

import numpy as np
import scipy
import matplotlib.pyplot as plt

def curve_curry(n):
    def curve(x, a, b):
        return a * np.sin(n * x) + b * np.cos(n * x)
    
    return curve

n = 2
x = np.linspace(-5, 5, 1000)
y = np.random.normal(loc=curve_curry(n)(x, 4, 3), scale=0.2, size=None)
result = scipy.optimize.curve_fit(curve_curry(n=2), x, y)
y2 = curve_curry(n)(x, *result[0])

plt.plot(x, y, '.', color='lightgrey')
plt.plot(x, y2)
plt.show()

In my case, I had fit parameters [p0, p1, p2], constants c1, c2, c3 and pre-defined functions f1, f2. My code looked something like this:

def func_curry(c1, c2, c3, f1, f2):
    def func(x, *p):
        
        yinv = f1(c1) + x * c2 * (c3 + p[1] + p[2] * x ) / p[0]
        y = f2(yinv)
        
        return y
    
    return func 


# Curve fit using scipy

args = (c1, c2, c3, f1, f2)
p0 = [p0, p1, p2]
popt, pcov = scipy.optimize.curve_fit(func_curry(*args), xdata, ydata, p0=p0)
PetGriffin
  • 495
  • 1
  • 4
  • 13