more work on HM interpreter
[fp.git] / src / HM.hs
index 83bbe04..e643938 100644 (file)
--- a/src/HM.hs
+++ b/src/HM.hs
 -- Maintainer  :  tomik.musil@gmail.com
 -- Stability   :  experimental
 --
 -- Maintainer  :  tomik.musil@gmail.com
 -- Stability   :  experimental
 --
--- This is a toy implementation of λ-calculus with Hindley-Milner type system.
+-- This is a toy implementation of \-calculus with Hindley-Milner type system.
 
 module HM
   ( -- * Types
     Type(..)
 
 module HM
   ( -- * Types
     Type(..)
-  , Term
+  , TypeScheme(..)
+  , Term(..)
+  , TypedTerm(..)
     -- * Type inference
   , algW
   ) where
 
     -- * Type inference
   , algW
   ) where
 
+import Control.Monad.Except
+import Control.Monad.State
+import qualified Data.Set as Set
+import qualified Data.Map as Map
+
 import HM.Term
 import HM.Parser
 
 import HM.Term
 import HM.Parser
 
-type Substitution = TypeScheme -> TypeScheme
-
-fresh :: TypeVarName
-fresh = undefined
-
-substitute :: TypeScheme -> TypeVarName -> TypeScheme -> TypeScheme
-substitute = undefined
-
-unify :: TypeScheme -> TypeScheme -> Either String Substitution
-unify (TScheme (Primitive a)) (TScheme (Primitive b)) | a == b = Right id
-unify (TScheme (TypeVar a)) (TScheme (TypeVar b)) | a == b = Right id
-unify a b = Left $ "cannot unify " ++ show a ++ " with " ++ show b
-
-algW :: HMTerm -> Either String TypeScheme
-algW (HMTerm (Var _) t) = Right t
-algW (HMTerm (Lambda x t) (TScheme p)) = do
-  let v = TScheme (TypeVar fresh)
-      np = substitute v x t
-  unify p np
-algW (HMTerm (App u v) t) = do
-  tu <- algW u
-  tv <- algW v
-  case tu of
-    (TScheme (TypeFunction a b)) -> do
-      unify a tv
-      return b
-    _ -> Left $ "cannot apply " ++ show tu ++ " to " ++ show tv
+type Substitution = Map.Map TypeVarName Type
+type VarS a = State Int a
+type TypeEnv = Map.Map VarName TypeScheme
+data TIState = TIState {tiSupply :: Int} deriving (Show)
+type TI a = ExceptT String (State TIState) a
+
+runTI :: TI a -> (Either String a, TIState)
+runTI t = runState (runExceptT t) $ TIState 0
+
+newVar :: TI Type
+newVar = do
+  s <- get
+  put s {tiSupply = tiSupply s + 1}
+  return (TypeVar $ "a" ++ show (tiSupply s))
+
+freeVarsT :: Type -> Set.Set TypeVarName
+freeVarsT (Primitive _) = Set.empty
+freeVarsT (TypeVar t) = Set.singleton t
+freeVarsT (TypeFunction a b) = freeVarsT a `Set.union` freeVarsT b
+
+freeVarsS :: TypeScheme -> Set.Set TypeVarName
+freeVarsS (TScheme t) = freeVarsT t
+freeVarsS (TSForAll v s) = v `Set.delete` freeVarsS s
+
+substituteT :: Substitution -> Type -> Type
+substituteT _ t@(Primitive _) = t
+substituteT s t@(TypeVar v) = Map.findWithDefault t v s
+substituteT s (TypeFunction a b) = TypeFunction (substituteT s a) (substituteT s b)
+
+substituteS :: Substitution -> TypeScheme -> TypeScheme
+substituteS s (TScheme t) = TScheme $ substituteT s t
+substituteS s (TSForAll v t) = TSForAll v $ substituteS (v `Map.delete` s) t
+
+idSub :: Substitution
+idSub = Map.empty
+
+composeSub :: Substitution -> Substitution -> Substitution
+composeSub s1 s2 = Map.map (substituteT s1) s2 `Map.union` s1
+
+varBind :: TypeVarName -> Type -> TI Substitution
+varBind v t | t == TypeVar v = return idSub
+            | v `Set.member` freeVarsT t = fail $ "occur check failed: " ++ v ++ " in " ++ show t
+            | otherwise = return $ Map.singleton v t
+
+instantiate :: TypeScheme -> TI Type
+instantiate (TScheme t) = return t
+instantiate (TSForAll v t) = do 
+  nv <- newVar
+  instantiate $ substituteS (Map.singleton v nv) t
+
+generalize :: TypeEnv -> Type -> TypeScheme
+generalize e t = foldr TSForAll (TScheme t) vars
+  where
+    vars = Set.toList $ freeVarsT t Set.\\ Set.unions (map freeVarsS $ Map.elems e)
+
+unify :: Type -> Type -> TI Substitution
+unify (TypeVar a) t = varBind a t
+unify t (TypeVar a) = varBind a t
+unify (TypeFunction a b) (TypeFunction a' b') = do
+  s1 <- unify a a'
+  s2 <- unify b b'
+  return $ s1 `composeSub` s2
+unify (Primitive a) (Primitive b) | a == b = return idSub
+unify a b = fail $ "cannot unify " ++ show a ++ " with " ++ show b
+
+tiLit :: Literal -> TI (Substitution, Type)
+tiLit (LBool _) = return (idSub, Primitive "Bool")
+tiLit (LInt _) = return (idSub, Primitive "Integer")
+tiLit (LFunc If) = do
+  a <- newVar
+  return (idSub, Primitive "Bool" `TypeFunction` (a `TypeFunction` (a `TypeFunction` a)))
+
+ti :: TypeEnv -> TypedTerm -> TI (Substitution, Type)
+ti e (TTerm tr sch) = do
+  (s, t) <- ti e (NTTerm tr)
+  sch' <- instantiate sch
+  s' <- unify t sch'
+  return (s', substituteT s' sch')
+ti _ (NTTerm (Lit l)) = tiLit l
+ti e (NTTerm (Var v)) = case Map.lookup v e of
+  Nothing -> fail $ "unbound variable: " ++ v
+  Just sigma -> do
+    t <- instantiate sigma
+    return (idSub, t)
+ti e (NTTerm (Lam x y)) = do
+  tv <- newVar
+  let e' = Map.insert x (TScheme tv) e
+  (s, t) <-  ti e' y
+  return (s, TypeFunction (substituteT s tv) t)
+ti e (NTTerm (App a b)) = do
+  tv <- newVar 
+  (s1, t1) <- ti e a
+  (s2, t2) <- ti (Map.map (substituteS s1) e) b
+  s3 <- unify (substituteT s2 t1) (TypeFunction t2 tv)
+  return (s3 `composeSub` s2 `composeSub` s1, substituteT s3 tv)
+ti e (NTTerm (Let x a b)) = do
+  (s1, t1) <- ti e a
+  let t' = generalize (Map.map (substituteS s1) e) t1
+      e' = Map.insert x t' e
+  (s2, t2) <- ti (Map.map (substituteS s1) e') b
+  return (s1 `composeSub` s2, t2)
+  
+algW :: TypedTerm -> Either String Type
+algW t = fst . runTI $ do
+  (s, u) <- ti Map.empty t
+  return $ substituteT s u