bracket, (un)block and MonadIO

sebc@macs.hw.ac.uk sebc@macs.hw.ac.uk
Thu, 4 Sep 2003 19:20:44 +0100


--QKdGvSO+nmPlgiQ/
Content-Type: multipart/mixed; boundary="7JfCtLOvnd9MIVvH"
Content-Disposition: inline


--7JfCtLOvnd9MIVvH
Content-Type: text/plain; charset=us-ascii
Content-Disposition: inline
Content-Transfer-Encoding: quoted-printable

Hi,

On Thu, Sep 04, 2003 at 10:45:17AM +0100, Simon Marlow wrote:
> >
> > Would anything prevent block, unblock, bracket (and other similar
> > functions working on IO actions) from being generalized to all
> > intances of MonadIO?
>=20
> I'm afraid I can't see a way to generalise the types of block and
> unblock, since they are based on underlying primitives that really do
> have type (IO a -> IO a).  Perhaps if your monad is isomorphic to IO, it
> could be done, but otherwise I don't think it's possible.  Unless I'm
> missing something.

It can be done by adding the right methods to the MonadIO class, with
these rank-2 types:

>     liftIO' :: (forall a. IO a -> IO a) -> m a -> m a
>     liftIO'' :: (forall a. IO a -> (b -> IO a) -> IO a) -> m a -> (b -> m=
 a) -> m a

See the attached patch for the details.

This solution is maybe a bit ugly, since these methods are fairly
specific (liftIO' is needed to generalize block and unblock, and
liftIO'' is needed to generalize catchException).

But it does allow one to use catch/bracket/etc with monads built on
top of IO with monad transformers, which is quite nice:

> import Control.Monad.Reader
> =20
> type M =3D ReaderT Int IO
>=20
> main' :: M ()
> main' =3D catch
>     (do n <- ask
>         liftIO (putStrLn (show n)))
>     (\ e -> return ())
>
> main :: IO ()
> main =3D
>     runReaderT main' 1

--=20
Sebastien

P.S.: The patch moves the MonadIO class to GHC.IOBase, which already
contains a function called ``liftIO'', but which does not appear to be
used anywhere; I just commented it out...

--7JfCtLOvnd9MIVvH
Content-Type: text/plain; charset=us-ascii
Content-Disposition: attachment; filename="MonadIO.patch"
Content-Transfer-Encoding: quoted-printable

diff -r -u ghc-6.0.1.orig/libraries/base/Control/Exception.hs ghc-6.0.1/lib=
raries/base/Control/Exception.hs
--- ghc-6.0.1.orig/libraries/base/Control/Exception.hs	2003-05-12 11:16:27.=
000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Exception.hs	2003-09-04 18:14:06.00000=
0000 +0100
@@ -112,7 +112,7 @@
 import GHC.Base		( assert )
 import GHC.Exception 	as ExceptionBase hiding (catch)
 import GHC.Conc		( throwTo, ThreadId )
-import GHC.IOBase	( IO(..) )
+import GHC.IOBase	( IO(..), MonadIO(..) )
 #endif
=20
 #ifdef __HUGS__
@@ -173,9 +173,10 @@
 -- "Control.Exception", or importing
 -- "Control.Exception" qualified, to avoid name-clashes.
=20
-catch  	:: IO a 		-- ^ The computation to run
-  	-> (Exception -> IO a)	-- ^ Handler to invoke if an exception is raised
-  	-> IO a		=09
+catch  	:: MonadIO m
+	=3D> m a			-- ^ The computation to run
+  	-> (Exception -> m a)	-- ^ Handler to invoke if an exception is raised
+  	-> m a
 catch =3D  ExceptionBase.catchException
=20
 -- | The function 'catchJust' is like 'catch', but it takes an extra
@@ -370,10 +371,11 @@
 -- > withFile name =3D bracket (openFile name) hClose
 --
 bracket=20
-	:: IO a		-- ^ computation to run first (\"acquire resource\")
-	-> (a -> IO b)  -- ^ computation to run last (\"release resource\")
-	-> (a -> IO c)	-- ^ computation to run in-between
-	-> IO c		-- returns the value from the in-between computation
+	:: MonadIO m
+        =3D> m a		-- ^ computation to run first (\"acquire resource\")
+	-> (a -> m b)	-- ^ computation to run last (\"release resource\")
+	-> (a -> m c)	-- ^ computation to run in-between
+	-> m c		-- returns the value from the in-between computation
 bracket before after thing =3D
   block (do
     a <- before=20
@@ -383,7 +385,7 @@
     after a
     return r
  )
-  =20
+
=20
 -- | A specialised variant of 'bracket' with just a computation to run
 -- afterward.
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/Cont.hs ghc-6.0.1/li=
braries/base/Control/Monad/Cont.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/Cont.hs	2003-05-14 18:31:47=
=2E000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Monad/Cont.hs	2003-09-04 18:46:10.0000=
00000 +0100
@@ -77,6 +77,8 @@
=20
 instance (MonadIO m) =3D> MonadIO (ContT r m) where
 	liftIO =3D lift . liftIO
+	liftIO' f m =3D ContT $ \ k -> liftIO' f (runContT m k)
+	liftIO'' f m1 m2 =3D ContT $ \ k -> liftIO'' f (runContT m1 k) (\ e -> ru=
nContT (m2 e) k)
=20
 instance (MonadReader r' m) =3D> MonadReader r' (ContT r m) where
 	ask       =3D lift ask
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/Error.hs ghc-6.0.1/l=
ibraries/base/Control/Monad/Error.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/Error.hs	2003-05-14 18:31:4=
7.000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Monad/Error.hs	2003-09-04 18:45:18.000=
000000 +0100
@@ -167,6 +167,8 @@
=20
 instance (Error e, MonadIO m) =3D> MonadIO (ErrorT e m) where
 	liftIO =3D lift . liftIO
+	liftIO' f m =3D ErrorT $ liftIO' f (runErrorT m)
+	liftIO'' f m1 m2 =3D ErrorT $ liftIO'' f (runErrorT m1) (\ e -> runErrorT=
 (m2 e))
=20
 instance (Error e, MonadReader r m) =3D> MonadReader r (ErrorT e m) where
 	ask       =3D lift ask
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/List.hs ghc-6.0.1/li=
braries/base/Control/Monad/List.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/List.hs	2003-05-14 18:31:47=
=2E000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Monad/List.hs	2003-09-04 18:43:48.0000=
00000 +0100
@@ -61,6 +61,8 @@
=20
 instance (MonadIO m) =3D> MonadIO (ListT m) where
 	liftIO =3D lift . liftIO
+	liftIO' f m =3D ListT $ liftIO' f (runListT m)
+	liftIO'' f m1 m2 =3D ListT $ liftIO'' f (runListT m1) (\ e -> runListT (m=
2 e))
=20
 instance (MonadReader s m) =3D> MonadReader s (ListT m) where
 	ask       =3D lift ask
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/RWS.hs ghc-6.0.1/lib=
raries/base/Control/Monad/RWS.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/RWS.hs	2003-05-14 18:31:47.=
000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Monad/RWS.hs	2003-09-04 18:43:08.00000=
0000 +0100
@@ -142,7 +142,8 @@
=20
 instance (Monoid w, MonadIO m) =3D> MonadIO (RWST r w s m) where
 	liftIO =3D lift . liftIO
-
+	liftIO' f m =3D RWST $ \ r s -> liftIO' f (runRWST m r s)
+	liftIO'' f m1 m2 =3D RWST $ \ r s -> liftIO'' f (runRWST m1 r s) (\ e -> =
runRWST (m2 e) r s)
=20
 evalRWST :: (Monad m) =3D> RWST r w s m a -> r -> s -> m (a, w)
 evalRWST m r s =3D do
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/Reader.hs ghc-6.0.1/=
libraries/base/Control/Monad/Reader.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/Reader.hs	2003-05-14 18:31:=
47.000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Monad/Reader.hs	2003-09-04 18:50:42.00=
0000000 +0100
@@ -130,6 +130,8 @@
=20
 instance (MonadIO m) =3D> MonadIO (ReaderT r m) where
 	liftIO =3D lift . liftIO
+	liftIO' f m =3D ReaderT $ \ r -> liftIO' f (runReaderT m r)
+	liftIO'' f m1 m2 =3D ReaderT $ \ r -> liftIO'' f (runReaderT m1 r) (\ e -=
> runReaderT (m2 e) r)
=20
 mapReaderT :: (m a -> n b) -> ReaderT w m a -> ReaderT w n b
 mapReaderT f m =3D ReaderT $ f . runReaderT m
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/State.hs ghc-6.0.1/l=
ibraries/base/Control/Monad/State.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/State.hs	2003-05-14 18:31:4=
7.000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Monad/State.hs	2003-09-04 18:39:29.000=
000000 +0100
@@ -211,6 +211,8 @@
=20
 instance (MonadIO m) =3D> MonadIO (StateT s m) where
 	liftIO =3D lift . liftIO
+	liftIO' f m =3D StateT $ \ s -> liftIO' f (runStateT m s)
+	liftIO'' f m1 m2 =3D StateT $ \ s -> liftIO'' f (runStateT m1 s) (\ e -> =
runStateT (m2 e) s)
=20
 instance (MonadReader r m) =3D> MonadReader r (StateT s m) where
 	ask       =3D lift ask
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/Trans.hs ghc-6.0.1/l=
ibraries/base/Control/Monad/Trans.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/Trans.hs	2003-03-08 19:02:3=
9.000000000 +0000
+++ ghc-6.0.1/libraries/base/Control/Monad/Trans.hs	2003-09-04 16:59:31.000=
000000 +0100
@@ -20,12 +20,13 @@
=20
 module Control.Monad.Trans (
 	MonadTrans(..),
-	MonadIO(..), =20
+	MonadIO(..),
   ) where
=20
 import Prelude
=20
 import System.IO
+import GHC.IOBase ( MonadIO(..) )
=20
 -- -----------------------------------------------------------------------=
----
 -- MonadTrans class
@@ -36,9 +37,3 @@
=20
 class MonadTrans t where
 	lift :: Monad m =3D> m a -> t m a
-
-class (Monad m) =3D> MonadIO m where
-	liftIO :: IO a -> m a
-
-instance MonadIO IO where
-	liftIO =3D id
diff -r -u ghc-6.0.1.orig/libraries/base/Control/Monad/Writer.hs ghc-6.0.1/=
libraries/base/Control/Monad/Writer.hs
--- ghc-6.0.1.orig/libraries/base/Control/Monad/Writer.hs	2003-05-14 18:31:=
47.000000000 +0100
+++ ghc-6.0.1/libraries/base/Control/Monad/Writer.hs	2003-09-04 18:39:11.00=
0000000 +0100
@@ -142,6 +142,8 @@
=20
 instance (Monoid w, MonadIO m) =3D> MonadIO (WriterT w m) where
 	liftIO =3D lift . liftIO
+	liftIO' f m =3D WriterT $ liftIO' f (runWriterT m)
+	liftIO'' f m1 m2 =3D WriterT $ liftIO'' f (runWriterT m1) (\ e -> runWrit=
erT (m2 e))
=20
 instance (Monoid w, MonadReader r m) =3D> MonadReader r (WriterT w m) where
 	ask       =3D lift ask
diff -r -u ghc-6.0.1.orig/libraries/base/GHC/Exception.lhs ghc-6.0.1/librar=
ies/base/GHC/Exception.lhs
--- ghc-6.0.1.orig/libraries/base/GHC/Exception.lhs	2003-01-16 14:38:40.000=
000000 +0000
+++ ghc-6.0.1/libraries/base/GHC/Exception.lhs	2003-09-04 18:49:15.00000000=
0 +0100
@@ -44,10 +44,11 @@
 have to work around that in the definition of catchException below).
=20
 \begin{code}
-catchException :: IO a -> (Exception -> IO a) -> IO a
-catchException (IO m) k =3D  IO $ \s -> catch# m (\ex -> unIO (k ex)) s
+catchException :: MonadIO m =3D> m a -> (Exception -> m a) -> m a
+catchException =3D liftIO'' catchException'
+    where catchException' (IO m) k =3D IO $ \s -> catch# m (\ex -> unIO (k=
 ex)) s
=20
-catch           :: IO a -> (IOError -> IO a) -> IO a=20
+catch           :: MonadIO m =3D> m a -> (IOError -> m a) -> m a
 catch m k	=3D  catchException m handler
   where handler (IOException err)   =3D k err
 	handler other               =3D throw other
@@ -69,17 +70,19 @@
 -- no need to worry about re-enabling asynchronous exceptions; that is
 -- done automatically on exiting the scope of
 -- 'block'.
-block :: IO a -> IO a
+block :: MonadIO m =3D> m a -> m a
=20
 -- | To re-enable asynchronous exceptions inside the scope of
 -- 'block', 'unblock' can be
 -- used.  It scopes in exactly the same way, so on exit from
 -- 'unblock' asynchronous exception delivery will
 -- be disabled again.
-unblock :: IO a -> IO a
+unblock :: MonadIO m =3D> m a -> m a
=20
-block (IO io) =3D IO $ blockAsyncExceptions# io
-unblock (IO io) =3D IO $ unblockAsyncExceptions# io
+block =3D liftIO' block'
+    where block' (IO io) =3D IO $ blockAsyncExceptions# io
+unblock =3D liftIO' block'
+    where block' (IO io) =3D IO $ unblockAsyncExceptions# io
 \end{code}
=20
=20
diff -r -u ghc-6.0.1.orig/libraries/base/GHC/IOBase.lhs ghc-6.0.1/libraries=
/base/GHC/IOBase.lhs
--- ghc-6.0.1.orig/libraries/base/GHC/IOBase.lhs	2003-05-23 12:05:33.000000=
000 +0100
+++ ghc-6.0.1/libraries/base/GHC/IOBase.lhs	2003-09-04 18:25:35.000000000 +=
0100
@@ -88,11 +88,21 @@
     m >>=3D k     =3D bindIO m k
     fail s	=3D failIO s
=20
+class (Monad m) =3D> MonadIO m where
+    liftIO :: IO a -> m a
+    liftIO' :: (forall a. IO a -> IO a) -> m a -> m a
+    liftIO'' :: (forall a. IO a -> (b -> IO a) -> IO a) -> m a -> (b -> m =
a) -> m a
+
+instance MonadIO IO where
+    liftIO =3D id
+    liftIO' =3D id
+    liftIO'' =3D id
+
 failIO :: String -> IO a
 failIO s =3D ioError (userError s)
=20
-liftIO :: IO a -> State# RealWorld -> STret RealWorld a
-liftIO (IO m) =3D \s -> case m s of (# s', r #) -> STret s' r
+-- liftIO :: IO a -> State# RealWorld -> STret RealWorld a
+-- liftIO (IO m) =3D \s -> case m s of (# s', r #) -> STret s' r
=20
 bindIO :: IO a -> (a -> IO b) -> IO b
 bindIO (IO m) k =3D IO ( \ s ->

--7JfCtLOvnd9MIVvH--

--QKdGvSO+nmPlgiQ/
Content-Type: application/pgp-signature
Content-Disposition: inline

-----BEGIN PGP SIGNATURE-----
Version: GnuPG v1.2.3 (GNU/Linux)

iD8DBQE/V4J8vtNcI2aw9NwRAoOZAJ9UDqSkOhhA1XirxtZc3xrhw6Uc9wCgoNlQ
KMREfOCdN7tUuI7GJ7uT2RM=
=PdIr
-----END PGP SIGNATURE-----

--QKdGvSO+nmPlgiQ/--