I found a solution by combining methods from:
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
yticks = [3, 2, 1, 0]
X,Z = np.meshgrid(np.linspace(0,1,21), np.linspace(0,1,21))
data = np.cos(X) * np.cos(X) + np.sin(Z) * np.sin(Z)
for c, k in zip(colors, yticks):
Y = k*np.ones(X.shape)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=plt.cm.BrBG(data/data.max()), shade=False)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_yticks(yticks)
plt.show()