{- simplified conjugate gradient, as outlined by Jan-Willem Maessen: for j <- seq(1#cgit) do q = A p alpha = rho / (p DOT q) z += alpha p rho0 = rho r -= alpha q rho := r DOT r beta = rho / rho0 p := r + beta p end Here p,q,r, and z are vectors, A is the derivative of our function (in this case a sparse symmetric positive-definite matrix, but we can really think of it as a higher-order function of type Vector->Vector) and the greek letters are scalars. The "answer" is z. In practice we'd not run a fixed number of iterations, but instead do a convergence test. All the hard work is in the line "q = A p", but the storage consumption is mostly in the surrounding code. -} {- ok, lets translate that into Haskell, for concreteness - three variants: naive, list-based code - strict-list-based code - array-base, inplace-update code -} import Data.Array.Base import Data.Array.IArray import Control.Monad(unless) import Control.Monad.ST import Data.Array.ST import Debug.Trace(trace) import System.Environment(getArgs) type Elem = Double n :: Num a => a n = 40 a :: Matrix a = [[if i==j then 1 else 0|i<-[1..n]]|j<-[1..n]] p,r,z :: Vector p = [1..n] r = [1..n] z = [1|i<-[1..n]] rho :: Elem rho = 1 -- stop criterion for main loop stop z n = n==0 ------------------------ naive list-based code type Vector = [Elem] type Matrix = [Vector] s .* v = map (s*) v v1 `add` v2 = zipWith (+) v1 v2 v1 `sub` v2 = zipWith (-) v1 v2 v1 `dot` v2 = sum $ zipWith (*) v1 v2 m `mat` v = map (`dot` v) m loop a p r z rho n | stop z n = z loop a p r z rho n = loop a p' r' z' rho' (n-1) where q = a `mat` p alpha = rho / (p `dot` q) z' = z `add` (alpha .* p) r' = r `sub` (alpha .* q) rho' = r' `dot` r' beta = rho' / rho p' = r' `add` (beta .* p) test c = loop a p r z rho c ------------------------ strict list-based code data StrictList a = Nil | !a :< !(StrictList a) deriving Show type VectorS = StrictList Elem type MatrixS = StrictList Vector foldS f n Nil = n foldS f n (x:a+x*b) v1 v2 v1 `submulS` (x,v2) = zipWithS (\a b->a-x*b) v1 v2 loopS a p r z rho n | stop z n = z loopS a p r z rho n = loopS a p' r' z' rho' (n-1) where q = a `matS` p alpha = rho / (p `dotS` q) z' = z `addmulS` (alpha,p) -- z `addS` (alpha `smulS` p) r' = r `submulS` (alpha,q) -- r `subS` (alpha `smulS` q) rho' = r' `dotS` r' beta = rho' / rho p' = r' `addmulS` (beta,p) -- r' `addS` (beta `smulS` p) testS c = loopS (fromList (map fromList a)) (fromList p) (fromList r) (fromList z) rho c ------------------------ array-based, update-in-place code type VectorA s = STUArray s Int Elem type MatrixA s = STUArray s (Int,Int) Elem modArray !a !i f = unsafeRead a i >>= (unsafeWrite a i . f) (+*=),(-*=) :: VectorA s -> (Elem,VectorA s) -> ST s () v1 +*= (x,v2) = l v1 x v2 1 where l !v1 !x !v2 !i = unless (i>n) $ do { a<-unsafeRead v1 i; b<-unsafeRead v2 i; unsafeWrite v1 i $! (a+(x*b)); l v1 x v2 (i+1) } v1 -*= (x,v2) = l v1 x v2 1 where l !v1 !x !v2 !i = unless (i>n) $ do { a<-unsafeRead v1 i; b<-unsafeRead v2 i; unsafeWrite v1 i $! (a-(x*b)); l v1 x v2 (i+1) } (*+=) :: (Elem,VectorA s) -> VectorA s -> ST s () (x,v1) *+= v2 = x `seq` v1 `seq` v2 `seq` l 1 where l !i = unless (i>n) $ do { e2 <- unsafeRead v2 i; modArray v1 i ((e2+).(x*)); l (i+1) } dotA :: VectorA s -> VectorA s -> ST s Elem v1 `dotA` v2 = v1 `seq` v2 `seq` l 1 0 where l !i !s | i>n = return s l i s = do { a<-unsafeRead v1 i; b<-unsafeRead v2 i;l (i+1) $! (s+a*b) } matA :: MatrixA s -> VectorA s -> VectorA s -> ST s (VectorA s) (m `matA` v) tmp = m `seq` v `seq` tmp `seq` l 1 1 0 where l !i !j !s | i>n = return tmp l i j s | j>n = unsafeWrite tmp i s >> l (i+1) 1 0 l i j s = do a<-unsafeRead m $! (i*(n+1)+j) b<-unsafeRead v j l i (j+1) $! (s+a*b) loopA a p r z q rho n | stop z n = return z loopA a p r z q rho n = do (a `matA` p) q alpha <- fmap (rho/) (p `dotA` q) z +*= (alpha,p) r -*= (alpha,q) rho'<- r `dotA` r let beta = rho' / rho (beta,p) *+= r loopA a p r z q rho' (n-1) testA c = runSTUArray (do aA <- newListArray ((0,0),(n,n)) (concat [[if i==j then 1 else 0|i<-[0..n]]|j<-[0..n]]) pA <- newListArray (0,n) (0:p) rA <- newListArray (0,n) (0:r) zA <- newListArray (0,n) (0:z) qA <- newArray (0,n) 0 loopA aA pA rA zA qA rho c ) ----------------------- main = do (version:count:_) <- getArgs case version of "list" -> print $ test (read count) -- 100000: 2m3s "listS" -> print $ testS (read count) -- 100000: 12s "array" -> print $ testA (read count) -- 100000: 33s "check" -> do let c = read count l = test c ts = toList $ testS c ea = tail $ elems $ testA c diff a b = maximum $ map abs $ zipWith (-) a b putStrLn $ "list==listS? "++show (l==ts)++" "++show (diff l ts) putStrLn $ "list==array? "++show (l==ea)++" "++show (diff l ea) putStrLn $ "listS==array? "++show (ts==ea)++" "++show (diff ts ea)