{-# OPTIONS -fglasgow-exts #-}
-- Just for pattern guards


module SudokuPJ where

import Data.FiniteMap
import Data.List
import Text.PrettyPrint
import Char

import TestPJ

data Verbosity = All | Terse | Final deriving( Eq )

------------------------
run tests = vcat [ text d <> colon <+> run1 Final t | (d,t) <- tests ]

run1 :: Verbosity -> [String] -> Doc
run1 verb tst
  = vcat [if verb==All then ppBoard init_board else empty,
	  run_moves verb 0 init_board]
  where
    init_board = initBoard tst

run_moves verb n board
  | Just d <- inconsistent board
  = vcat [text "Consistency error:",
	  nest 2 d]

  | Just (b,d) <- move verb board
  = vcat [if verb == Final then empty else 
		text "Step" <+> int n <> colon <+> d $$ text "",
	  if verb == All then ppBoard b else empty,
	  run_moves verb (n+1) b]

  | otherwise
  = vcat [stuck_or_success <+> text "after" <+> int n <+> text "moves",
	  if verb == Final then empty else ppBoard board ]

  where
    stuck_or_success | any (isUnk . getBoard board) allPos
		     = text "Stuck"
		     | otherwise = text "Success"

------------------------
move :: Verbosity -> Board -> Maybe (Board,Doc)
move verb b = firstJust [improve verb b, 
			 oneChoice b, 
			 population b, 
			 groups b, 
			 rowCols b
		]

------------------------------------------
--	Positions, boxes
------------------------------------------

type CellIx = Int	-- Cell index
type BoxIx  = Int	-- Box index

type Pos    = (CellIx,CellIx)	-- (row, col)
type BoxPos = (BoxIx, BoxIx)	-- (row, col)

cellN :: CellIx
cellN = 9	-- Number of cells

boxN :: BoxIx
boxN = 3	-- Number of boxes, and number of cells in a box

allRows, allCols :: [CellIx]
allRows = [1..cellN]
allCols = [1..cellN]

boxCells :: BoxIx -> [CellIx]
boxCells b = take boxN [ (b-1)*boxN + 1 .. ]

inBox :: BoxIx -> CellIx -> Bool
inBox b i = br <= i && i < br + boxN
  where
    br = (b-1)*boxN + 1

allPos :: [Pos]
allPos = [(r,c) | r <- allRows, c <- allCols]

ppPos :: Pos -> Doc
ppPos (r,c) = parens (int r <> comma <> int c)

------------------------------------------
--	The Board
------------------------------------------

type Board = FiniteMap Pos Cts
data Cts = Known Val Bool	-- Known value (True <=> given)
	 | Unk [Val]		-- Possible values, in increasing order

isUnk (Unk _)     = True
isUnk (Known _ _) = False

type Val = Int
allVals = [1..maxVal]
maxVal = 9

setUnknown :: Board -> Pos -> [Val] -> Board
setUnknown b pos vs = addToFM b pos (Unk vs)

boardUnks :: Board -> [(Pos,[Val])]
boardUnks b = [(pos,vs) | (pos,Unk vs) <- fmToList b]

elemUnk :: Val -> Cts -> Bool
elemUnk v (Unk vs) = v `elem` vs
elemUnk v other    = False

delUnk :: Board -> Pos -> Val -> Board
delUnk b p v = addToFM b p (case getBoard b p of 
				     Unk vs -> Unk (Data.List.delete v vs)
				     other  -> other)

delUnks :: Board -> [(Pos,Val)] -> Board
delUnks b pvs = foldr del b pvs
      where 
	del (p,v) b = delUnk b p v

setFound :: Board -> Pos -> Val -> Board
setFound b pos v = addToFM b pos (Known v False)

setGiven :: Board -> Pos -> Val -> Board
setGiven b pos v = addToFM b pos (Known v True)

getBoard :: Board -> Pos -> Cts
getBoard b pos = case lookupFM b pos of
		   Just cts -> cts
		   Nothing  -> error (show (text "getBoard" <+> ppPos pos))

------------------------------------------
--	Sets
------------------------------------------
data Set = Row CellIx | Col CellIx | Box BoxPos
	   deriving( Eq, Ord, Show )

type BoxSet = Set
type RowSet = Set
type ColSet = Set

setIntersect :: Set -> Set -> [Pos]
setIntersect (Row r) (Col c) = [(r,c)]
setIntersect (Row r) (Box (br,bc)) 
  | boxBase r == br = [(r,c) | c <- boxCells bc]
  | otherwise	    = []
  
setIntersect (Col c) (Row r) = [(r,c)]
setIntersect (Col c) (Box (br,bc)) 
  | boxBase c == bc = [(r,c) | r <- boxCells br]
  | otherwise	    = []
setIntersect (Box b) set = setIntersect set (Box b)

boxBase :: CellIx -> BoxIx
boxBase r = 1 + ((r-1) `div` boxN)

setPosns :: Set -> [Pos]
setPosns (Row r)       = [(r,c) | c <- allCols]
setPosns (Col c)       = [(r,c) | r <- allRows]
setPosns (Box (br,bc)) = [(r,c) | r <- boxCells br, c <- boxCells bc]

rowSets, colSets, boxSets, allSets :: [Set]
rowSets  = map Row allRows
colSets  = map Col allCols
boxSets  = [ Box (br,bc) | br <- [1..boxN], bc <- [1..boxN] ] 
allSets  = rowSets ++ colSets ++ boxSets

txSet :: Set -> Set	-- Transpose a Set
txSet (Row i) 	= Col i
txSet (Col i) 	= Row i
txSet (Box (i,j)) = Box (j,i)

txPos (i,j) = (j,i)

ppSets :: [Set] -> Doc
ppSets [d] = ppSet d
ppSets (Row r : rs) = text "rows" <+> ppList int (r : [r | Row r <- rs])
ppSets (Col c : cs) = text "columns" <+> ppList int (c : [c | Col c <- cs])
ppSets (Box b : bs) = text "boxes" <+> ppList pp_box (b : [b | Box b <- bs])

ppSet :: Set -> Doc
ppSet (Row r) = text "row" <+> int r
ppSet (Col c) = text "column" <+> int c
ppSet (Box b) = text "box" <+> pp_box b

pp_box (r,c) = parens (int r <> comma <> int c)

ppSetName (Row r) = text "row"
ppSetName (Col c) = text "column"
ppSetName (Box b) = text "box"

--------------------------------------
initBoard :: [String] -> Board
initBoard setup
  = foldr fill given_board allPos
  where
    given_board = foldr add_row emptyFM (allRows `zip` setup)
    
    add_row (r,cs) b = foldr (add_item r) b (allCols `zip` cs)
    add_item r (c,i) b 
	| not (isDigit i) = b
	| otherwise = setGiven b (r,c) (ord i - ord '0')

    fill pos b | pos `elemFM` b = b	-- Already set in given_board
	       | otherwise = setUnknown b pos (allVals \\ cantBe b pos)

--------------------------------------
inconsistent :: Board -> Maybe Doc
-- Nothing -> consistent
-- Just d  -> bad
inconsistent b = first bad_set allSets
  where
    bad_set set
	| (v:_) <- dups [v | Known v _ <- cts_s]
	= Just (int v <+> text "appears more than once in" <+> ppSet set)
	| (v:_) <- filter (not . (`elem` mentioned)) allVals
	= Just (int v <+> text "does not appear in" <+> ppSet set)
	| otherwise
	= Nothing
	where
	  cts_s = map (getBoard b) (setPosns set)
	  mentioned = concatMap get cts_s
    get (Known v _) = [v]
    get (Unk vs)    = vs


--------------------------------------
ppBoard :: Board -> Doc
ppBoard b = hline '=' $$ vcat (map (ppRow b) allRows) $$ text ""

ppRow b r = vcat (map (ppSubRow b r) [0..boxN-1])
	    $$ hline (if r `mod` boxN == 0 then '=' else '-')

ppSubRow b r sr = fatVbar <> hcat (map (ppSubRowCell b r sr) allCols)

ppSubRowCell b r sr c 
  = hmargin <> payload <> hmargin <> 
    if c `mod` boxN == 0 then fatVbar else thinVbar
  where
    payload = case getBoard b (r,c) of
		Known v g | sr==1     -> ch <> int v <> ch
			  | otherwise -> text "   "
			where
			  ch = if g then char '-' else char '*'
		Unk vs -> hcat $ map (go vs) $
			  take boxN [sr*boxN + 1 ..]
    go vs v | v `elem` vs = int v
	    | otherwise   = space

hline ch = text (replicate (colWidth*cellN) ch)

colWidth = 1 + 2*hmarginWidth + boxN
hmarginWidth  = 2
hmargin = text (replicate hmarginWidth ' ')
thinVbar = char '|'
fatVbar  = char '$'

ppList :: (a -> Doc) -> [a] -> Doc
ppList pp [] = empty
ppList pp [x] = pp x
ppList pp (x:xs) = pp x <> comma <+> ppList pp xs

--------------------------------------
improve :: Verbosity -> Board -> Maybe (Board, Doc)
improve verb b 
  = case improveStep b of
	Nothing    -> Nothing
	Just (b,d) -> Just (b', final_doc)
		where (n, b', d') = go 1 b d
		      final_doc | verb == Terse = text "Improvements" <+> parens (int n)
			        | otherwise     = text "Improvements" $$ (nest 2 d')
  where
    go n b d = case improveStep b of
		Nothing      -> (n,b,d)
		Just (b',d') -> go (n+1) b' (d $$ d')

improveStep :: Board -> Maybe (Board, Doc)
improveStep b = first (improveOne b) allPos

improveOne :: Board -> Pos -> Maybe (Board, Doc)
improveOne b pos
  = case getBoard b pos of
	Known _ _ -> Nothing
	Unk vs    -> first (try vs) vs
  where
    try vs v | v `elem` cantBe b pos
	     = Just (setUnknown b pos (Data.List.delete v vs),
		     text "Delete" <+> int v <+> text "from" <+> ppPos pos)
	     | otherwise
	     = Nothing

------------------------
cantBe :: Board -> Pos -> [Val]
cantBe b (r,c) = knowns b (concatMap setPosns sets)
  where
    sets = [Row r, Col c, Box (boxBase r, boxBase c)]

    knowns :: Board -> [Pos] -> [Val]
    knowns b pos_s = [v | Just (Known v _) <- map (lookupFM b) pos_s]
	-- cantBe and knowns are used when initialising the board
	-- so it might not have a value in every cell

------------------------
oneChoice :: Board -> Maybe (Board,Doc)
oneChoice b = first choose_one allPos
  where
    choose_one pos = case getBoard b pos of
			Unk [v] -> Just (setFound b pos v, msg pos v)
			other -> Nothing
    msg pos v = text "The only possibility at" <+> ppPos pos <+> text "is" <+> int v

------------------------
population :: Board -> Maybe (Board,Doc)
-- Each row, col, and box must have every value in it
population b
  = first try allSets
  where
    try set = first (try1 set unks) missing
	where
	  cts_s = [(pos, getBoard b pos) | pos <- setPosns set]
	  unks  = [(pos,vs) | (pos, Unk vs) <- cts_s]
	  missing = allVals \\ [v | (_, Known v _) <- cts_s]

    try1 set unks v = case filter (is_in v) unks of
			    [(pos,_)] -> Just (setFound b pos v,
					       pop_msg set pos v)
			    other     -> Nothing

    is_in v (pos,vs) = v `elem` vs

    pop_msg set pos v = text "In" <+> ppSet set <> comma <+> 
			text "the only place for" <+> int v <+> text "is" <+> ppPos pos

------------------------
groups :: Board -> Maybe (Board,Doc)
-- If a set (box,row,col) has N cells that contain the same N values,
-- 	then no other cell in the set can contain those values
-- Dually, if there are N values that are possible only in the same N cells,
-- 	then no other values can be in those cells
groups b = firstJust [groupsSet b n s | n <- [2..cellN-1], s <- allSets]
	-- "population" is subsumed by groups of size 1
	-- but "population" does it faster, in one step

groupsSet :: Board -> Int -> Set -> Maybe (Board,Doc)
groupsSet b n set
  | Just (unequal, pos_s, vs, del_items) <- findGroup n unks
  = Just (delUnks b del_items, msg1 unequal vs pos_s (map fst del_items))
  | Just (unequal, vs, pos_s, del_items) <- findGroup n val_pos_s
  = Just (delUnks b [(p,v) | (v,p) <- del_items], msg2 unequal vs pos_s)
  | otherwise
  = Nothing
  where
    cts_s  = [(pos, getBoard b pos) | pos <- setPosns set]
    unks   = [(pos,vs) | (pos, Unk vs) <- cts_s]
    set_unks (pos,vs) b = setUnknown b pos vs

    val_pos :: [(Val, Pos)]
    val_pos = [(v,p) | (p,vs) <- unks, v <- vs]

    val_pos_s :: [(Val, [Pos])]
    val_pos_s = groupByFst val_pos

    msg1 unequal vs our_pos_s del_pos_s
	= vcat [text "In" <+> ppSet set <> comma <+> text "only the values" <+> ppList int vs,
	        text "can be in cells" <+> ppList ppPos our_pos_s,
	        text "and hence can be deleted from cells" <+> ppList ppPos del_pos_s,
		unequal_msg unequal]

    msg2 unequal vs our_pos_s 
	= vcat [text "In" <+> ppSet set <> comma <+> text "the values" <+> ppList int vs,
	        text "can only be in cells" <+> ppList ppPos our_pos_s,
	        text "and hence other values can be deleted from these cells",
		unequal_msg unequal]

    unequal_msg True  = text "[Note: the cells do not all have the same possibilities]"
    unequal_msg False = empty


-----------------
rowCols b = firstJust [rowCols1 b n v | n <- [1..cellN-1], v <- allVals]
		-- 1 for row/col overlaps with 'population'

rowCols1 :: Board -> Int -> Val -> Maybe (Board, Doc)
rowCols1 b n v
  | Just (unequal, rows, cols, del_items)
	<- first try [	(2, cellN, rowSets, colSets),
			(1, boxN,  rowSets, boxSets),
			(1, boxN,  colSets, boxSets) ]
  = let 
	del_specs = [(p,v) | (r,c) <- del_items, p <- setIntersect r c]
    in
    Just (delUnks b del_specs, msg rows cols del_specs)
  | otherwise
  = Nothing
  where
    try (min,max,rows,cols) 
	| n  < min  = Nothing
	| n >= max  = Nothing
	| otherwise = findGroup n (make rows cols)
			   `andThen`
		      findGroup n (make cols rows)

    make :: [Set] -> [Set] -> [(Set, [Set])]
    make rows cols 
	= [(r, cs) | r <- rows, 
		     let cs = [c | c <- cols, any v_is_in (setIntersect r c)],
		     not (null cs)]

    v_is_in p = v `elemUnk` getBoard b p

    msg rows cols del_specs
	= vcat [text "In" <+> ppSets rows <> comma <+> text "the value" <+> int v,
	        text "can only be in" <+> ppSets cols,
	        text "and hence can be deleted from other" <+> 
		   ppSetName (head rows) <> text "s in those" <+> ppSetName (head cols) <> text "s",
		text "namely" <+> ppList ppPos (map fst del_specs)
		]

-----------------
findGroup :: (Eq key, Ord val)
      => Int -> [(key,[val])]
      -> Maybe (Bool, [key], [val], [(key,val)])
-- If N keys that collectively map to a set of exactly N values
-- AND any of those N values are mapped to by some other keys (other-keys),
-- THEN return (unequal, N-keys, N-vals, del-items)
-- Where del-items are the (key,[val]) that are in the input set, 
-- but are not part of the N-keys, N-vals group; these are the ones to delete
--
-- The incoming [val] are assumed distinct
--
-- The Bool result is true if the keys do not each map separately 
-- to the same N values, a case that is harder to spot by eye

findGroup n all_items 
  | length all_items <= n = Nothing
  | otherwise
  = first try (sublists n all_items)
  where
    try items
	| n_items == length the_vs
	&& not (null del_items)
	= Just (unequal, the_ks, the_vs, del_items)
	| otherwise = Nothing
	where
	  n_items = length items
	  (the_ks,the_vs_s) = unzip items
	  the_vs = unionL the_vs_s
	  unequal = not (all ((n_items ==) . length) the_vs_s)
	  del_items = [ (k, w)
		      | (k,ws) <- all_items,
		        not (k `elem` the_ks),
			w <- ws, 
			w `elem` the_vs ]

------------------------
first :: (a -> Maybe b) -> [a] -> Maybe b
first f xs = firstJust (map f xs)

firstJust :: [Maybe a] -> Maybe a
firstJust xs = foldr andThen Nothing xs

andThen :: Maybe a -> Maybe a -> Maybe a
andThen (Just a) _ = Just a
andThen Nothing  b = b

groupByFst :: Ord a => [(a,b)] -> [(a,[b])]
groupByFst prs = [(r, c:map snd ps) | (r,c) : ps <- equivBy cmpFst prs]

groupBySnd :: Ord b => [(a,b)] -> [(b,[a])]
groupBySnd prs = [(c, r:map fst ps) | (r,c) : ps <- equivBy cmpSnd prs]

cmpFst (a,_) (b,_) = a `compare` b
cmpSnd (_,a) (_,b) = a `compare` b

equivBy :: (a -> a -> Ordering) -> [a] -> [[a]]
-- Group items together that compare equal
equivBy cmp xs
  = groupBy eq (sortBy cmp xs)
  where
    eq x y = case x `cmp` y of
		EQ    -> True
		other -> False
    
dups :: Ord a => [a] -> [a]
dups xs = [ y | (y:ys) <- equivBy compare xs,	
		not (null ys) ]

sublists :: Int -> [a] -> [[a]]
-- All sublists, of specified length, of the input
sublists 0 xs = [[]]
sublists n [] = []
sublists n (x:xs) = map (x :) (sublists (n-1) xs) ++ sublists n xs

unionL :: Ord a => [[a]] -> [a]
-- Merge a non-empty bunch of incoming sets 
-- (each in increasing order) into one
unionL xss = foldr1 add2 xss
  where
    add2 xs ys = foldr add ys xs
    add x ys | x `elem` ys = ys
	     | otherwise   = x:ys
