{-# options_ghc -fdefer-typed-holes #-}
{-# language GHC2021, DataKinds, TemplateHaskell #-}

import qualified SimpleSMT as SMT

import Control.Monad.Trans.State as S
import Control.Monad.Trans.Class (lift)

import Data.List (transpose)
import Control.Monad (replicateM, forM)
import Prelude hiding (and,or,not)
import Control.Lens

import Data.Kind
import GHC.TypeLits
import Data.Proxy

data State = State { _solver :: !SMT.Solver
                   , _top :: !Int
                   }
$(makeLenses ''Main.State)


main :: IO ()
main =
  do l <- SMT.newLogger 0
     -- s <- SMT.newSolver "yices-smt2" [ "--smt2-model-format" ] (Just l)
     s <- SMT.newSolver "z3" [ "-in" ] (Just l)
     SMT.setLogic s "QF_BV"
     flip S.evalStateT (State {_solver= s, _top = 0}) $ do
       [a,b,c] <- replicateM 3 $ restricted_matrix @3 5
       mapM (lift . SMT.assert s . monotone) [a,b,c]
       -- lift $ assert s $ gtm (times a b) (times b a)
       mapM (\ (l, r) -> lift $ SMT.assert s $ gtm l r)
         -- [ (times a a, times a (times b a)) ]
         -- z086 :
         [ (times a a, times b c), (times b b, times a c), (times c c, times a b)]
       lift $ print =<< SMT.check s
       vs <- forM [a,b,c] getm
       lift $ print $ map Vertical vs

data Number (w :: Natural) = Number
  { contents :: SMT.SExpr -- denotes BV with width w
  , overflow :: SMT.SExpr -- denotes single bit
  }

valid n = SMT.not $ overflow n

type SMT = StateT Main.State IO 

make
  :: forall (w :: Natural)
  . KnownNat w
  => SMT (Number w)
make = do
  let w = natVal (Proxy @w)
  c <- declare $ SMT.tBits w
  o <- declare SMT.tBool
  return $ Number { contents = c, overflow = o }

declare :: SMT.SExpr -> SMT SMT.SExpr
declare ty = do
  t <- use top
  s <- use solver
  n <- lift $ SMT.declare s ("r" <> show t) ty
  top += 1
  return n

instance forall (w :: Natural) . KnownNat w => Num (Number w) where
  fromInteger i =
    let w = natVal (Proxy @w)
    in  Number { contents = SMT.bvBin (fromIntegral w) i
               , overflow = SMT.bool False }
  a + b =
    let w = natVal (Proxy @w)
        s = SMT.bvAdd (SMT.zeroExtend 1 $ contents a)
                      (SMT.zeroExtend 1 $ contents b)
    in  Number { contents = SMT.extract s (w-1) 0
               , overflow = SMT.or (overflow a)
                          $ SMT.or (overflow b)
                          $ bvUGt (SMT.extract s w w)
                                  (SMT.bvBin (fromIntegral 1) 0)
               }

  a * b =
    let w = natVal (Proxy @w)
        p = SMT.bvMul (SMT.zeroExtend w $ contents a)
                      (SMT.zeroExtend w $ contents b)
    in  Number { contents = SMT.extract p (w-1) 0
               , overflow = _
               }



matrix
  :: forall (w :: Natural) . KnownNat w
  => Int -> SMT [[Number w]]
matrix dim = replicateM dim $ replicateM dim make

restricted_matrix
  :: forall (w :: Natural) . KnownNat w
  => Int -> SMT [[Number w]]
restricted_matrix dim = do
  let w = natVal (Proxy @w)
  inner <- matrix (dim - 1)
  let first_col = [1] <> replicate (dim-2) 0
      last_row =  replicate (dim-1) 0 <> [1]
  return $ zipWith (:) first_col inner
         <> [ last_row ]
  


monotone
  :: forall (w :: Natural) . KnownNat w
  => [[Number w]] -> SMT.SExpr
monotone m = SMT.and
  (geqn (head $ head m) (1 :: Number w))
  (geqn (last $ last m) (1 :: Number w))

times a b =
  flip map a $ \ row ->
  flip map (transpose b) $ \ col -> 
    sum $ zipWith (*) row col

gtm a b = SMT.and (geqm a b) (gtn (last $ head a) (last $ head b))

geqm a b = foldr SMT.and  (SMT.bool True) $ do
  (xs,ys) <- zip a b
  (x,y) <- zip xs ys
  return $ geqn x y

getm a = do
  s <- get
  forM a $ \ xs -> forM xs $ getn

getn n = use solver >>= \ s -> lift $ do
  c <- SMT.getExpr s $ contents n
  o <- SMT.getExpr s $ overflow n
  return c
  
gtn a b = SMT.and (valid a)
  $ SMT.and (valid b)
  $ bvUGt (contents a) (contents b)

geqn a b = SMT.and (valid a)
  $ SMT.and (valid b)
  $ bvUGeq (contents a) (contents b)

-- missing in SimpleSMT:

bvUGeq x y = SMT.bvULeq y x
bvUGt x y = SMT.bvULt y x

-- crude formatting:

newtype Vertical a = Vertical [a]
instance Show a => Show (Vertical a) where
  show (Vertical xs) =
    unlines $ zipWith (<>) ("[ " : repeat " , ") (map show xs)
            <> [ " ]" ]
    
