[Haskell-cafe] memorize function with number parameterized types in GHC

oleg at okmij.org oleg at okmij.org
Wed Nov 9 07:05:29 CET 2011


It seems GHC can be pursuaded to do proper specialization and
memoization. We can see that, first, using trace:

> class (Ord b, Integral b, Num b, Bits b) => PositiveN a b where
>     p2num :: Dep a b
>
> instance (Ord b, Integral b, Num b, Bits b) => PositiveN One b where
>     p2num = trace "p2num 1" $ Dep 1

If we define 

> tttt :: PositiveN p Int => ModP2 p Int -> ModP2 p Int
> tttt x = x * x * x * x
>
> ssss :: PositiveN p Int => ModP2 p Int -> ModP2 p Int
> ssss x = x + x + x + x
>
> test2 = tttt x + ssss x
>  where x = 1 :: ModP2 (D1 One) Int

and run test2 in GHCi, we see

   *Math.Montg> test2
   p2num 1 
   1+p2num 1
   4Z

That is, p2num was invoked only twice; I guess one invocation is for
converting 1 to a modular number, and the other invocation was used
for all 4 additions and 3 multiplications, distributed across
multiple functions. 

Looking at the core is a better test:
	ghc -O2 -c -ddump-prep Montg.hs


Math.Montg.test2
  :: Math.Montg.ModP2 (Math.Montg.D1 Math.Montg.One) GHC.Types.Int
[GblId, Str=DmdType]
Math.Montg.test2 =
  case Math.Montg.test14 of _ { GHC.Types.I# ww_s1Ww ->
  case Math.Montg.$w$stttt ww_s1Ww of ww1_s1WB { __DEFAULT ->
  case Math.Montg.$w$sssss ww_s1Ww of ww2_s1WC { __DEFAULT ->
  case Math.Montg.$fNumModP2_$spmask
       `cast` (Math.Montg.NTCo:Dep
                 (Math.Montg.D1 Math.Montg.One) GHC.Types.Int
               :: Math.Montg.Dep (Math.Montg.D1 Math.Montg.One) GHC.Types.Int
                    ~
                  GHC.Types.Int)
  of _ { GHC.Types.I# y#_s1WF ->
  case GHC.Prim.int2Word# y#_s1WF of sat_s2m8 { __DEFAULT ->
  case GHC.Prim.+# ww1_s1WB ww2_s1WC of sat_s2m9 { __DEFAULT ->
  case GHC.Prim.int2Word# sat_s2m9 of sat_s2ma { __DEFAULT ->
  case GHC.Prim.and# sat_s2ma sat_s2m8 of sat_s2mb { __DEFAULT ->
  case GHC.Prim.word2Int# sat_s2mb of sat_s2mc { __DEFAULT ->
  (GHC.Types.I# sat_s2mc)

As you can see, the program used Math.Montg.$fNumModP2_$spmask. 
Here is thus definition in core:


Math.Montg.$fNumModP2_$spmask
  :: Math.Montg.Dep (Math.Montg.D1 Math.Montg.One) GHC.Types.Int
[GblId, Str=DmdType]
Math.Montg.$fNumModP2_$spmask =
  case Math.Montg.$wbitLen
         @ GHC.Types.Int
         GHC.Base.$fEqInt
         GHC.Num.$fNumInt_$cfromInteger

You only need to look at the type to see that GHC has specialized
pmask to the particular instance Dep (D1 One) Int -- just as we
wanted.

Here is the prefix of your code with my modifications

module Math.Montg where

import Data.Bits
import Debug.Trace

newtype Dep a b = Dep { unDep :: b }

data One = One

data D0 a = D0 a
data D1 a = D1 a

class (Ord b, Integral b, Num b, Bits b) => PositiveN a b where
    p2num :: Dep a b

instance (Ord b, Integral b, Num b, Bits b) => PositiveN One b where
    p2num = trace "p2num 1" $ Dep 1

instance PositiveN p b => PositiveN (D0 p) b where
    p2num = Dep (unDep (p2num :: Dep p b) * 2)

instance PositiveN p b => PositiveN (D1 p) b where
    p2num = Dep (unDep (p2num :: Dep p b) * 2 + 1)

ctz :: (Num a, Bits a) => a -> Int
ctz x | testBit x 0 = 0
      | otherwise   = ctz (x `shiftR` 1)

bitLen :: (Num a, Bits a) => a -> Int
bitLen 0 = 0
bitLen x = bitLen (x `shiftR` 1) + 1

pmask :: forall p b. (PositiveN p b) => Dep p b
pmask | bitLen n == ctz n + 1 = Dep (bit (ctz n) - 1)
      | otherwise             = Dep (bit (bitLen n) - 1)
  where
    n = unDep (p2num :: Dep p b)

addmod2 :: forall p b. (PositiveN p b) => Dep p b -> Dep p b -> Dep p b
addmod2 (Dep a) (Dep b) = Dep ((a + b) .&. unDep (pmask :: Dep p b))
{-# INLINE addmod2 #-}

submod2 :: forall p b. (PositiveN p b) => p -> b -> b -> b
submod2 _ a b = (a - b) .&. unDep (pmask :: Dep p b)
{-# INLINE submod2 #-}

mulmod2 :: forall p b. (PositiveN p b) => Dep p b -> Dep p b -> Dep p b
mulmod2 (Dep a) (Dep b) = Dep $ (a * b) .&. unDep (pmask :: Dep p b)
{-# INLINE mulmod2 #-}

addmod :: forall p b. (PositiveN p b) => p -> b -> b -> b
addmod _ a b | a + b >= p = a + b - p
             | otherwise  = a + b
  where
    p = unDep (p2num :: Dep p b)
{-# INLINE addmod #-}

submod :: forall p b. (PositiveN p b) => p -> b -> b -> b
submod _ a b | a < b     = a + unDep (p2num :: Dep p b) - b
             | otherwise = a - b
{-# INLINE submod #-}

-- | extended euclidean algorithm
-- `extgcd a b` returns `(g, x, y)` s.t. `g = gcd a b` and `ax + by = g`
--
extgcd :: Integral a => a -> a -> (a, a, a)
extgcd a b | a < 0 = let (g, x, y) = extgcd (-a) b in (g, -x, y)
extgcd a b | b < 0 = let (g, x, y) = extgcd a (-b) in (g, x, -y)
extgcd a 0 = (a, 1, 0)
extgcd a b = let
                 (adivb, amodb) = a `divMod` b
                 (g, y, x) = extgcd b amodb
                 --   (a - a / b * b) * x + b * y
                 -- = a * x - a / b * b * x + b * y
                 -- = a * x + (y - a / b * x) * b
             in
                 (g, x, y - adivb * x)

newtype PositiveN p a => ModP2 p a = ModP2 { unModP2 :: a } deriving Eq

instance PositiveN p a => Show (ModP2 p a) where
    show (ModP2 r) = show r ++ "+" ++ show (unDep (pmask :: Dep p a) + 1) ++ "Z"

-- In principle, Dep and ModP2 could be the same ...
-- Anyway, they are newtype....
modP2_Dep :: PositiveN p a => ModP2 p a -> Dep p a
modP2_Dep (ModP2 a) = Dep a
dep_ModP2 :: PositiveN p a => Dep p a -> ModP2 p a
dep_ModP2 (Dep a) = ModP2 a

instance PositiveN p a => Num (ModP2 p a) where
    a + b = dep_ModP2 $ addmod2 (modP2_Dep a) (modP2_Dep b)
    ModP2 a - ModP2 b = ModP2 $ submod2 (undefined :: p) a b
    ModP2 a * ModP2 b = ModP2 $ unDep $ mulmod2 (Dep a :: Dep p a) (Dep b::Dep p a)
    fromInteger x = ModP2 (fromInteger x `mod` (unDep (pmask :: Dep p a) + 1))
    abs = id
    signum = const 1

-- .... 
-- A few tests

test1 = map (\x -> x * x * x * x) l1
 where
 l1 :: [ModP2 (D1 (D1 One)) Int]
 l1 =  [10,11,12,13,14,15]

ttt :: ModP2 (D1 (D1 One)) Int -> ModP2 (D1 (D1 One)) Int
ttt x = x * x * x * x

tttt :: PositiveN p Int => ModP2 p Int -> ModP2 p Int
tttt x = x * x * x * x

ssss :: PositiveN p Int => ModP2 p Int -> ModP2 p Int
ssss x = x + x + x + x

test2 = tttt x + ssss x
 where x = 1 :: ModP2 (D1 One) Int




More information about the Haskell-Cafe mailing list