{-# language GHC2021, DataKinds, TypeFamilies #-}

import Ersatz
import Prelude hiding (and, or, not, (&&), (||), any, all)
import Ersatz.Solver
import Ersatz.Solver.Minisat
import Data.List (transpose)
import Control.Monad (replicateM)

import Data.Kind
import GHC.TypeLits
import Data.Proxy

main = do
  result <- solveWith (cryptominisat5Path "kissat") $ do
    let dim = 5
        number = exists  @(Number 3)
        matrix = replicateM dim $ replicateM dim $ number
        monotone m = head (head m) >=? 1 && last (last m) >=? 1
        restricted m =
            head (transpose m) === encode ([1] <> replicate (dim-1) 0)
         && last m === encode (replicate (dim-1) 0 <> [1])
        gt p q = geq p q && last (head p) >? last (head q)
        geq p q = and $ do
          (xs,ys) <- zip p q ; (x,y) <- zip xs ys ; return $ x >=? y
    a <- matrix ; assert $ monotone a && restricted a
    b <- matrix ; assert $ monotone b && restricted b
    let -- lhs = times a b ; rhs = times b a
      a2 = times a a ; b2 = times b b
      a3 = times a2 a; b3 = times b2 b
      lhs = times a2 b2 ; rhs = times b3 a3
    assert $ gt lhs rhs
    return [a,b,lhs,rhs]
  case result of
    (Satisfied, Just ms) -> print $ map Vertical ms

times a b =
  flip map a $ \ row ->
  flip map (transpose b) $ \ col ->
  sum $ zipWith (*) row col

data Number (w :: Natural ) = Number { contents :: Bits, overflow :: Bit }

make :: forall (w :: Natural) . KnownNat w => Bits -> Number w
make (Bits xs) =
  let w = natVal (Proxy @w)
      (pre, post) = splitAt (fromIntegral w) xs
  in  Number { contents = Bits pre, overflow = or post }

instance KnownNat w => Num (Number w) where
  fromInteger i | i >= 0 = encode $ fromIntegral i
  a + b =
    let n = make @w $ contents a + contents b
    in  n { overflow = overflow a || overflow b || overflow n }
  a * b =
    let n = make @w $ contents a * contents b
    in  n { overflow = overflow a || overflow b || overflow n }
    
instance KnownNat w => Variable (Number w) where
  literally m = do
    let w = natVal (Proxy @w)
    xs <- replicateM (fromIntegral w) m
    return $ make @w $ Bits $ map Var xs

instance Equatable (Number w) where
  a === b = not (overflow a) && not (overflow b) && contents a === contents b

instance Orderable (Number w) where
  a <?  b = not (overflow a) && ( overflow b  || contents a <? contents b)
  a <=? b = not (overflow a) && not (overflow b) && contents a <=? contents b

instance Codec (Number w) where
  type Decoded (Number w) = Natural
  encode n = Number (encode $ fromIntegral n) false
  decode s a = do
    c <- decode s $ contents a
    o <- decode s $ overflow a
    return $ if o then error "overflow" else fromIntegral c

-- crude formatting:

newtype Vertical a = Vertical [a]
instance Show a => Show (Vertical a) where
  show (Vertical xs) =
    unlines $ zipWith (<>) ("[ " : repeat " , ") (map show xs)
            <> [ " ]" ]
