0

I have a function that takes an argument y_0 that might be a list or array, or might just be a single float. I wanted it to work in the case where y_0 was a single float, so I tried using np.asarray(y_0) in the code below, so that if there is only one item, the loop will still work. However, I got the error TypeError: iteration over a 0-d array. I could use an if statement to check to see if it is a singleton or not, and take appropriate action. However, I was curious if there is a way to make it iterate over the single object?

def vf_grapher(fn, t_0, t_n, dt, y_0, lintype='-r', sup_title=None,
               title=None, xlab=None, ylab=None):

    t = np.arange(t_0, t_n, dt)
    y_min = .0
    y_max = .0

    fig, axs = plt.subplots()
    fig.suptitle(sup_title)

    axs.set_title(title)
    axs.set_ylabel(ylab)
    axs.set_xlabel(xlab)

    for iv in np.asarray(y_0):
        soln = rk4(dt, t, fn, iv)
        plt.plot(t, soln, lintype)
        if y_min > np.min(soln):
            y_min = np.min(soln)
        if y_max < np.max(soln):
            y_max = np.max(soln)

For minimal working example, include the following function:

def rk4(dt, t, field, y_0):
"""
:param dt: float - the timestep
:param t: array - the time mesh
:param field: method - the vector field y' = f(t, y)
:param y_0: array - contains initial conditions
:return: ndarray - solution
"""

# Initialize solution matrix. Each row is the solution to the system
# for a given time step. Each column is the full solution for a single
# equation.
y = np.asarray(len(t) * [y_0])

for i in np.arange(len(t) - 1):
    k1 = dt * field(t[i], y[i])
    k2 = dt * field(t[i] + 0.5 * dt, y[i] + 0.5 * k1)
    k3 = dt * field(t[i] + 0.5 * dt, y[i] + 0.5 * k2)
    k4 = dt * field(t[i] + dt, y[i] + k3)
    y[i + 1] = y[i] + (k1 + 2 * k2 + 2 * k3 + k4) / 6

return y

if __name__ == '__main__':
    def f(t, x): return x**2 - x

    vf_grapher(f, 0, 4, 0.1, (-0.9, 0.5, 1.01), xlab='t', ylab='x(t)',
               sup_title=r'Solution Field for $\dot{x} = x^2 - x$')
rocksNwaves
  • 5,331
  • 4
  • 38
  • 77

2 Answers2

2

You can use the ndmin argument to np.array to ensure that your array is in fact iterable:

np.array(y_0, ndmin=1, copy=False)

np.asarray is just an alias for np.array that sets some default arguments differently.

ndmin can be used to pad your shape with unit dimensions. This helps because normally zero-dimensional arrays are equivalent to scalars. One annoying side-effect of this is that they aren't iterable. ndmin=1 means that a scalar input should be treated as a one-dimensional, one-element array, which is what you're looking for.

copy=False just tells numpy to use existing arrays as-is instead of making a copy. That way, if a user passes in an actual array (as opposed to a list or a scalar), it will be used without data duplication. I often pair this argument with subok=True, which will pass through subclasses of ndarray as well without copying.

Mad Physicist
  • 107,652
  • 25
  • 181
  • 264
  • So, that worked well but left me with two questions... 1.) when I call np.array(y_0), why is the dimension set to zero by default? 2.) When would a copy be necessary? I'm not very well versed on how memory and all that works, but does it have something to do with that? – rocksNwaves Oct 29 '19 at 18:48
  • @rocksnWaves. You don't ever need a copy and numpy treats scalars weird. That's all the args are doing. I've updated my answer with more details. – Mad Physicist Oct 29 '19 at 19:12
0

I'm not sure why you can iterate over a list of a single object but not an array of a single object, but I found a way to determine if an item is iterable or not in this answer to a different question: https://stackoverflow.com/a/1952481/3696204

Then, I used a try, except block as follows:

try:
    iter(y_0)
except TypeError:
    y_0 = list([y_0])

for iv in y_0:
    soln = rk4(dt, t, fn, iv)
    plt.plot(t, soln, lintype)
    if y_min > np.min(soln):
        y_min = np.min(soln)
    if y_max < np.max(soln):
        y_max = np.max(soln)

Thanks for the helpful comments, everyone.

rocksNwaves
  • 5,331
  • 4
  • 38
  • 77