0
import numpy as np
from numpy import asarray
from matplotlib import pyplot as plt
import torch

# generate a signal
fs = 50 # sampling freq
ts = np.arange(0, 10, 1/fs) # times at which signal is sampled
s1 = np.sin(2 * np.pi * 2 * ts) # 2 hz
s2 = np.sin(2 * np.pi * 3 * ts) # 3 hz
s3 = np.sin(2 * np.pi * 6 * ts) # 6 hz
s = s1 + s2 + s3 # aggregate signal

# generate specgram
spectrum, freqs, t, im = plt.specgram(s, Fs=fs, xextent=((0, len(s)/fs)))

# convert matplotlib image to torch tensor
# bypassing the numpy part would be even better!
torch_tensor = torch.from_numpy(asarray(im, np.float32))

print(torch_tensor)

>>> TypeError: float() argument must be a string or a number, not 'AxesImage'

I should add that the 'spectrum' variable is kind of what I am looking for, except that I am a little confused by it since it has only two columns for time, and I think the specgram image has many more than two timesteps. If there is a way to use the spectrum variable to represent the whole image as a torch tensor, then that would also work for me.

Mattpats
  • 414
  • 2
  • 9
  • 21

1 Answers1

1

plt.specgram returns the spectrogram in the spectrum variable. This means that you need to pass that variable to the torch.from_numpy function. Additionally, according to this, specgram shows the 10*log10(spectrum) which means that you might want to do that operation ot compare the results shown by specgram with the plot of your tensor. See code below:

import numpy as np
from numpy import asarray
import numpy as np
from matplotlib import pyplot as plt
import torch

# generate a signal
fs = 50 # sampling freq
ts = np.arange(0, 10, 1/fs) # times at which signal is sampled
s1 = np.sin(2 * np.pi * 2 * ts) # 2 hz
s2 = np.sin(2 * np.pi * 3 * ts) # 3 hz
s3 = np.sin(2 * np.pi * 6 * ts) # 6 hz
s = s1 + s2 + s3 # aggregate signal

# generate specgram
ax1=plt.subplot(121)
ax1.set_title('Specgram image')
spectrum, freqs, t, im = ax1.specgram(s, Fs=fs, xextent=((0, len(s)/fs)))
ax1.axis('tight')

torch_tensor = torch.from_numpy(spectrum)

#Plot torch tensor variable
ax2=plt.subplot(122)
ax2.set_title('Torch tensor image')
ax2.imshow(10*np.log10(torch_tensor),origin='lower left',extent=[0,10,0,25])
ax2.axis('tight')

plt.show()

And the output gives:

enter image description here

jylls
  • 4,395
  • 2
  • 10
  • 21
  • Do you know if there is a way to set the "torch tensor image" as an object? Something like, `specgram_object = ax2.imshow(10*np.log10(torch_tensor)...)` With the object being a np.array or torch.tensor – Mattpats Dec 20 '21 at 23:49
  • You could just use: `torch_tensor = torch.from_numpy(10*np.log10(spectrum))` – jylls Dec 20 '21 at 23:50
  • `imshow` returns a [AxesImage](https://matplotlib.org/stable/api/image_api.html#matplotlib.image.AxesImage) instance so I don't think you can convert that to a numpy array or a torch tensor. And even if you could I am not sure it would mean much since it's dealing with the properties related to displaying the image rather than the array you are plotting – jylls Dec 20 '21 at 23:58
  • the problem is that I want the whole "specgram image" as it appears. the torch_tensor you suggest has dimensions of (129,2), whereas the image is (375,264) – Mattpats Dec 20 '21 at 23:59
  • 1
    The dimension (129,2) comes from the way you are computing the spectrogram. The two plots I show both only have 2 columns and 129 rows. If you want more time windows you might have to produce a longer signal or change the options of your spectrogram – jylls Dec 21 '21 at 00:04
  • You could also repeat the columns of your spectrogram [col1,col1,...,col1,col2,col2,...col2] to get the dimension you want – jylls Dec 21 '21 at 00:09
  • I'm probably going to use the repeat the columns idea. But you mention make the signal longer or change the spectrogram options... Which spectrogram options would make the spectrogram have more time steps? – Mattpats Dec 21 '21 at 01:33
  • Recommending parameters is a little difficult since it very much depends on the type of information you want to see on your spectrograms (frequency bands, power or amplitude etc.). I'd recommend reading the doc [here](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.specgram.html) . And if you have additional questions, addressing them to the signal processing stackexchange [here](https://dsp.stackexchange.com/) – jylls Dec 21 '21 at 17:58