fix more bugs
[fp.git] / src / HM.hs
1 -- |
2 -- Module      :  HM
3 -- Copyright   :  Tomáš Musil 2014
4 -- License     :  BSD-3
5 --
6 -- Maintainer  :  tomik.musil@gmail.com
7 -- Stability   :  experimental
8 --
9 -- This is a toy implementation of \-calculus with Hindley-Milner type system.
10
11 module HM
12   ( -- * Types
13     Type(..)
14   , TypeScheme(..)
15   , Term(..)
16   , TypedTerm(..)
17     -- * Type inference
18   , algW
19   ) where
20
21 import Control.Monad.Except
22 import Control.Monad.State
23 import qualified Data.Set as Set
24 import qualified Data.Map as Map
25
26 import HM.Term
27 import HM.Parser
28
29 type Substitution = Map.Map TypeVarName Type
30 type VarS a = State Int a
31 type TypeEnv = Map.Map VarName TypeScheme
32 data TIState = TIState {tiSupply :: Int} deriving (Show)
33 type TI a = ExceptT String (State TIState) a
34
35 runTI :: TI a -> (Either String a, TIState)
36 runTI t = runState (runExceptT t) $ TIState 0
37
38 newVar :: TI Type
39 newVar = do
40   s <- get
41   put s {tiSupply = tiSupply s + 1}
42   return (TypeVar $ "a" ++ show (tiSupply s))
43
44 freeVarsT :: Type -> Set.Set TypeVarName
45 freeVarsT (Primitive _) = Set.empty
46 freeVarsT (TypeVar t) = Set.singleton t
47 freeVarsT (TypeFunction a b) = freeVarsT a `Set.union` freeVarsT b
48
49 freeVarsS :: TypeScheme -> Set.Set TypeVarName
50 freeVarsS (TScheme t) = freeVarsT t
51 freeVarsS (TSForAll v s) = v `Set.delete` freeVarsS s
52
53 substituteT :: Substitution -> Type -> Type
54 substituteT _ t@(Primitive _) = t
55 substituteT s t@(TypeVar v) = Map.findWithDefault t v s
56 substituteT s (TypeFunction a b) = TypeFunction (substituteT s a) (substituteT s b)
57
58 substituteS :: Substitution -> TypeScheme -> TypeScheme
59 substituteS s (TScheme t) = TScheme $ substituteT s t
60 substituteS s (TSForAll v t) = TSForAll v $ substituteS (v `Map.delete` s) t
61
62 idSub :: Substitution
63 idSub = Map.empty
64
65 composeSub :: Substitution -> Substitution -> Substitution
66 composeSub s1 s2 = Map.map (substituteT s1) s2 `Map.union` s1
67
68 varBind :: TypeVarName -> Type -> TI Substitution
69 varBind v t | t == TypeVar v = return idSub
70             | v `Set.member` freeVarsT t = fail $ "occur check failed: " ++ v ++ " ~ " ++ show t
71             | otherwise = return $ Map.singleton v t
72
73 instantiate :: TypeScheme -> TI Type
74 instantiate (TScheme t) = return t
75 instantiate (TSForAll v t) = do 
76   nv <- newVar
77   instantiate $ substituteS (Map.singleton v nv) t
78
79 generalize :: TypeEnv -> Type -> TypeScheme
80 generalize e t = foldr TSForAll (TScheme t) vars
81   where
82     vars = Set.toList $ freeVarsT t Set.\\ Set.unions (map freeVarsS $ Map.elems e)
83
84 unify :: Type -> Type -> TI Substitution
85 unify (TypeVar a) t = varBind a t
86 unify t (TypeVar a) = varBind a t
87 unify (TypeFunction a b) (TypeFunction a' b') = do
88   s1 <- unify a a'
89   s2 <- unify (substituteT s1 b) (substituteT s1 b')
90   return $ s1 `composeSub` s2
91 unify (Primitive a) (Primitive b) | a == b = return idSub
92 unify a b = fail $ "cannot unify " ++ show a ++ " with " ++ show b
93
94 tiLit :: Literal -> TI (Substitution, Type)
95 tiLit (LBool _) = return (idSub, Primitive "Bool")
96 tiLit (LInt _) = return (idSub, Primitive "Integer")
97 tiLit (LFunc If) = do
98   a <- newVar
99   return (idSub, Primitive "Bool" `TypeFunction` (a `TypeFunction` (a `TypeFunction` a)))
100
101 ti :: TypeEnv -> TypedTerm -> TI (Substitution, Type)
102 ti e (TTerm tr sch) = do
103   (s, t) <- ti e (NTTerm tr)
104   sch' <- instantiate sch
105   s' <- unify t sch'
106   return (s', substituteT s' sch')
107 ti _ (NTTerm (Lit l)) = tiLit l
108 ti e (NTTerm (Var v)) = case Map.lookup v e of
109   Nothing -> fail $ "unbound variable: " ++ v
110   Just sigma -> do
111     t <- instantiate sigma
112     return (idSub, t)
113 ti e (NTTerm (Lam x y)) = do
114   tv <- newVar
115   let e' = Map.insert x (TScheme tv) e
116   (s, t) <-  ti e' y
117   return (s, TypeFunction (substituteT s tv) t)
118 ti e (NTTerm (App a b)) = do
119   tv <- newVar 
120   (s1, t1) <- ti e a
121   (s2, t2) <- ti (Map.map (substituteS s1) e) b
122   s3 <- unify (substituteT s2 t1) (TypeFunction t2 tv)
123   return (s3 `composeSub` s2 `composeSub` s1, substituteT s3 tv)
124 ti e (NTTerm (Let x a b)) = do
125   (s1, t1) <- ti e a
126   let t' = generalize (Map.map (substituteS s1) e) t1
127       e' = Map.insert x t' e
128   (s2, t2) <- ti (Map.map (substituteS s1) e') b
129   return (s1 `composeSub` s2, t2)
130  
131 algW :: TypedTerm -> Either String Type
132 algW t = fst . runTI $ do
133   (s, u) <- ti Map.empty t
134   return $ substituteT s u