{-# language TemplateHaskell #-}

import SimpleSMT

import Control.Monad.Trans.State as S
import Control.Monad.Trans.Class (lift)

import Data.List (transpose)
import Control.Monad (replicateM, forM)
import Prelude hiding (and,or,not)
import Control.Lens

data State = State { _solver :: !Solver
                   , _top :: !Int
                   }
$(makeLenses ''Main.State)

type SMT = S.StateT Main.State IO

main :: IO ()
main =
  do l <- newLogger 0
     -- s <- newSolver "yices-smt2" [ "--smt2-model-format" ] (Just l)
     s <- newSolver "z3" [ "-in" ] (Just l)
     setLogic s "QF_BV"
     flip S.evalStateT (State {_solver = s, _top = 0}) $ do
       [a,b,c] <- replicateM 3 $ matrix_width 3 3
       mapM (lift . assert s . monotone) [a,b,c]
       mapM (\ (l, r) -> lift $ assert s $ gtm l r)
         [
           -- aa -> aba
           (times a a, times a (times b a)) ]
         -- z086 :
         -- [ (times a a, times b c), (times b b, times a c), (times c c, times a b)]
       lift $ print =<< check s
       vs <- forM [a,b,c] getm
       lift $ print $ map Vertical vs

declare_bv :: Integer -> SMT SExpr
declare_bv width = do
  t <- use top
  s <- use solver
  n <- lift $ declare s ("r" <> show t) $ tBits width
  top += 1
  return n

-- this is wrong (because it ignores overflow)
-- and impractical (because bit width is fixed)
instance Num SExpr where
  fromInteger = bvBin 3 -- FIXME: fixed width
  (+) = bvAdd
  (*) = bvMul

matrix_width dim w =
  replicateM dim $ replicateM dim $ declare_bv w

monotone m =
  and (bvUGeq (head $ head m) 1) (bvUGeq (last $ last m) 1)

times a b =
  flip map a $ \ row ->
  flip map (transpose b) $ \ col -> 
    sum $ zipWith (*) row col

gtm a b = and (geqm a b) (bvUGt (last $ head a) (last $ head b))

geqm a b = foldr and  (bool True) $ do
  (xs,ys) <- zip a b
  (x,y) <- zip xs ys
  return $ bvUGeq x y

getm :: [[SExpr]] -> SMT [[Value]]
getm a =
  forM a $ \ xs -> forM xs $ \ x -> use solver >>= \ s ->
    lift $ getExpr s x

-- missing from SimpleSMT:

bvUGeq x y = bvULeq y x
bvUGt x y = bvULt y x

-- crude formatting:

newtype Vertical a = Vertical [a]
instance Show a => Show (Vertical a) where
  show (Vertical xs) =
    unlines $ zipWith (<>) ("[ " : repeat " , ") (map show xs)
            <> [ " ]" ]
