[Haskell-cafe] folds with escapes

Donald Bruce Stewart dons at cse.unsw.edu.au
Wed Jul 4 20:35:13 EDT 2007


dm.maillists:
> On Thursday 05 July 2007 11:20, Michael Vanier wrote:
> > Again, I'm sure this has been done before (and no doubt better); I'd
> > appreciate any pointers to previous work along these lines.
> 
> Takusen is, if I recall correctly, based around a generalised fold supporting 
> accumulation and early termination.  Maybe have a look at that.
> 

Streams are a similar idea, a generalised unfold supporting early
termination, skipping, and accumulation. Useful for coding up lots of
list functions with the same underlying type, so you can fuse them with
a single rule.

A data type to encode this unfold:

    data Stream a = forall s.  Stream !(s -> Step a s)  -- ^ a stepper function
                                      !s                -- ^ an initial state

    data Step a s = Yield a !s
                  | Skip    !s
                  | Done

Give a way to introduce and remove these guys:

    stream :: [a] -> Stream a
    stream xs0 = Stream next xs0
      where
        next []     = Done
        next (x:xs) = Yield x xs

    unstream :: Stream a -> [a]
    unstream (Stream next s0) = unfold_unstream s0
      where
        unfold_unstream !s = case next s of
          Done       -> []
          Skip    s' ->     unfold_unstream s'
          Yield x s' -> x : unfold_unstream s'

We can roll a fair few list functions:

    -- folds
    foldl :: (b -> a -> b) -> b -> Stream a -> b
    foldl f z0 (Stream next s0) = loop_foldl z0 s0
      where
        loop_foldl z !s = case next s of
          Done       -> z
          Skip    s' -> loop_foldl z s'
          Yield x s' -> loop_foldl (f z x) s'

    foldr :: (a -> b -> b) -> b -> Stream a -> b
    foldr f z (Stream next s0) = loop_foldr s0
      where
        loop_foldr !s = case next s of
          Done       -> z
          Skip    s' -> expose s' $ loop_foldr s'
          Yield x s' -> expose s' $ f x (loop_foldr s')

    -- short circuiting:
    any :: (a -> Bool) -> Stream a -> Bool
    any p (Stream next s0) = loop_any s0
      where
        loop_any !s = case next s of
          Done                   -> False
          Skip    s'             -> loop_any s'
          Yield x s' | p x       -> True
                     | otherwise -> loop_any s'

    -- maps
    map :: (a -> b) -> Stream a -> Stream b
    map f (Stream next0 s0) = Stream next s0
      where
        next !s = case next0 s of
            Done       -> Done
            Skip    s' -> Skip        s'
            Yield x s' -> Yield (f x) s'

    -- filters
    filter :: (a -> Bool) -> Stream a -> Stream a
    filter p (Stream next0 s0) = Stream next s0
      where
        next !s = case next0 s of
          Done                   -> Done
          Skip    s'             -> Skip    s'
          Yield x s' | p x       -> Yield x s'
                     | otherwise -> Skip    s'

    -- taking
    takeWhile :: (a -> Bool) -> Stream a -> Stream a
    takeWhile p (Stream next0 s0) = Stream next s0
      where
        next !s = case next0 s of
          Done                   -> Done
          Skip    s'             -> Skip s'
          Yield x s' | p x       -> Yield x s'
                     | otherwise -> Done

    -- dropping
    dropWhile :: (a -> Bool) -> Stream a -> Stream a
    dropWhile p (Stream next0 s0) = Stream next (S1 :!: s0)
      where
        next (S1 :!: s)  = case next0 s of
          Done                   -> Done
          Skip    s'             -> Skip    (S1 :!: s')
          Yield x s' | p x       -> Skip    (S1 :!: s')
                     | otherwise -> Yield x (S2 :!: s')

        next (S2 :!: s) = case next0 s of
          Done       -> Done
          Skip    s' -> Skip    (S2 :!: s')
          Yield x s' -> Yield x (S2 :!: s')

    -- zips
    zipWith :: (a -> b -> c) -> Stream a -> Stream b -> Stream c
    zipWith f (Stream next0 sa0) (Stream next1 sb0) 
            = Stream next (sa0 :!: sb0 :!: Nothing)
      where
        next (sa :!: sb :!: Nothing)     = case next0 sa of
            Done        -> Done
            Skip    sa' -> Skip (sa' :!: sb :!: Nothing)
            Yield a sa' -> Skip (sa' :!: sb :!: Just (L a))

        next (sa' :!: sb :!: Just (L a)) = case next1 sb of
            Done        -> Done
            Skip    sb' -> Skip          (sa' :!: sb' :!: Just (L a))
            Yield b sb' -> Yield (f a b) (sa' :!: sb' :!: Nothing)

    -- concat
    concat :: Stream [a] -> [a]
    concat (Stream next s0) = loop_concat_to s0
      where
        loop_concat_go []     !s =     loop_concat_to    s
        loop_concat_go (x:xs) !s = x : loop_concat_go xs s

        loop_concat_to !s = case next s of
          Done        -> []
          Skip     s' -> loop_concat_to    s'
          Yield xs s' -> loop_concat_go xs s'

The nice thing is that once all your functions are in terms of these, usually
non-recursive guys, and you have a rewrite rule:

    {-# RULES

    "STREAM stream/unstream fusion" forall s.
        stream (unstream s) = s

      #-}

GHC will do all the loop fusion for you. Particularly nice with strict arrays,
since that'll eliminate one O(n) array allocation per function.

See http://www.cse.unsw.edu.au/~dons/streams.html

Cheers,
  Don


More information about the Haskell-Cafe mailing list