9c4aa76d70db5c819ce74bbf927a1e4bdafafc91
[fp.git] / src / HM.hs
1 -- |
2 -- Module      :  HM
3 -- Copyright   :  Tomáš Musil 2014
4 -- License     :  BSD-3
5 --
6 -- Maintainer  :  tomik.musil@gmail.com
7 -- Stability   :  experimental
8 --
9 -- This is a toy implementation of \-calculus with Hindley-Milner type system.
10
11 module HM
12   ( -- * Types
13     Type(..)
14   , TypeScheme(..)
15   , Term(..)
16   , TypedTerm(..)
17     -- * Type inference
18   , algW
19   , runTI
20   ) where
21
22 import Control.Monad.Except
23 import Control.Monad.State
24 import qualified Data.Set as Set
25 import qualified Data.Map as Map
26
27 import HM.Term
28 import HM.Parser
29
30 type Substitution = Map.Map TypeVarName Type
31 type VarS a = State Int a
32 type TypeEnv = Map.Map VarName TypeScheme
33 data TIState = TIState {tiSupply :: Int} deriving (Show)
34 type TI a = ExceptT String (State TIState) a
35
36 runTI :: TI a -> (Either String a, TIState)
37 runTI t = runState (runExceptT t) $ TIState 0
38
39 newVar :: TI Type
40 newVar = do
41   s <- get
42   put s {tiSupply = tiSupply s + 1}
43   return (TypeVar $ "a" ++ show (tiSupply s))
44
45 freeVarsT :: Type -> Set.Set TypeVarName
46 freeVarsT (Primitive _) = Set.empty
47 freeVarsT (TypeVar t) = Set.singleton t
48 freeVarsT (TypeFunction a b) = freeVarsT a `Set.union` freeVarsT b
49
50 freeVarsS :: TypeScheme -> Set.Set TypeVarName
51 freeVarsS (TScheme t) = freeVarsT t
52 freeVarsS (TSForAll v s) = v `Set.delete` freeVarsS s
53
54 substituteT :: Substitution -> Type -> Type
55 substituteT _ t@(Primitive _) = t
56 substituteT s t@(TypeVar v) = Map.findWithDefault t v s
57 substituteT s (TypeFunction a b) = TypeFunction (substituteT s a) (substituteT s b)
58
59 substituteS :: Substitution -> TypeScheme -> TypeScheme
60 substituteS s (TScheme t) = TScheme $ substituteT s t
61 substituteS s (TSForAll v t) = TSForAll v $ substituteS (v `Map.delete` s) t
62
63 idSub :: Substitution
64 idSub = Map.empty
65
66 composeSub :: Substitution -> Substitution -> Substitution
67 composeSub s1 s2 = Map.map (substituteT s1) s2 `Map.union` s1
68
69 varBind :: TypeVarName -> Type -> TI Substitution
70 varBind v t | t == TypeVar v = return idSub
71             | v `Set.member` freeVarsT t = fail $ "occur check failed: " ++ v ++ " in " ++ show t
72             | otherwise = return $ Map.singleton v t
73
74 instantiate :: TypeScheme -> TI Type
75 instantiate (TScheme t) = return t
76 instantiate (TSForAll v t) = do 
77   nv <- newVar
78   instantiate $ substituteS (Map.singleton v nv) t
79
80 generalize :: TypeEnv -> Type -> TypeScheme
81 generalize e t = foldr TSForAll (TScheme t) vars
82   where
83     vars = Set.toList $ freeVarsT t Set.\\ Set.unions (map freeVarsS $ Map.elems e)
84
85 unify :: Type -> Type -> TI Substitution
86 unify (TypeVar a) t = varBind a t
87 unify t (TypeVar a) = varBind a t
88 unify (TypeFunction a b) (TypeFunction a' b') = do
89   s1 <- unify a a'
90   s2 <- unify b b'
91   return $ s1 `composeSub` s2
92 unify (Primitive a) (Primitive b) | a == b = return idSub
93 unify a b = fail $ "cannot unify " ++ show a ++ " with " ++ show b
94
95 ti :: TypeEnv -> TypedTerm -> TI (Substitution, Type)
96 --ti _ (TTerm (Var v) (TScheme t@(Primitive _))) = return (idSub, t)
97 ti e (TTerm tr sch) = do
98   (s, t) <- ti e (NTTerm tr)
99   sch' <- instantiate sch
100   s' <- unify t sch'
101   return (s', substituteT s' sch')
102 ti e (NTTerm (Var v)) = case Map.lookup v e of
103   Nothing -> fail $ "unbound variable: " ++ v
104   Just sigma -> do
105     t <- instantiate sigma
106     return (idSub, t)
107 ti e (NTTerm (Lam x y)) = do
108   tv <- newVar
109   let e' = Map.insert x (TScheme tv) e
110   (s, t) <-  ti e' y
111   return (s, TypeFunction (substituteT s tv) t)
112 ti e (NTTerm (App a b)) = do
113   tv <- newVar 
114   (s1, t1) <- ti e a
115   (s2, t2) <- ti (Map.map (substituteS s1) e) b
116   s3 <- unify (substituteT s2 t1) (TypeFunction t2 tv)
117   return (s3 `composeSub` s2 `composeSub` s1, substituteT s3 tv)
118 ti e (NTTerm (Let x a b)) = do
119   (s1, t1) <- ti e a
120   let t' = generalize (Map.map (substituteS s1) e) t1
121       e' = Map.insert x t' e
122   (s2, t2) <- ti (Map.map (substituteS s1) e') b
123   return (s1 `composeSub` s2, t2)
124   
125  
126 algW :: TypedTerm -> TI Type
127 algW t = do
128   (s, u) <- ti Map.empty t
129   return $ substituteT s u