{-# LANGUAGE Rank2Types #-}

module SafeMap (Tagged, unTag, fromMap, insert, lookup, check) where

import qualified Data.Map as M
import Data.Map (Map)
import Data.Maybe
import Prelude hiding  (lookup)

newtype Tagged a x = Tag { unTag :: x }
    deriving (Show)
retag :: Tagged t1 x -> Tagged t2 x
retag (Tag x) = (Tag x)

fromMap :: Map k v -> (forall t. Tagged t (Map k v) -> c) -> c
fromMap m f = (f (Tag m))

insert :: Ord k => k -> v -> Tagged t1 (Map k v) ->
    (forall t2. Tagged t2 (Map k v) -> Tagged t2 k -> (Tagged t1 k -> Tagged t2 k) -> c) -> 
    c
insert k v (Tag m) f = f (Tag (M.insert k v m)) (Tag k) retag 

lookup :: Ord k => Tagged t k -> Tagged t (Map k v) -> v
lookup (Tag k) (Tag m) = fromJust $ M.lookup k m

check :: Ord k => k -> Tagged t1 (Map k v) ->
    (forall t2. Tagged t2 (Map k v) -> Maybe (Tagged t2 k) -> (Tagged t1 k -> Tagged t2 k) -> c) -> 
    c
check k (Tag m) f = f (Tag m) (if M.member k m then Just (Tag k) else Nothing) retag
