import numpy as np
from numpy import asarray
from matplotlib import pyplot as plt
import torch
# generate a signal
fs = 50 # sampling freq
ts = np.arange(0, 10, 1/fs) # times at which signal is sampled
s1 = np.sin(2 * np.pi * 2 * ts) # 2 hz
s2 = np.sin(2 * np.pi * 3 * ts) # 3 hz
s3 = np.sin(2 * np.pi * 6 * ts) # 6 hz
s = s1 + s2 + s3 # aggregate signal
# generate specgram
spectrum, freqs, t, im = plt.specgram(s, Fs=fs, xextent=((0, len(s)/fs)))
# convert matplotlib image to torch tensor
# bypassing the numpy part would be even better!
torch_tensor = torch.from_numpy(asarray(im, np.float32))
print(torch_tensor)
>>> TypeError: float() argument must be a string or a number, not 'AxesImage'
I should add that the 'spectrum' variable is kind of what I am looking for, except that I am a little confused by it since it has only two columns for time, and I think the specgram image has many more than two timesteps. If there is a way to use the spectrum variable to represent the whole image as a torch tensor, then that would also work for me.