I'm trying to implement Gaussian Processes (GPs) in Julia following krasserm.
The numpy implementation works like a charm:
import numpy as np
def kernel(X1, X2, l=1.0, sigma_f=1.0):
''' Isotropic squared exponential kernel. Computes a covariance matrix from points in X1 and X2. Args: X1: Array of m points (m x d). X2: Array of n points (n x d). Returns: Covariance matrix (m x n). '''
sqdist = np.sum(X1**2, 1).reshape(-1, 1) + np.sum(X2**2, 1) - 2 * np.dot(X1, X2.T)
return sigma_f**2 * np.exp(-0.5 / l**2 * sqdist)
X = np.arange(-1, 1, 0.1).reshape(-1, 1)
mu = np.zeros(X.shape)
cov = kernel(X, X)
samples = np.random.multivariate_normal(mu.ravel(), cov, 3)
While I have an error for creating a multivariate normal using Distributions.jl.
Here what I have tried:
using LinearAlgebra
using Distributions
function kernel(x₁, x₂; l=1.0, σ=1.0)
t₁ = sum(x₁.^2, dims=2)
t₂ = sum(x₂.^2, dims=2)
t₃ = 2*x₁*x₂'
t = (t₁ .+ t₂') - t₃
return σ^2 * exp.(-0.5/l^2 * t)
end
x = -1:0.1:0.9
μ = zeros(length(x))
σ = kernel(x, x)
D = MvNormal(μ, σ)
The problem seems that my covariance is not positive (semi)definite.
Julia:
σ
20×20 Array{Float64,2}:
1.0 0.995012 0.980199 0.955997 … 0.235746 0.197899 0.164474
0.995012 1.0 0.995012 0.980199 0.278037 0.235746 0.197899
0.980199 0.995012 1.0 0.995012 0.324652 0.278037 0.235746
0.955997 0.980199 0.995012 1.0 0.375311 0.324652 0.278037
0.923116 0.955997 0.980199 0.995012 0.429557 0.375311 0.324652
0.882497 0.923116 0.955997 0.980199 … 0.486752 0.429557 0.375311
0.83527 0.882497 0.923116 0.955997 0.546074 0.486752 0.429557
0.782705 0.83527 0.882497 0.923116 0.606531 0.546074 0.486752
0.726149 0.782705 0.83527 0.882497 0.666977 0.606531 0.546074
0.666977 0.726149 0.782705 0.83527 0.726149 0.666977 0.606531
0.606531 0.666977 0.726149 0.782705 … 0.782705 0.726149 0.666977
0.546074 0.606531 0.666977 0.726149 0.83527 0.782705 0.726149
0.486752 0.546074 0.606531 0.666977 0.882497 0.83527 0.782705
0.429557 0.486752 0.546074 0.606531 0.923116 0.882497 0.83527
0.375311 0.429557 0.486752 0.546074 0.955997 0.923116 0.882497
0.324652 0.375311 0.429557 0.486752 … 0.980199 0.955997 0.923116
0.278037 0.324652 0.375311 0.429557 0.995012 0.980199 0.955997
0.235746 0.278037 0.324652 0.375311 1.0 0.995012 0.980199
0.197899 0.235746 0.278037 0.324652 0.995012 1.0 0.995012
0.164474 0.197899 0.235746 0.278037 0.980199 0.995012 1.0
python:
>>> cov
array([[1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681,
0.60653066, 0.54607443, 0.48675226, 0.42955736, 0.3753111 ,
0.32465247, 0.2780373 , 0.23574608, 0.1978987 , 0.16447446],
[0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748,
0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904,
0.66697681, 0.60653066, 0.54607443, 0.48675226, 0.42955736,
0.3753111 , 0.32465247, 0.2780373 , 0.23574608, 0.1978987 ],
[0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867,
0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454,
0.72614904, 0.66697681, 0.60653066, 0.54607443, 0.48675226,
0.42955736, 0.3753111 , 0.32465247, 0.2780373 , 0.23574608],
[0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248,
0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021,
0.78270454, 0.72614904, 0.66697681, 0.60653066, 0.54607443,
0.48675226, 0.42955736, 0.3753111 , 0.32465247, 0.2780373 ],
[0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ,
0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ,
0.83527021, 0.78270454, 0.72614904, 0.66697681, 0.60653066,
0.54607443, 0.48675226, 0.42955736, 0.3753111 , 0.32465247],
[0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681,
0.60653066, 0.54607443, 0.48675226, 0.42955736, 0.3753111 ],
[0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748,
0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904,
0.66697681, 0.60653066, 0.54607443, 0.48675226, 0.42955736],
[0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867,
0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454,
0.72614904, 0.66697681, 0.60653066, 0.54607443, 0.48675226],
[0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248,
0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021,
0.78270454, 0.72614904, 0.66697681, 0.60653066, 0.54607443],
[0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ,
0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ,
0.83527021, 0.78270454, 0.72614904, 0.66697681, 0.60653066],
[0.60653066, 0.66697681, 0.72614904, 0.78270454, 0.83527021,
0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635,
0.8824969 , 0.83527021, 0.78270454, 0.72614904, 0.66697681],
[0.54607443, 0.60653066, 0.66697681, 0.72614904, 0.78270454,
0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748,
0.92311635, 0.8824969 , 0.83527021, 0.78270454, 0.72614904],
[0.48675226, 0.54607443, 0.60653066, 0.66697681, 0.72614904,
0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867,
0.95599748, 0.92311635, 0.8824969 , 0.83527021, 0.78270454],
[0.42955736, 0.48675226, 0.54607443, 0.60653066, 0.66697681,
0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248,
0.98019867, 0.95599748, 0.92311635, 0.8824969 , 0.83527021],
[0.3753111 , 0.42955736, 0.48675226, 0.54607443, 0.60653066,
0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ,
0.99501248, 0.98019867, 0.95599748, 0.92311635, 0.8824969 ],
[0.32465247, 0.3753111 , 0.42955736, 0.48675226, 0.54607443,
0.60653066, 0.66697681, 0.72614904, 0.78270454, 0.83527021,
0.8824969 , 0.92311635, 0.95599748, 0.98019867, 0.99501248,
1. , 0.99501248, 0.98019867, 0.95599748, 0.92311635],
[0.2780373 , 0.32465247, 0.3753111 , 0.42955736, 0.48675226,
0.54607443, 0.60653066, 0.66697681, 0.72614904, 0.78270454,
0.83527021, 0.8824969 , 0.92311635, 0.95599748, 0.98019867,
0.99501248, 1. , 0.99501248, 0.98019867, 0.95599748],
[0.23574608, 0.2780373 , 0.32465247, 0.3753111 , 0.42955736,
0.48675226, 0.54607443, 0.60653066, 0.66697681, 0.72614904,
0.78270454, 0.83527021, 0.8824969 , 0.92311635, 0.95599748,
0.98019867, 0.99501248, 1. , 0.99501248, 0.98019867],
[0.1978987 , 0.23574608, 0.2780373 , 0.32465247, 0.3753111 ,
0.42955736, 0.48675226, 0.54607443, 0.60653066, 0.66697681,
0.72614904, 0.78270454, 0.83527021, 0.8824969 , 0.92311635,
0.95599748, 0.98019867, 0.99501248, 1. , 0.99501248],
[0.16447446, 0.1978987 , 0.23574608, 0.2780373 , 0.32465247,
0.3753111 , 0.42955736, 0.48675226, 0.54607443, 0.60653066,
0.66697681, 0.72614904, 0.78270454, 0.83527021, 0.8824969 ,
0.92311635, 0.95599748, 0.98019867, 0.99501248, 1. ]])
Numpy seems to accept it without complains but MvNormal says:
PosDefException: matrix is not Hermitian; Cholesky factorization failed.
or
PosDefException: matrix is not positive definite; Cholesky factorization failed.
As it seems that it can be a problem of floating points precision, I have tried sol2 using:
σ = σ + maximum([0.0, -minimum(eigvals(σ))])*I
D = MvNormal(μ, σ)
which should make the matrix positive definite, without success.
Solution sol1 or a combination of sol1 and sol2 did not work either:
σ = Symmetric(σ)
D = MvNormal(σ.data)
I also tried as in sol3 to make it Hermitian when needed:
D = MvNormal(Matrix(Hermitian(σ)))
Do you have any insight of what should I do?
After further tests, it seems that
σ = σ + 0.00000000001*I
D = MvNormal(σ)
works, but I find this solution really awful, why does it even work ?