I don't know how you're getting on, so maybe you've solved it. But, based on the link from Paul's comment, you could do something like this. We pass the color values we want using the facecolor argument of plot_surface.
(I've modified the surface3d demo from the matplotlib docs)
EDIT: As Stefan noted in his comment, my answer can be simplified to:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
fig = plt.figure()
ax = fig.gca(projection='3d')
X = np.arange(-5, 5, 0.25)
xlen = len(X)
Y = np.arange(-5, 5, 0.25)
ylen = len(Y)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
maxR = np.amax(R)
Z = np.sin(R)
# Note that the R values must still be normalized.
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=cm.jet(R/maxR),
linewidth=0)
plt.show()
And (the end of) my needlessly complicated original version, using the same code as above though omitting the matplotlib.cm import,
# We will store (R, G, B, alpha)
colorshape = R.shape + (4,)
colors = np.empty( colorshape )
for y in range(ylen):
for x in range(xlen):
# Normalize the radial value.
# 'jet' could be any of the built-in colormaps (or your own).
colors[x, y] = plt.cm.jet(R[x, y] / maxR )
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
linewidth=0)
plt.show()