Problem with backtracking monad transformer

Guest, Simon
Thu, 30 Jan 2003 13:55:50 -0000

I'm trying to make a backtracking state monad using Ralf Hinze's
backtracking monad transformer.  My problem is that it won't backtrack
very far.

Suppose I try ( a >> b ) `mplus` c.

If b fails, it should try c, but it doesn't rewind past a.

My sample code is below.

GHCI> c [0,1] match_1            -- (1 or 0) then 1, OK
GHCI> c [1,0] match_2            -- (1 then 0) or (1,1), OK
GHCI> c [1,1] match_2            -- (1 then 0) or (1,1), fails

What have I misunderstood?

-- backtracking state monad
-- requires -fglasgow-exts

import qualified Monad as M
import qualified Control.Monad.Trans as MT

-- turn tracing on and off by uncommenting just one of the following =
import Debug.Trace( trace )
--trace s x =3D x

-- Ralf Hinze's efficient backtracking monad transformer

newtype BACKTR m a
  =3D BACKTR { mkBACKTR :: (forall b. (a -> m b -> m b) -> m b -> m b) =

instance (Monad m) =3D> Monad (BACKTR m) where
  return a =3D BACKTR (\c -> c a)
  m >>=3D k  =3D BACKTR (\c -> mkBACKTR m (\a -> mkBACKTR (k a) c))

-- We don't use a Backtr class, but do it with the MonadPlus class,
-- mzero is false (fail),
-- mplus is =A6 (orelse)
instance (Monad m) =3D> M.MonadPlus (BACKTR m) where
  mzero         =3D BACKTR (\c -> id)
  m1 `mplus` m2 =3D BACKTR (\c -> mkBACKTR m1 c . mkBACKTR m2 c)

-- standard MonadTrans class has lift for promote, and doesn't have =
instance MT.MonadTrans BACKTR where
  lift m =3D BACKTR (\c f -> m >>=3D \a -> c a f)

observe :: (Monad m) =3D> BACKTR m a -> m a
observe m =3D mkBACKTR m (\a f -> return a) (fail "false")

-- State Monad

data SM st a =3D SM (st -> (a,st)) -- The monadic type

instance Monad (SM st) where
   -- defines state propagation
   SM c1 >>=3D fc2 =3D SM (\s0 -> let (r,s1) =3D c1 s0
                                  SM c2 =3D fc2 r in
                                 c2 s1)
   return k =3D SM (\s -> (k,s))

-- extracts the state from the monad
readSM :: SM st st
readSM =3D SM (\s -> (s,s))

-- updates the state of the monad
updateSM :: (st -> st) -> SM st () -- alters the state
updateSM f =3D SM (\s -> ((), f s))

-- run a computation in the SM monad
runSM :: st -> SM st a -> (a,st)
runSM s0 (SM c) =3D c s0

-- backtracking state monad
type NDSM st a =3D BACKTR (SM st) a

readNDSM :: NDSM st st
readNDSM =3D MT.lift readSM

updateNDSM :: (st -> st) -> NDSM st ()
updateNDSM f =3D MT.lift (updateSM f)

--run a computation in the NDSM monad
runNDSM :: st -> NDSM st a -> (a,st)
runNDSM s0 m =3D runSM s0 (observe m)

-- the state
type Bit =3D Int

data CState =3D CState
             { ok                :: Bool,
               remaining_data    :: [Bit],
               history           :: [String] -- log, kept in reverse
             } deriving Show

initState xs =3D CState True xs []

-- prepend a message in the log
logit :: CState -> String -> CState
logit s logmsg =3D s { history =3D logmsg : (history s) }

-- matching action
match_bits :: [Bit] -> NDSM CState ()
match_bits xs =3D do
   s <- readNDSM
   let s' =3D logit s ("attempt match_bits " ++ show xs
                     ++ " remaining: " ++ show (remaining_data s))

       s'' =3D if xs =3D=3D take (length xs) (remaining_data s')
                s' { remaining_data =3D drop (length xs) =
(remaining_data s') }
                s' { ok =3D False }
   if ok s''
      then updateNDSM (\s -> s'')
      else trace (unlines $ "MATCH FAILED":(reverse $ history s'')) =

-- test routines

-- just fine
match_1 =3D
   (match_bits [1] `M.mplus` match_bits [0])=20
   >> match_bits [1]

-- this one only rewinds past the [0] attempt, not the [1] attempt
match_2 =3D
   (  (match_bits [1] >> match_bits [0]) 
      `M.mplus` match_bits [1, 1] )

c :: [Bit] -> NDSM CState () -> ([Bit], [String])
c h hspec =3D=20
   let (v, s) =3D runNDSM (initState h) hspec in
   case (ok s) of True -> ([], "ok":(reverse $ history s))
                  _    -> ([(negate)1], ["fail"])

