```%
% (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
%
\section[FloatOut]{Float bindings outwards (towards the top level)}

``Long-distance'' floating of bindings towards the top level.

\begin{code}
module FloatOut ( floatOutwards ) where

import CoreSyn
import CoreUtils

import DynFlags	( DynFlags, DynFlag(..), FloatOutSwitches(..) )
import ErrUtils		( dumpIfSet_dyn )
import CostCentre	( dupifyCC, CostCentre )
import Id		( Id, idType )
import Type		( isUnLiftedType )
import SetLevels	( Level(..), LevelledExpr, LevelledBind,
setLevels, isTopLvl, tOP_LEVEL )
import UniqSupply       ( UniqSupply )
import Bag
import Util
import Maybes
import UniqFM
import Outputable
import FastString
\end{code}

-----------------
Overall game plan
-----------------

The Big Main Idea is:

To float out sub-expressions that can thereby get outside
a non-one-shot value lambda, and hence may be shared.

To achieve this we may need to do two thing:

a) Let-bind the sub-expression:

f (g x)  ==>  let lvl = f (g x) in lvl

Now we can float the binding for 'lvl'.

b) More than that, we may need to abstract wrt a type variable

\x -> ... /\a -> let v = ...a... in ....

Here the binding for v mentions 'a' but not 'x'.  So we
abstract wrt 'a', to give this binding for 'v':

vp = /\a -> ...a...
v  = vp a

Now the binding for vp can float out unimpeded.
I can't remember why this case seemed important enough to
deal with, but I certainly found cases where important floats
didn't happen if we did not abstract wrt tyvars.

With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level.  Whether or not this
happens is controlled by a flag.

~~~~~~~~~~~~~~~

At the moment we never float a binding out to between two adjacent
lambdas.  For example:

@
\x y -> let t = x+x in ...
===>
\x -> let t = x+x in \y -> ...
@
Reason: this is less efficient in the case where the original lambda
is never partially applied.

But there's a case I've seen where this might not be true.  Consider:
@
elEm2 x ys
= elem' x ys
where
elem' _ []	= False
elem' x (y:ys)	= x==y || elem' x ys
@
It turns out that this generates a subexpression of the form
@
\deq x ys -> let eq = eqFromEqDict deq in ...
@
vwhich might usefully be separated to
@
\deq -> let eq = eqFromEqDict deq in \xy -> ...
@
Well, maybe.  We don't do this at the moment.

%************************************************************************
%*									*
\subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
%*									*
%************************************************************************

\begin{code}
floatOutwards :: FloatOutSwitches
-> DynFlags
-> UniqSupply
-> [CoreBind] -> IO [CoreBind]

floatOutwards float_sws dflags us pgm
= do {
let { annotated_w_levels = setLevels float_sws pgm us ;
(fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
} ;

(vcat (map ppr annotated_w_levels));

let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };

dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
(hcat [	int tlets,  ptext (sLit " Lets floated to top level; "),
int ntlets, ptext (sLit " Lets floated elsewhere; from "),
int lams,   ptext (sLit " Lambda groups")]);

return (concat binds_s')
}

floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
floatTopBind bind
= case (floatBind bind) of { (fs, floats) ->
(fs, bagToList (flattenFloats floats))
}
\end{code}

%************************************************************************
%*									*
\subsection[FloatOut-Bind]{Floating in a binding (the business end)}
%*									*
%************************************************************************

\begin{code}
floatBind :: LevelledBind -> (FloatStats, FloatBinds)

floatBind (NonRec (TB name level) rhs)
= case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats `plusFloats` unitFloat level (NonRec name rhs')) }

floatBind bind@(Rec pairs)
= case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
let rhs_floats = foldr1 plusFloats rhss_floats in

if not (isTopLvl bind_dest_lvl) then
-- Find which bindings float out at least one lambda beyond this one
-- These ones can't mention the binders, because they couldn't
-- be escaping a major level if so.
-- The ones that are not going further can join the letrec;
-- they may not be mutually recursive but the occurrence analyser will
-- find that out.
case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
(sum_stats fss,
floats' `plusFloats` unitFloat bind_dest_lvl
(Rec (floatsToBindPairs heres new_pairs))) }
else
-- In a recursive binding, *destined for* the top level
-- (only), the rhs floats may contain references to the
-- bound things.  For example
--	f = ...(let v = ...f... in b) ...
--  might get floated to
--	v = ...f...
--	f = ... b ...
-- and hence we must (pessimistically) make all the floats recursive
-- with the top binding.  Later dependency analysis will unravel it.
--
-- This can only happen for bindings destined for the top level,
-- because only then will partitionByMajorLevel allow through a binding
-- that only differs in its minor level
(sum_stats fss, unitFloat tOP_LEVEL
(Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))
}
where
bind_dest_lvl = getBindLevel bind

do_pair (TB name level, rhs)
= case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (name, rhs'))
}
\end{code}

%************************************************************************

\subsection[FloatOut-Expr]{Floating in expressions}
%*									*
%************************************************************************

\begin{code}
floatExpr, floatRhs, floatCaseAlt
:: Level
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)

floatCaseAlt lvl arg	-- Used rec rhss, and case-alternative rhss
= case (floatExpr lvl arg) of { (fsa, floats, arg') ->
case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
-- Dump bindings that aren't going to escape from a lambda;
-- in particular, we must dump the ones that are bound by
-- the rec or case alternative
(fsa, floats', install heres arg') }}

floatRhs lvl arg	-- Used for nested non-rec rhss, and fn args
-- See Note [Floating out of RHS]
= case (floatExpr lvl arg) of { (fsa, floats, arg') ->
if exprIsCheap arg' then
(fsa, floats, arg')
else
case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
(fsa, floats', install heres arg') }}

-- Note [Floating out of RHSs]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- Dump bindings that aren't going to escape from a lambda
-- This isn't a scoping issue (the binder isn't in scope in the RHS
--	of a non-rec binding)
-- Rather, it is to avoid floating the x binding out of
--	f (let x = e in b)
-- unnecessarily.  But we first test for values or trival rhss,
-- because (in particular) we don't want to insert new bindings between
-- the "=" and the "\".  E.g.
--	f = \x -> let <bind> in <body>
-- We do not want
--	f = let <bind> in \x -> <body>
-- (a) The simplifier will immediately float it further out, so we may
--	as well do so right now; in general, keeping rhss as manifest
--	values is good
-- (b) If a float-in pass follows immediately, it might add yet more
--	bindings just after the '='.  And some of them might (correctly)
--	be strict even though the 'let f' is lazy, because f, being a value,
--	gets its demand-info zapped by the simplifier.
--
-- We use exprIsCheap because that is also what's used by the simplifier
-- to decide whether to float a let out of a let

floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)

floatExpr lvl (App e a)
= case (floatExpr      lvl e) of { (fse, floats_e, e') ->
case (floatRhs lvl a) 	of { (fsa, floats_a, a') ->
(fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}

floatExpr _ lam@(Lam _ _)
= let
(bndrs_w_lvls, body) = collectBinders lam
bndrs		     = [b | TB b _ <- bndrs_w_lvls]
lvls		     = [l | TB _ l <- bndrs_w_lvls]

-- For the all-tyvar case we are prepared to pull
-- the lets out, to implement the float-out-of-big-lambda
-- transform; but otherwise we only float bindings that are
-- going to escape a value lambda.
-- In particular, for one-shot lambdas we don't float things
-- out; we get no saving by so doing.
partition_fn | all isTyVar bndrs = partitionByLevel
| otherwise	 = partitionByMajorLevel
in
case (floatExpr (last lvls) body) of { (fs, floats, body') ->

-- Dump any bindings which absolutely cannot go any further
case (partition_fn (head lvls) floats)	of { (floats', heres) ->

(add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
}}

floatExpr lvl (Note note@(SCC cc) expr)
= case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
let
-- Annotate bindings floated outwards past an scc expression
-- with the cc.  We mark that cc as "duplicated", though.

annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
in
(fs, annotated_defns, Note note expr') }

floatExpr _ (Note InlineMe expr)	-- Other than SCCs
= (zeroStats, emptyFloats, Note InlineMe (unTag expr))
-- Do no floating at all inside INLINE.
-- The SetLevels pass did not clone the bindings, so it's
-- unsafe to do any floating, even if we dump the results
-- inside the Note (which is what we used to do).

floatExpr lvl (Note note expr)	-- Other than SCCs
= case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
(fs, floating_defns, Note note expr') }

floatExpr lvl (Cast expr co)
= case (floatExpr lvl expr)	of { (fs, floating_defns, expr') ->
(fs, floating_defns, Cast expr' co) }

floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
| isUnLiftedType (idType bndr)	-- Treat unlifted lets just like a case
-- I.e. floatExpr for rhs, floatCaseAlt for body
= case floatExpr lvl rhs	    of { (_, rhs_floats, rhs') ->
case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
(fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}

floatExpr lvl (Let bind body)
= case (floatBind bind)     of { (fsb, bind_floats) ->
case (floatExpr lvl body) of { (fse, body_floats, body') ->
bind_floats `plusFloats` body_floats,
body')  }}

floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
= case floatExpr lvl scrut	of { (fse, fde, scrut') ->
case floatList float_alt alts	of { (fsa, fda, alts')  ->
(add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
}}
where
-- Use floatCaseAlt for the alternatives, so that we
-- don't gratuitiously float bindings out of the RHSs
float_alt (con, bs, rhs)
= case (floatCaseAlt case_lvl rhs)	of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }

floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (zeroStats, emptyFloats, [])
floatList f (a:as) = case f a		 of { (fs_a,  binds_a,  b)  ->
case floatList f as of { (fs_as, binds_as, bs) ->
(fs_a `add_stats` fs_as, binds_a `plusFloats` binds_as, b:bs) }}

getBindLevel :: Bind (TaggedBndr Level) -> Level
getBindLevel (NonRec (TB _ lvl) _)       = lvl
getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
getBindLevel (Rec [])                    = panic "getBindLevel Rec []"

unTagBndr :: TaggedBndr tag -> CoreBndr
unTagBndr (TB b _) = b

unTag :: TaggedExpr tag -> CoreExpr
unTag (Var v)  	  = Var v
unTag (Lit l)  	  = Lit l
unTag (Type ty)   = Type ty
unTag (Note n e)  = Note n (unTag e)
unTag (App e1 e2) = App (unTag e1) (unTag e2)
unTag (Lam b e)   = Lam (unTagBndr b) (unTag e)
unTag (Cast e co) = Cast (unTag e) co
unTag (Let (Rec prs) e)    = Let (Rec [(unTagBndr b,unTag r) | (b, r) <- prs]) (unTag e)
unTag (Let (NonRec b r) e) = Let (NonRec (unTagBndr b) (unTag r)) (unTag e)
unTag (Case e b ty alts)   = Case (unTag e) (unTagBndr b) ty
[(c, map unTagBndr bs, unTag r) | (c,bs,r) <- alts]
\end{code}

%************************************************************************
%*									*
\subsection{Utility bits for floating stats}
%*									*
%************************************************************************

I didn't implement this with unboxed numbers.  I don't want to be too
strict in this stuff, as it is rarely turned on.  (WDP 95/09)

\begin{code}
data FloatStats
= FlS	Int  -- Number of top-floats * lambda groups they've been past
Int  -- Number of non-top-floats * lambda groups they've been past
Int  -- Number of lambda (groups) seen

get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a b c) = (a, b, c)

zeroStats :: FloatStats
zeroStats = FlS 0 0 0

sum_stats :: [FloatStats] -> FloatStats
sum_stats xs = foldr add_stats zeroStats xs

add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
= FlS (a1 + a2) (b1 + b2) (c1 + c2)

add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS a b c) (FB tops others)
= FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
\end{code}

%************************************************************************
%*									*
\subsection{Utility bits for floating}
%*									*
%************************************************************************

Note [Representation of FloatBinds]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The FloatBinds types is somewhat important.  We can get very large numbers
of floating bindings, often all destined for the top level.  A typical example
is     x = [4,2,5,2,5, .... ]
Then we get lots of small expressions like (fromInteger 4), which all get
lifted to top level.

The trouble is that
(a) we partition these floating bindings *at every binding site*
(b) SetLevels introduces a new bindings site for every float
So we had better not look at each binding at each binding site!

That is why MajorEnv is represented as a finite map.

We keep the bindings destined for the *top* level separate, because
we float them out even if they don't escape a *value* lambda; see
partitionByMajorLevel.

\begin{code}
type FloatBind = CoreBind      -- INVARIANT: a FloatBind is always lifted

data FloatBinds  = FB !(Bag FloatBind)         -- Destined for top level
!MajorEnv                 -- Levels other than top
-- See Note [Representation of FloatBinds]

type MajorEnv = UniqFM MinorEnv                        -- Keyed by major level
type MinorEnv = UniqFM (Bag FloatBind)         -- Keyed by minor level

flattenFloats :: FloatBinds -> Bag FloatBind
flattenFloats (FB tops others) = tops `unionBags` flattenMajor others

flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag

flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor = foldUFM unionBags emptyBag

emptyFloats :: FloatBinds
emptyFloats = FB emptyBag emptyUFM

unitFloat :: Level -> FloatBind -> FloatBinds
unitFloat InlineCtxt b = FB (unitBag b) emptyUFM
unitFloat lvl@(Level major minor) b
| isTopLvl lvl = FB (unitBag b) emptyUFM
| otherwise    = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))

plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)

plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = plusUFM_C plusMinor

plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor = plusUFM_C unionBags

floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
floatsToBindPairs floats binds = foldrBag add binds floats
where
add (Rec pairs)         binds = pairs ++ binds
add (NonRec binder rhs) binds = (binder,rhs) : binds

install :: Bag FloatBind -> CoreExpr -> CoreExpr
install defn_groups expr
= foldrBag install_group expr defn_groups
where
install_group defns body = Let defns body

partitionByMajorLevel, partitionByLevel
:: Level                -- Partitioning level
-> FloatBinds           -- Defns to be divided into 2 piles...
-> (FloatBinds,         -- Defns  with level strictly < partition level,
Bag FloatBind)      -- The rest

--      ---- partitionByMajorLevel ----
-- Float it if we escape a value lambda, *or* if we get to the top level
-- If we can get to the top level, say "yes" anyway. This means that
--     x = f e
-- transforms to
--    lvl = e
--    x = f lvl
-- which is as it should be

partitionByMajorLevel InlineCtxt (FB tops defns)
= (FB tops emptyUFM, flattenMajor defns)

partitionByMajorLevel (Level major _) (FB tops defns)
= (FB tops outer, heres `unionBags` flattenMajor inner)
where
(outer, mb_heres, inner) = splitUFM defns major
heres = case mb_heres of
Nothing -> emptyBag
Just h  -> flattenMinor h

partitionByLevel InlineCtxt (FB tops defns)
= (FB tops emptyUFM, flattenMajor defns)

partitionByLevel (Level major minor) (FB tops defns)
= (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
here_min `unionBags` flattenMinor inner_min
`unionBags` flattenMajor inner_maj)

where
(outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
(outer_min, mb_here_min, inner_min) = case mb_here_maj of
Nothing -> (emptyUFM, Nothing, emptyUFM)
Just min_defns -> splitUFM min_defns minor
here_min = mb_here_min `orElse` emptyBag

wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
wrapCostCentre cc (FB tops defns)
= FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
where
wrap_defns = mapBag wrap_one
wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)
\end{code}

```