module RecursiveInference where


import Data.List (nub, union, (\\))
import Control.Monad (ap, liftM)


type Id = String

-- Variable names in the program are not allowed to start with '$'
enumId :: Int -> Id
enumId n = "$t_" ++ show n

data Kind
  = Star
  | Kfun Kind Kind
  deriving (Eq, Ord, Show)

data Type
  = TVar Tyvar
  | TCon Tycon
  | TAp Type Type
  | TGen Int
  deriving (Eq, Ord, Show)

data Tyvar
  = Tyvar Id Kind
  deriving (Eq, Ord, Show)

data Tycon
  = Tycon Id Kind
  deriving (Eq, Ord, Show)

data Scheme
  = Forall [Kind] Type
  deriving (Eq, Show)

tUnit :: Type
tUnit    = TCon (Tycon "()" Star)

tChar :: Type
tChar    = TCon (Tycon "Char" Star)

tInt :: Type
tInt     = TCon (Tycon "Int" Star)

tInteger :: Type
tInteger = TCon (Tycon "Integer" Star)

tFloat :: Type
tFloat   = TCon (Tycon "Float" Star)

tDouble :: Type
tDouble  = TCon (Tycon "Double" Star)


tList :: Type
tList    = TCon (Tycon "[]" (Kfun Star Star))

tArrow :: Type
tArrow   = TCon (Tycon "(->)" (Kfun Star (Kfun Star Star)))

tTuple2 :: Type
tTuple2  = TCon (Tycon "(,)" (Kfun Star (Kfun Star Star)))

tString    :: Type
tString     = list tChar

infixr      4 `fn`
fn         :: Type -> Type -> Type
a `fn` b    = TAp (TAp tArrow a) b

list       :: Type -> Type
list t      = TAp tList t

pair       :: Type -> Type -> Type
pair a b = TAp (TAp tTuple2 a) b


class HasKind t where
  kind :: t -> Kind

instance HasKind Tyvar where
  kind (Tyvar _ k) = k

instance HasKind Tycon where
  kind (Tycon _ k) = k

instance HasKind Type where
  kind (TCon tc) = kind tc
  kind (TVar u)  = kind u
  kind (TAp t _) = case (kind t) of
    (Kfun _ k) -> k

type Subst  = [(Tyvar, Type)]

nullSubst  :: Subst
nullSubst   = []

(+->)      :: Tyvar -> Type -> Subst
u +-> t     = [(u, t)]

class Types t where
  apply :: Subst -> t -> t
  tv    :: t -> [Tyvar]

instance Types Type where
  apply s (TVar u)  =
    case lookup u s of
      Just t  -> t
      Nothing -> TVar u
  apply s (TAp l r) = TAp (apply s l) (apply s r)
  apply _ t         = t

  tv (TVar u)  = [u]
  tv (TAp l r) = tv l `union` tv r
  tv _         = []


instance Types Scheme where
  apply s (Forall ks t) = Forall ks (apply s t)
  tv (Forall _ t) = tv t

instance Types a => Types [a] where
  apply s = map (apply s)
  tv = nub . concat . map tv



toScheme :: Type -> Scheme
toScheme t = Forall [] t

quantify      :: [Tyvar] -> Type -> Scheme
quantify vs t = Forall ks (apply s t)
 where vs' = [ v | v <- tv t, v `elem` vs ]
       ks = map kind vs'
       s = zip vs' (map TGen [0..])


freshInst               :: Scheme -> TI Type
freshInst (Forall ks t) = do
  ts <- mapM newTVar ks
  return (inst ts t)

class Instantiate t where
  inst  :: [Type] -> t -> t

instance Instantiate Type where
  inst ts (TAp l r) = TAp (inst ts l) (inst ts r)
  inst ts (TGen n)  = ts !! n
  inst _  t         = t


infixr 4 @@
(@@)       :: Subst -> Subst -> Subst
s1 @@ s2 = [ (u, apply s1 t) | (u,t) <- s2 ] ++ s1



mgu     :: Monad m => Type -> Type -> m Subst
varBind :: Monad m => Tyvar -> Type -> m Subst

mgu (TAp l r)  (TAp l' r') = do
  s1 <- mgu l l'
  s2 <- mgu (apply s1 r) (apply s1 r')
  return (s2 @@ s1)
mgu (TVar u)   t           =
  varBind u t
mgu t          (TVar u)    =
  varBind u t
mgu (TCon tc1) (TCon tc2)
           | tc1==tc2 = return nullSubst
mgu _  _              = fail "types do not unify"

varBind u t | t == TVar u      = return nullSubst
            | u `elem` tv t    = fail "occurs check fails"
            | kind u /= kind t = fail "kinds do not match"
            | otherwise        = return (u +-> t)


data Assump =
  Id :>: Scheme
  deriving (Show)

instance Types Assump where
  apply s (i :>: sch) = i :>: (apply s sch)
  tv (_ :>: sch)      = tv sch

find                 :: Monad m => Id -> [Assump] -> m Scheme
find i []             = fail ("unbound identifier: " ++ i)
find i ((i':>:sch):as) = if i==i' then return sch else find i as



---- monad


newtype TI a = TI (Subst -> Int -> (Subst, Int, a))

instance Functor TI where
  fmap = liftM

instance Applicative TI where
  pure = return
  (<*>) = ap

instance Monad TI where
  return x   = TI (\s n -> (s,n,x))
  TI f >>= g = TI (\s n -> case f s n of
                            (s',m,x) -> let TI gx = g x
                                        in  gx s' m)

runTI       :: TI a -> a
runTI (TI f) = x where (_,_,x) = f nullSubst 0

getSubst   :: TI Subst
getSubst = TI (\s n -> (s,n,s))

unify      :: Type -> Type -> TI ()
unify t1 t2 = do
  s <- getSubst
  u <- mgu (apply s t1) (apply s t2)
  extSubst u

extSubst   :: Subst -> TI ()
extSubst s' = TI (\s n -> (s'@@s, n, ()))

newTVar :: Kind -> TI Type
newTVar k = TI (\s n -> let v = Tyvar (enumId n) k
                        in  (s, n+1, TVar v))


type Infer e t = [Assump] -> e -> TI t


-- Inference

data Literal
  = LitInt  Integer
  | LitChar Char
  | LitRat  Rational
  | LitStr  String

tiLit :: Literal -> TI Type
tiLit (LitChar _) = return tChar
tiLit (LitInt _)  = return tInteger
tiLit (LitStr _)  = return tString
tiLit (LitRat _)  = return tFloat



data Expr
  = Var   Id
  | Lit   Literal
  | Const Assump
  | Ap    Expr Expr
  | Abs   Id Expr
  | Let   BindGroup Expr


tiExpr                       :: Infer Expr Type
tiExpr as (Var i)         = do
  sc <- find i as
  t <- freshInst sc
  return t
tiExpr _  (Lit l)         =
  tiLit l
tiExpr _  (Const (_ :>: sc)) = do
  t <- freshInst sc
  return t
tiExpr as (Ap e f)        = do
  te <- tiExpr as e
  tf <- tiExpr as f
  t  <- newTVar Star
  unify (tf `fn` t) te
  return t
tiExpr as (Abs i e)       = do
  t <- newTVar Star
  let sc = toScheme t
  te <- tiExpr ((i :>: sc) : as) e
  return $ t `fn` te
tiExpr as (Let binds e)    = do
  as' <- tiBindGroup as binds
  tiExpr (as' ++ as) e

type Impl = (Id, Expr)
type Expl = (Id, Scheme, Expr)
type BindGroup = ([Expl], [[Impl]])

tiBindGroup :: Infer BindGroup [Assump]
tiBindGroup as (es,iss) = do
  let as' = [ v :>: sc | (v, sc, _) <- es ]
  as'' <- tiSeq tiImpls (as' ++ as) iss
  _ <- mapM (tiExpl (as'' ++ as' ++ as)) es
  return $ as'' ++ as'

tiExpl :: Infer Expl ()
tiExpl as (i, sc, e) = do
  t <- freshInst sc
  tiBinding as e t
  s <- getSubst
  let t'  = apply s t
  let fs  = tv (apply s as)
  let gs  = tv t' \\ fs
  let sc' = quantify gs t'
  if sc /= sc' then
      fail "signature too general"
    else
      return ()

tiImpls :: Infer [Impl] [Assump]
tiImpls as bs = do
  ts <- mapM (\_ -> newTVar Star) bs
  let scs = map toScheme ts
  let is = map fst bs
  let as' = (zipWith (:>:) is scs) ++ as

  let exprs = map snd bs
  _ <- sequence $ zipWith (tiBinding as') exprs ts

  s <- getSubst
  let ts' = apply s ts
  let fs = tv (apply s as)
  let vss = map tv ts'
  let gs = foldr1 union vss \\ fs
  let scs' = map (quantify gs) ts'
  return $ zipWith (:>:) is scs'



tiBinding :: [Assump] -> Expr -> Type -> TI ()
tiBinding as e t = do
  t' <- tiExpr as e
  unify t t'


tiSeq :: Infer bg [Assump] -> Infer [bg] [Assump]
tiSeq _  _  []       =
  return []
tiSeq ti as (bs:bss) = do
  as'  <- ti as bs
  as'' <- tiSeq ti (as' ++ as) bss
  return $ as'' ++ as'


type Program = [BindGroup]

tiProgram :: [Assump] -> Program -> [Assump]
tiProgram as bgs = runTI $ do
  as' <- tiSeq tiBindGroup as bgs
  s <- getSubst
  return (apply s as')


-- bind is a convenience function for defining the examples
bind :: String -> Expr -> BindGroup
bind name expr = ([], [[(name, expr)]])


-- Types "let x = 123 in x"
example1 :: [Assump]
example1 = tiProgram a p
  where a = []
        e = Let (bind "x" (Lit $ LitInt 123)) (Var "x")
        p = [bind "e" e]

-- Types "(\x -> x + x)"
example2 :: [Assump]
example2 = tiProgram a p
  where plusT = tInt `fn` (tInt `fn` tInt)
        a     = ["+" :>: toScheme plusT]
        e     = Abs "x" (Ap (Ap (Var "+") (Var "x")) (Var "x"))
        p     = [bind "f" e]

example3 :: [Assump]
example3 = tiProgram a p
  where a = []
        idFn = Abs "x" (Var "x")
        expr = Ap (Ap (Var "id") (Var "id")) (Lit $ LitStr "polymorphism!")
        e = Let (bind "id" idFn) expr
        p = [bind "e" e]

-- Types "f x = 2 + (f (x + 1))"
example4 :: [Assump]
example4 = tiProgram a p
  where plusT = tInteger `fn` (tInteger `fn` tInteger)
        a = ["+" :>: toScheme plusT]
        xPlus1 = Ap (Ap (Var "+") (Var "x")) (Lit $ LitInt 1)
        recCall = Ap (Var "f") xPlus1
        e = Abs "x" (Ap (Ap (Var "+") (Lit $ LitInt 2)) recCall)
        p = [bind "f" e]

-- Types
--     f x = 2 * g x
--     g y = f (y + 1)
example5 :: [Assump]
example5 = tiProgram a p
  where plusT = tInteger `fn` (tInteger `fn` tInteger)
        timesT = tInteger `fn` (tInteger `fn` tInteger)
        a = ["+" :>: toScheme plusT, "*" :>: toScheme timesT]

        gx = Ap (Var "g") (Var "x")
        f = Abs "x" (Ap (Ap (Var "*") (Lit $ LitInt 2)) gx)

        yPlus1 = Ap (Ap (Var "+") (Var "y")) (Lit $ LitInt 1)
        g = Abs "y" (Ap (Var "f") yPlus1)

        p = [([], [[("f", f), ("g", g)]])]

-- Types
--     identity x = x
--     foo n = (identity identity) n
--
-- identity remains a -> a since it is in a different binding group from foo
example6 :: [Assump]
example6 = tiProgram a p
  where plusT = tInteger `fn` (tInteger `fn` tInteger)
        a = ["+" :>: toScheme plusT]

        identity = Abs "x" (Var "x")

        ii = Ap (Var "identity") (Var "identity")
        nPlus1 = Ap (Ap (Var "+") (Var "n")) (Lit $ LitInt 1)
        foo = Abs "n" (Ap ii nPlus1)

        p = [([], [[("identity", identity)], [("foo", foo)]])]

-- Types:
--    a x = [b x]
--
--    b :: a -> a
--    b y = let foo = c 'c' in y
--
--    c z = "foo" ++ a z
example7 :: [Assump]
example7 = tiProgram assumps p
  where concatT = Forall [Star] (TGen 0 `fn` (TGen 0 `fn` TGen 0))
        assumps = ["++" :>: concatT]

        bx = Ap (Var "b") (Var "x") -- b x
        listConSch = Forall [Star] (TGen 0 `fn` (list $ TGen 0))  -- a -> [a]
        listConAssump = "[]" :>: listConSch
        a = Abs "x" (Ap (Const listConAssump) bx) -- a x = [b x]

        cc = Ap (Var "c") (Lit $ LitChar 'c') -- c 'c'
        fooBinding = ([], [[("foo", cc)]]) -- foo = c 'c'
        b = Abs "y" (Let fooBinding (Var "y")) -- b y = let foo = c 'c' in y
        bScheme = Forall [Star] (TGen 0 `fn` TGen 0) -- a -> a

        az = Ap (Var "a") (Var "z") -- a z
        literalFoo = Lit $ LitStr "foo" -- "foo"
        c = Abs "z" (Ap (Ap (Var "++") literalFoo) az) -- c z = "foo" ++ a z

        p = [([("b", bScheme, b)], [[("a", a)], [("c", c)]])]



showExample :: [Assump] -> String
showExample example = unlines $ map showAssump example

showAssump :: Assump -> String
showAssump (i :>: (Forall _ t)) = i ++ " :: " ++ showType t

showType :: Type -> String
showType t = case t of
  TVar (Tyvar var _) -> var
  TCon (Tycon con _) -> con
  TAp t1 t2          -> case t1 of
    (TCon (Tycon "[]" _)) ->
      "[" ++ showType t2 ++ "]"
    (TAp (TCon (Tycon "(->)" _)) t12) ->
      showTypeParens t12 ++ " -> " ++ showType t2
    _ ->
      showTypeParens t1 ++ " " ++ showType t2
  TGen i             -> "_t" ++ show i

showTypeParens :: Type -> String
showTypeParens t = case t of
  TAp _ _ -> "(" ++ showType t ++ ")"
  _       -> showType t
