2

When I try to dispaly LSTM and RNN models prediction results with the following code:

plt.figure(figsize=(5, 3))

plt.plot(y_test, c="orange", linewidth=3, label="Original values")
plt.plot(lstm_pred, c="red", linewidth=3, label="LSTM predictions")
plt.plot(rnn_pred, alpha=0.5, c="green", linewidth=3, label="RNN predictions")
plt.legend()
plt.xticks(rotation=45)
plt.title("Predictions vs actual data", fontsize=20)
plt.show()

If I plot them one by one, the lines are displayed correctly.

enter image description here

thon

enter image description here

But display all lines in one plot, the lines don't display correctly. Someone knows how to fix it? Thanks.

enter image description here

ah bon
  • 9,293
  • 12
  • 65
  • 148
  • 7
    The data is displaying properly - the magnitude of the values in the red data is just way larger than that of the other data. – CDJB Nov 25 '19 at 11:23
  • 1
    Exactly, so create right Y axis and plot red line there. – crayxt Nov 25 '19 at 11:39

1 Answers1

4

As mentioned in the comments you need to create a second Y axis. Then you need to merge the legend together

fig, ax1 = plt.subplots()

line1 = ax1.plot(y_test, c="orange", linewidth=3, label="Original values")
line2 = ax1.plot(rnn_pred, alpha=0.5, c="green", linewidth=3, label="RNN predictions")
plt.xticks(rotation=45)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
line3 = ax2.plot(lstm_pred, c="red", linewidth=3, label="LSTM predictions")

# added these three lines
lines = line1+line2+line3
labels = [l.get_label() for l in lines]
ax.legend(lns, labels)

plt.title("Predictions vs actual data", fontsize=20)
plt.show()
ah bon
  • 9,293
  • 12
  • 65
  • 148
CAPSLOCK
  • 6,243
  • 3
  • 33
  • 56