Hi all,<br><br>In preparation for students working on concurrent data structures GSOC(s), I wanted to make sure they could count on CAS for array elements as well as IORefs. The following patch represents my first attempt:<br>
<br> <a href="https://github.com/rrnewton/ghc/commit/18ed460be111b47a759486677960093d71eef386">https://github.com/rrnewton/ghc/commit/18ed460be111b47a759486677960093d71eef386</a><br><br>It passes a simple test [Appendix 2 below], but I am very unsure as to whether the GC write barrier is correct. Could someone do a code-review on the following few lines of CMM:<br>
<br><font class="Apple-style-span" face="'courier new', monospace"> if (GET_INFO(arr) == stg_MUT_ARR_PTRS_CLEAN_info) {<br> SET_HDR(arr, stg_MUT_ARR_PTRS_DIRTY_info, CCCS);<br> len = StgMutArrPtrs_ptrs(arr);<br>
// The write barrier. We must write a byte into the mark table:<br> I8[arr + SIZEOF_StgMutArrPtrs + WDS(len) + (ind >> MUT_ARR_PTRS_CARD_BITS )] = 1;<br> }<br></font><br>Thanks,<br> -Ryan<br>
<br>-- Appendix 1: First draft code CMM definition for casArray#<br>-------------------------------------------------------------------<br>stg_casArrayzh<br>/* MutableArray# s a -> Int# -> a -> a -> State# s -> (# State# s, Int#, a #) */<br>
{<br> W_ arr, p, ind, old, new, h, len;<br> arr = R1; // anything else?<br> ind = R2;<br> old = R3;<br> new = R4;<br><br> p = arr + SIZEOF_StgMutArrPtrs + WDS(ind);<br> (h) = foreign "C" cas(p, old, new) [];<br>
<br> if (h != old) {<br> // Failure, return what was there instead of 'old':<br> RET_NP(1,h);<br> } else {<br> // Compare and Swap Succeeded:<br> if (GET_INFO(arr) == stg_MUT_ARR_PTRS_CLEAN_info) {<br>
SET_HDR(arr, stg_MUT_ARR_PTRS_DIRTY_info, CCCS);<br> len = StgMutArrPtrs_ptrs(arr);<br> // The write barrier. We must write a byte into the mark table:<br> I8[arr + SIZEOF_StgMutArrPtrs + WDS(len) + (ind >> MUT_ARR_PTRS_CARD_BITS )] = 1;<br>
}<br> RET_NP(0,h);<br> }<br>}<br><br>-- Appendix 2: Simple test file; when run it should print:<br>-------------------------------------------------------------------<br>-- Perform a CAS within a MutableArray#<br>
-- 1st try should succeed: (True,33)<br>-- 2nd should fail: (False,44)<br>-- Printing array:<br>-- 33 33 33 44 33<br>-- Done.<br>-------------------------------------------------------------------<br>{-# Language MagicHash, UnboxedTuples #-}<br>
<br>import <a href="http://GHC.IO">GHC.IO</a><br>import GHC.IORef<br>import <a href="http://GHC.ST">GHC.ST</a><br>import GHC.STRef<br>import GHC.Prim<br>import GHC.Base<br>import Data.Primitive.Array<br>import Control.Monad<br>
<br>------------------------------------------------------------------------<br><br>-- -- | Write a value to the array at the given index:<br>casArrayST :: MutableArray s a -> Int -> a -> a -> ST s (Bool, a)<br>
casArrayST (MutableArray arr#) (I# i#) old new = ST$ \s1# -><br> case casArray# arr# i# old new s1# of <br> (# s2#, x#, res #) -> (# s2#, (x# ==# 0#, res) #)<br><br>------------------------------------------------------------------------<br>
{-# NOINLINE mynum #-}<br>mynum :: Int<br>mynum = 33<br><br>main = do <br> putStrLn "Perform a CAS within a MutableArray#"<br> arr <- newArray 5 mynum<br><br> res <- stToIO$ casArrayST arr 3 mynum 44<br>
res2 <- stToIO$ casArrayST arr 3 mynum 44<br> putStrLn$ " 1st try should succeed: "++show res<br> putStrLn$ "2nd should fail: "++show res2<br><br> putStrLn "Printing array:"<br> forM_ [0..4] $ \ i -> do<br>
x <- readArray arr i <br> putStr (" "++show x)<br> putStrLn ""<br> putStrLn "Done."<br><br>