{-# language TemplateHaskell #-}

import SimpleSMT

import Control.Monad.Trans.State as S
import Control.Monad.Trans.Class (lift)

import Data.List (transpose)
import Control.Monad (replicateM, forM, forM_, when)
import Prelude hiding (and,or,not)
import Control.Lens

data State = State { _solver :: !Solver
                   , _top :: !Int
                   }
$(makeLenses ''Main.State)

type SMT = S.StateT Main.State IO

main :: IO ()
main =
  do l <- newLogger 0 -- prints a lot, change to 1 to reduce
     s <- newSolver
        "z3" [ "-in" ]
        -- "yices-smt2" [ "--smt2-model-format" ]  
        (Just l)
     setLogic s "QF_NIA"
     flip S.evalStateT (State {_solver= s, _top = 0}) $ do
       [a,b,c] <- replicateM 3 $
         -- matrix_bounded 5 3
         matrix 3
       mapM (lift . assert s . monotone) [a,b,c]
       mapM (\ (l, r) -> lift $ assert s $ gtm l r)
         [
           -- aa -> aba
           (times a a , times a (times b a))

         -- z086 :
         -- (times a a, times b c) , (times b b, times a c) , (times c c, times a b)
         ]
         
       lift $ print =<< check s
       vs <- forM [a,b,c] getm
       lift $ print $ map Vertical vs

{- independent verification:  rlwrap maxima
a : matrix([11,7,7],[15,9,8],[1,12,5]);
b : matrix([1,0,0],[0,0,0],[1,0,1]);
a . a;
a . b . a;
-}

declare_nat :: SMT SExpr
declare_nat = do
  t <- use top
  s <- use solver
  n <- lift $ declare s ("r" <> show t) tInt
  lift $  assert s $ geq n 0
  top += 1
  return n

matrix dim = do
  replicateM dim $ replicateM dim $ declare_nat 

declare_nat_bounded :: Integer -> SMT SExpr
declare_nat_bounded b = do
  n <- declare_nat
  use solver >>= \ s -> lift $ assert s $ geq (fromInteger b) n
  return n

instance Num SExpr where
  fromInteger = int
  (+) = add
  (*) = mul

matrix_bounded dim b = do
  replicateM dim $ replicateM dim $ declare_nat_bounded b

monotone m =
  and (geq (head $ head m) 1) (geq (last $ last m) 1)

times a b =
  flip map a $ \ row ->
  flip map (transpose b) $ \ col -> 
    sum $ zipWith (*) row col

gtm a b = and (geqm a b) (gt (last $ head a) (last $ head b))

geqm a b = foldr and  (bool True) $ do
  (xs,ys) <- zip a b
  (x,y) <- zip xs ys
  return $ geq x y

getm :: [[SExpr]] -> SMT [[Value]]
getm a =
  forM a $ \ xs -> forM xs $ \ x -> use solver >>= \ s ->
    lift $ getExpr s x

-- crude formatting:

newtype Vertical a = Vertical [a]
instance Show a => Show (Vertical a) where
  show (Vertical xs) =
    unlines $ zipWith (<>) ("[ " : repeat " , ") (map show xs)
            <> [ " ]" ]
