Is there a Haskell library for automatic differentiation which works with unboxed vectors? The grad
function from Numeric.AD
requires an instance of Traversable
, which Data.Vector.Unboxed
is not.
Asked
Active
Viewed 1,025 times
45

Henry Disoza
- 223
- 3
- 11

Grzegorz Chrupała
- 3,053
- 17
- 24
-
7`grad` doesn't need that many changes to make it work for unboxed vectors. You'll have to reimplement `bind` and `unbind` from `Numeric.AD.Internal.Var`. The `Variable` type is not unboxed, but you can replace it with a tuple. (I can't try it out right now because `ad` doesn't build on GHC 7.8) – Sjoerd Visscher Mar 27 '14 at 22:38
-
@SjoerdVisscher I think your comment would suffice as a good answer to this question? – sclv Feb 23 '15 at 18:38
1 Answers
1
I don't know why vectors of pairs are stored as pairs of vectors, but you can easily write instances for your datatype to store the elements sequentially.
{-# LANGUAGE TypeFamilies, MultiParamTypeClasses #-}
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
import Control.Monad (liftM, zipWithM_)
import Data.Vector.Unboxed.Base
data Point3D = Point3D {-# UNPACK #-} !Int {-# UNPACK #-} !Int {-# UNPACK #-} !Int
newtype instance MVector s Point3D = MV_Point3D (MVector s Int)
newtype instance Vector Point3D = V_Point3D (Vector Int)
instance Unbox Point3D
At this point the last line will cause an error since there are no instances for vector types for Point3D. They can be written as follows:
instance M.MVector MVector Point3D where
basicLength (MV_Point3D v) = M.basicLength v `div` 3
basicUnsafeSlice a b (MV_Point3D v) = MV_Point3D $ M.basicUnsafeSlice (a*3) (b*3) v
basicOverlaps (MV_Point3D v0) (MV_Point3D v1) = M.basicOverlaps v0 v1
basicUnsafeNew n = liftM MV_Point3D (M.basicUnsafeNew (3*n))
basicUnsafeRead (MV_Point3D v) n = do
[a,b,c] <- mapM (M.basicUnsafeRead v) [3*n,3*n+1,3*n+2]
return $ Point3D a b c
basicUnsafeWrite (MV_Point3D v) n (Point3D a b c) = zipWithM_ (M.basicUnsafeWrite v) [3*n,3*n+1,3*n+2] [a,b,c]
instance G.Vector Vector Point3D where
basicUnsafeFreeze (MV_Point3D v) = liftM V_Point3D (G.basicUnsafeFreeze v)
basicUnsafeThaw (V_Point3D v) = liftM MV_Point3D (G.basicUnsafeThaw v)
basicLength (V_Point3D v) = G.basicLength v `div` 3
basicUnsafeSlice a b (V_Point3D v) = V_Point3D $ G.basicUnsafeSlice (a*3) (b*3) v
basicUnsafeIndexM (V_Point3D v) n = do
[a,b,c] <- mapM (G.basicUnsafeIndexM v) [3*n,3*n+1,3*n+2]
return $ Point3D a b c
I think most of the function definitions are self explanatory. The vector of points is stored as a vector of Ints and the nth point is the 3n,3n+1,3n+2 Ints.

Navnish Bhardwaj
- 1,687
- 25
- 39
-
a b returns a vector starting at the ath element and ending at the bth element. The start of the nth element is 3*n; the end is 3*n+2. 2) You never use these functions, but they are used to define all other functions on unboxed vectors, so these are all the functions you need to define to use any function on unboxed vectors. 3) Your "point" type should contain n elements of the same type. This would work with data Point3D a = P3d {-# UNPACK #-} !a {-# UNPACK #-} !a {-# UNPACK #-} !a and newtype instance MVector s (Point3D a) = MV_Point3D (MVector s a) – Navnish Bhardwaj Nov 12 '14 at 10:16