3

The formula below is a special case of the Wasserstein distance/optimal transport when the source and target distributions, x and y (also called marginal distributions) are 1D, that is, are vectors.

enter image description here

where F^{-1} are inverse probability distribution functions of the cumulative distributions of the marginals u and v, derived from real data called x and y, both generated from the normal distribution:

import numpy as np
from numpy.random import randn
import scipy.stats as ss

n = 100
x = randn(n)
y = randn(n)

How can the integral in the formula be coded in python and scipy? I'm guessing the x and y have to be converted to ranked marginals, which are non-negative and sum to 1, while Scipy's ppf could be used to calculate the inverse F^{-1}'s?

develarist
  • 1,224
  • 1
  • 13
  • 34
  • Integrals are a representation of sums, so you should represent this as a summation and use `dz` as small as possible. Check a book related to Mathematical Methods you could probably find a better representation of this integral – EnriqueBet Dec 07 '20 at 02:28
  • 1
    [`scipy.stats.wasserstein_distance`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wasserstein_distance.html), see also a Numba compiled version at [`GitHub: dist-matrix`](https://github.com/sparks-baird/dist-matrix) – Sterling Apr 23 '22 at 02:59

2 Answers2

2

Note that when n gets large we have that a sorted set of n samples approaches the inverse CDF sampled at 1/n, 2/n, ..., n/n. E.g.:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
plt.plot(norm.ppf(np.linspace(0, 1, 1000)), label="invcdf")
plt.plot(np.sort(np.random.normal(size=1000)), label="sortsample")
plt.legend()
plt.show()

plot

Also note that your integral from 0 to 1 can be approximated as a sum over 1/n, 2/n, ..., n/n.

Thus we can simply answer your question:

def W(p, u, v):
    assert len(u) == len(v)
    return np.mean(np.abs(np.sort(u) - np.sort(v))**p)**(1/p)

Note that if len(u) != len(v) you can still apply the method with linear interpolation:

def W(p, u, v):
    u = np.sort(u)
    v = np.sort(v)
    if len(u) != len(v):
        if len(u) > len(v): u, v = v, u
        us = np.linspace(0, 1, len(u))
        vs = np.linspace(0, 1, len(v))
        u = np.linalg.interp(u, us, vs)
    return np.mean(np.abs(u - v)**p)**(1/p)

An alternative method if you have prior information about the sort of distribution of your data, but not its parameters, is to find the best fitting distribution on your data (e.g. with scipy.stats.norm.fit) for both u and v and then do the integral with the desired precision. E.g.:

from scipy.stats import norm as gauss
def W_gauss(p, u, v, num_steps):
    ud = gauss(*gauss.fit(u))
    vd = gauss(*gauss.fit(v))
    z = np.linspace(0, 1, num_steps, endpoint=False) + 1/(2*num_steps)
    return np.mean(np.abs(ud.ppf(z) - vd.ppf(z))**p)**(1/p)
orlp
  • 112,504
  • 36
  • 218
  • 315
  • I compared the 2nd code snippet above using `W(1, u, v)`, where `u` and `v` are vectors generated with `norm.rvs`, but when I compare the result of this function with `scipy.stats.wasserstein_distance(u, v)`, your function multiplies the output of the latter by `1000`. For example, `W` will give `120` if `wasserstein_distance` gives `0.12` – develarist Dec 07 '20 at 13:22
  • As for the last code snippet above, function `W_gauss` returns `nan` as output value, and the error `RuntimeWarning: invalid value encountered in subtract return np.linalg.norm(ud.ppf(z) - vd.ppf(z), ord=p)`. What is `num_steps` and how is it supposed to be set? – develarist Dec 07 '20 at 13:26
  • @develarist Ah, the issue was that `gauss.ppf` is (correctly) infinite at `0` and `1`. I'll fix that. `num_steps` is just an integer parameter that represents how many steps should be used in evaluating the integral, the higher the more precision. – orlp Dec 07 '20 at 15:18
  • any idea why (the first) function `W` doesn't return the same thing as `wasserstein_distance`? – develarist Dec 07 '20 at 15:20
  • @develarist I actually misused `np.linalg.norm`, that's now also fixed. Note that `W(1, u, v)` is equivalent to (unweighted) `wasserstein_distance`. – orlp Dec 07 '20 at 15:29
  • but `wasserstein_distance` is unweighted by default. `W` still doesn't match it, this is what I'm trying to figure out – develarist Dec 07 '20 at 15:41
  • 1
    @develarist From my tests with the latest version (make sure you have included my edits) we have that `W(1, u, v)` matches `scipy.stats.wasserstein_distance(u, v)`. Make sure you compare with `np.isclose` not with `==` since floating point isn't perfect. – orlp Dec 07 '20 at 15:42
  • If i have several different CDFs of different variables instead of just one, but I would like to weight them together to form one "optimal source distribution", can't this framework be modified into the closed-form 1D Wasserstein by simply adding in a weight vector somehow? https://stats.stackexchange.com/questions/499748/weighted-combination-of-several-cumulative-distribution-functions – develarist Dec 07 '20 at 15:46
0

I guess I am a bit late but, but this is what I would do for an exact solution (using only numpy):

import numpy as np
from numpy.random import randn
n = 100
m = 80
p = 2
x = np.sort(randn(n))
y = np.sort(randn(m))
a = np.ones(n)/n
b = np.ones(m)/m
# cdfs
ca = np.cumsum(a)
cb = np.cumsum(b)

# points on which we need to evaluate the quantile functions
cba = np.sort(np.hstack([ca, cb]))
# weights for integral
h = np.diff(np.hstack([0, cba]))

# construction of first quantile function
bins = ca + 1e-10 # small tolerance to avoid rounding errors and enforce right continuity
index_qx = np.digitize(cba, bins, right=True)    # right=True becouse quantile function is 
                                                 # right continuous
qx = x[index_qx] # quantile funciton F^{-1}      

# construction of second quantile function 
bins = cb + 1e-10 
index_qy = np.digitize(cba, bins, right=True)    # right=True becouse quantile function is 
                                                 # right continuous
qy = y[index_qy] # quantile funciton G^{-1}

ot_cost = np.sum((qx - qy)**p * h)
print(ot_cost)
        

In case you are interested, here you can find a more detailed numpy based implementation of the ot problem on the real line with dual and primal solutions as well: https://github.com/gnies/1d-optimal-transport. (I am still working on it though).

Giacomo
  • 101
  • 2