import matplotlib.pyplot as plt
7 def vis_square(data):
8 """Take an array of shape (n, height, width) or (n, height, width, 3)
9 and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
10 # normalize data for display
11 data = (data - data.min()) / (data.max() - data.min())
12
13 # force the number of filters to be square
14 n = int(np.ceil(np.sqrt(data.shape[0])))
15 padding = (((0, n ** 2 - data.shape[0]),
16 (0, 1), (0, 1)) # add some space between filters
17 + ((0, 0),) * (data.ndim - 3)) # don't pad the last dimension (if th ere is one)
18 data = np.pad(data, padding, mode='constant', constant_values=1) # pad with one s (white)
19
20 # tile the filters into an image
21 data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(rang e(4, data.ndim + 1)))
22 data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
23 #plt.imshow(data);plt.axis('off') ### THIS LINE DOESNT WORK IN TERMINAL
fig,ax = plt.imshow(data)
fig.savefig('fig.png')
I work in a shared server, thus no gui is provided. But I want to see what's the feature look like?
line: 23 doesn't work in terminal.
how can I save the data to an image? plz advise.
I solve the task by save the data file to a pickle, and then draw the picture in another computer. But still interested in the method of drawing the picture in a virtual picture and save the image on the server.
Use a `backend'. Check this doc
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt