{-# LANGUAGE GADTs #-}
module CmmSink (
     cmmSink,
     cmmPeepholeInline
  ) where

import Cmm
import BlockId
import CmmLive
import CmmUtils
import Hoopl

import UniqFM
import Unique
import Outputable

import qualified Data.Set as Set

-- -----------------------------------------------------------------------------
-- Sinking

-- This is an optimisation pass that
--  (a) moves assignments closer to their uses, to reduce register pressure
--  (b) pushes assignments into a single branch of a conditional if possible

-- It is particularly helpful in the Cmm generated by the Stg->Cmm
-- code generator, in which every function starts with a copyIn
-- sequence like:
--
--    x1 = R1
--    x2 = Sp[8]
--    x3 = Sp[16]
--    if (Sp - 32 < SpLim) then L1 else L2
--
-- we really want to push the x1..x3 assignments into the L2 branch.
--
-- Algorithm:
--
--  * Start by doing liveness analysis.
--  * Keep a list of assignments; earlier ones may refer to later ones
--  * Walk forwards through the graph;
--    * At an assignment:
--      * pick up the assignment and add it to the list
--    * At a store:
--      * drop any assignments that the store refers to
--      * drop any assignments that refer to memory that may be written
--        by the store
--      * do this recursively, dropping dependent assignments
--    * At a multi-way branch:
--      * drop any assignments that are live on more than one branch
--      * if any successor has more than one predecessor, drop everything
--        live in that successor
-- 
-- As a side-effect we'll delete some dead assignments (transitively,
-- even).  Maybe we could do without removeDeadAssignments?

-- If we do this *before* stack layout, we might be able to avoid
-- saving some things across calls/procpoints.
--
-- *but*, that will invalidate the liveness analysis, and we'll have
-- to re-do it.

cmmSink :: CmmGraph -> CmmGraph
cmmSink graph = cmmSink' (cmmLiveness graph) graph

type Assignment = (LocalReg, CmmExpr, AbsAddr)

cmmSink' :: BlockEnv CmmLive -> CmmGraph -> CmmGraph
cmmSink' liveness graph
  = ofBlockList (g_entry graph) $ sink mapEmpty $ postorderDfs graph
  where

  sink :: BlockEnv [Assignment] -> [CmmBlock] -> [CmmBlock]
  sink _ [] = []
  sink sunk (b:bs) =
    pprTrace "sink" (ppr lbl) $
    blockJoin first final_middle last : sink sunk' bs
    where
      lbl = entryLabel b
      (first, middle, last) = blockSplit b
      (middle', assigs) = walk (blockToList middle) emptyBlock
                               (mapFindWithDefault [] lbl sunk)

      getLive l = mapFindWithDefault Set.empty l liveness
      lives = map getLive (successors last)

      -- multilive is a list of registers that are live in more than
      -- one successor branch, and we should therefore drop them here.
      multilive = [ r | (r,n) <- ufmToList livemap, n > 1 ]
         where livemap = foldr (\r m -> addToUFM_C (+) m r (1::Int))
                            emptyUFM (concatMap Set.toList lives)

      (dropped_last, assigs') = dropAssignments drop_if assigs

      drop_if a@(r,_,_) = a `conflicts` last || getUnique r `elem` multilive

      final_middle = foldl blockSnoc middle' dropped_last

      sunk' = mapUnion sunk $
                 mapFromList [ (l, filterAssignments (getLive l) assigs')
                             | l <- successors last ]


filterAssignments :: RegSet -> [Assignment] -> [Assignment]
filterAssignments live assigs = reverse (go assigs [])
  where go []           kept = kept
        go (a@(r,_,_):as) kept | needed    = go as (a:kept)
                               | otherwise = go as kept
           where
              needed = r `Set.member` live || any (a `conflicts`) (map toNode kept)


walk :: [CmmNode O O] -> Block CmmNode O O -> [Assignment]
     -> (Block CmmNode O O, [Assignment])

walk []     block as = (block, as)
walk (n:ns) block as
  | Just a <- shouldSink n = walk ns block (a : as)
  | otherwise              = walk ns block' as'
  where
    (dropped, as') = dropAssignments (`conflicts` n) as
    block' = foldl blockSnoc block dropped `blockSnoc` n

shouldSink :: CmmNode O O -> Maybe Assignment
shouldSink (CmmAssign (CmmLocal r) e) | no_local_regs = Just (r, e, exprAddr e)
  where no_local_regs = foldRegsUsed (\_ _ -> False) True e
shouldSink _other = Nothing

toNode :: Assignment -> CmmNode O O
toNode (r,rhs,_) = CmmAssign (CmmLocal r) rhs

dropAssignments :: (Assignment -> Bool) -> [Assignment] -> ([CmmNode O O], [Assignment])
dropAssignments should_drop assigs
 = (dropped, reverse kept)
 where
   (dropped,kept) = go assigs [] []

   go []             dropped kept = (dropped, kept)
   go (assig : rest) dropped kept
      | conflict  = go rest (toNode assig : dropped) kept
      | otherwise = go rest dropped (assig:kept)
      where
        conflict = should_drop assig || any (assig `conflicts`) dropped

-- | @conflicts (r,e) stmt@ is @False@ if and only if the assignment
-- @r = e@ can be safely commuted past @stmt@.
--
-- We only sink "r = G" assignments right now, so conflicts is very simple:
--
conflicts :: Assignment -> CmmNode O x -> Bool
(_, rhs, _   ) `conflicts` CmmAssign reg  _ | reg `regUsedIn` rhs = True
(_, _,   addr) `conflicts` CmmStore addr' _ | addrConflicts addr (loadAddr addr') = True
(r, _,   _)    `conflicts` node
  = foldRegsUsed (\b r' -> r == r' || b) False node

-- An abstraction of the addresses read or written.
data AbsAddr = NoAddr | HeapAddr | StackAddr | AnyAddr

bothAddrs :: AbsAddr -> AbsAddr -> AbsAddr
bothAddrs NoAddr    x         = x
bothAddrs x         NoAddr    = x
bothAddrs HeapAddr  HeapAddr  = HeapAddr
bothAddrs StackAddr StackAddr = StackAddr
bothAddrs _         _         = AnyAddr

addrConflicts :: AbsAddr -> AbsAddr -> Bool
addrConflicts NoAddr    _         = False
addrConflicts _         NoAddr    = False
addrConflicts HeapAddr  StackAddr = False
addrConflicts StackAddr HeapAddr  = False
addrConflicts _         _         = True

exprAddr :: CmmExpr -> AbsAddr -- here NoAddr means "no reads"
exprAddr (CmmLoad addr _)  = loadAddr addr
exprAddr (CmmMachOp _ es)  = foldr bothAddrs NoAddr (map exprAddr es)
exprAddr _                 = NoAddr

absAddr :: CmmExpr -> AbsAddr -- here NoAddr means "don't know"
absAddr (CmmLoad addr _)  = bothAddrs HeapAddr (loadAddr addr) -- (1)
absAddr (CmmMachOp _ es)  = foldr bothAddrs NoAddr (map absAddr es)
absAddr (CmmReg r)        = regAddr r
absAddr (CmmRegOff r _)   = regAddr r
absAddr _ = NoAddr

loadAddr :: CmmExpr -> AbsAddr
loadAddr e = case absAddr e of
               NoAddr -> HeapAddr -- (2)
               a      -> a

-- (1) we assume that an address read from memory is a heap address.
-- We never read a stack address from memory.
--
-- (2) loading from an unknown address is assumed to be a heap load.

regAddr :: CmmReg -> AbsAddr
regAddr (CmmGlobal Sp) = StackAddr
regAddr (CmmGlobal Hp) = HeapAddr
regAddr _              = NoAddr

-- After sinking, if we have an assignment to a temporary that is used
-- exactly once, then it will either be of the form
--
--   x = E
--   .. stmt involving x ..
--
-- OR
--
--   x = E
--   .. stmt conflicting with E ..

-- So the idea in peepholeInline is to spot the first case
-- (recursively) and inline x.  We start with the set of live
-- registers and move backwards through the block.
--
-- ToDo: doesn't inline into the last node
--
cmmPeepholeInline :: CmmGraph -> CmmGraph
cmmPeepholeInline graph = ofBlockList (g_entry graph) $ map do_block (toBlockList graph)
  where
   liveness = cmmLiveness graph

   do_block :: Block CmmNode C C -> Block CmmNode C C
   do_block block = blockJoin first (go rmiddle live_middle) last
     where
       (first, middle, last) = blockSplit block
       rmiddle = reverse (blockToList middle)
    
       live = Set.unions [ mapFindWithDefault Set.empty l liveness | l <- successors last ]
    
       live_middle = gen_kill last live
    
       go :: [CmmNode O O] -> RegSet -> Block CmmNode O O
       go [] _ = emptyBlock
       go [stmt] _ = blockCons stmt emptyBlock
       go (stmt : rest) live = tryInline stmt usages live rest
         where
           usages :: UniqFM Int
           usages = foldRegsUsed addUsage emptyUFM stmt
    
       addUsage :: UniqFM Int -> LocalReg -> UniqFM Int
       addUsage m r = addToUFM_C (+) m r 1
    
       tryInline stmt usages live (CmmAssign (CmmLocal l) rhs : rest)
          | not (l `elemRegSet` live),
            Just 1 <- lookupUFM usages l = tryInline stmt' usages' live' rest
          where live'   = foldRegsUsed extendRegSet live rhs
                usages' = foldRegsUsed addUsage usages rhs
    
                stmt' = mapExpDeep inline stmt
                   where inline (CmmReg    (CmmLocal l'))     | l == l' = rhs
                         inline (CmmRegOff (CmmLocal l') off) | l == l' = cmmOffset rhs off
                         inline other = other
    
       tryInline stmt _usages live stmts
            = go stmts (gen_kill stmt live) `blockSnoc` stmt