module Vectorise.Utils (
  module Vectorise.Utils.Base,
  module Vectorise.Utils.Closure,
  module Vectorise.Utils.Hoisting,
  module Vectorise.Utils.PADict,
  module Vectorise.Utils.PRDict,
  module Vectorise.Utils.Poly,

  -- * Annotated Exprs
  collectAnnTypeArgs,
  collectAnnTypeBinders,
  collectAnnValBinders,
  isAnnTypeArg,

  -- * PD Functions
  replicatePD, emptyPD, packByTagPD,
  combinePD, liftPD,

  -- * Scalars
  zipScalars, scalarClosure,

  -- * Naming
  newLocalVar
) 
where
import Vectorise.Utils.Base
import Vectorise.Utils.Closure
import Vectorise.Utils.Hoisting
import Vectorise.Utils.PADict
import Vectorise.Utils.PRDict
import Vectorise.Utils.Poly
import Vectorise.Monad
import Vectorise.Builtins
import CoreSyn
import CoreUtils
import Type
import Var
import Control.Monad


-- Annotated Exprs ------------------------------------------------------------
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
  where
    go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
    go e                             tys = (e, tys)

collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnTypeBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isTyCoVar b = go (b:bs) e
    go bs e                           = (reverse bs, e)

collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnValBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isId b = go (b:bs) e
    go bs e                        = (reverse bs, e)

isAnnTypeArg :: AnnExpr b ann -> Bool
isAnnTypeArg (_, AnnType _) = True
isAnnTypeArg _              = False


-- PD Functions ---------------------------------------------------------------
replicatePD :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePD len x = liftM (`mkApps` [len,x])
                          (paMethod replicatePDVar "replicatePD" (exprType x))

emptyPD :: Type -> VM CoreExpr
emptyPD = paMethod emptyPDVar "emptyPD"


packByTagPD :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
packByTagPD ty xs len tags t
  = liftM (`mkApps` [xs, len, tags, t])
          (paMethod packByTagPDVar "packByTagPD" ty)


combinePD :: Type -> CoreExpr -> CoreExpr -> [CoreExpr] -> VM CoreExpr
combinePD ty len sel xs
  = liftM (`mkApps` (len : sel : xs))
          (paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
  where
    n = length xs


-- | Like `replicatePD` but use the lifting context in the vectoriser state.
liftPD :: CoreExpr -> VM CoreExpr
liftPD x
  = do
      lc <- builtin liftingContext
      replicatePD (Var lc) x


-- Scalars --------------------------------------------------------------------
zipScalars :: [Type] -> Type -> VM CoreExpr
zipScalars arg_tys res_ty
  = do
      scalar <- builtin scalarClass
      (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
      zipf <- builtin (scalarZip $ length arg_tys)
      return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
    where
      ty_args = arg_tys ++ [res_ty]


scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
scalarClosure arg_tys res_ty scalar_fun array_fun
  = do
      ctr      <- builtin (closureCtrFun $ length arg_tys)
      Just pas <- liftM sequence $ mapM paDictOfType (init arg_tys)
      return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
                       `mkApps`   (pas ++ [scalar_fun, array_fun])



{-
boxExpr :: Type -> VExpr -> VM VExpr
boxExpr ty (vexpr, lexpr)
  | Just (tycon, []) <- splitTyConApp_maybe ty
  , isUnLiftedTyCon tycon
  = do
      r <- lookupBoxedTyCon tycon
      case r of
        Just tycon' -> let [dc] = tyConDataCons tycon'
                       in
                       return (mkConApp dc [vexpr], lexpr)
        Nothing     -> return (vexpr, lexpr)
-}