型推論つづきのつづき
とりあえずできたが、今度は推論で完全じゃない部分を入れてないためにおきる問題にぶつかる。
Haskellだと普通はidid x = x xは型が(((...->t)->t)-t)となる型エラーで定義できない。
型なしの式としてみれば、このididは引数によっては実行でき停止する。
- 例: id x = x; (idid id) 10 => (id id) 10 => id 10 => 10
でも無限ループに落ちるパターンもある。
- 例: idid idid => idid idid => ...
前回のアルゴリズム実装で型推論を実装すると、idid :: ((a->b)->b)->b と結果が返り、idid にどんな関数を渡した式の型がvar typeになる。id idはよくて、x xは禁止にするにはどうすればいいか。それ以前に id x = x; id idの型が関数型にならないのも問題だが。なんとなく関数型のunificationが甘いのが原因だろうと思っている。
# syntax tree nodes for the lang class Expr: def __repr__(self): return self.__class__.__name__ + repr(self.__dict__) pass class Val(Expr): def __init__(self, python_value): self.value = python_value pass pass class Add(Expr): def __init__(self, left, right): self.left = left self.right = right pass pass class Let(Expr): def __init__(self, name, value, body): self.name = name self.value = value self.body = body pass pass class Ref(Expr): def __init__(self, name): self.name = name pass pass class Lambda(Expr): def __init__(self, params, body): self.params = params self.body = body pass pass class Apply(Expr): def __init__(self, func, args): self.func = func self.args = args pass pass class Typed(Expr): """ for type declaration """ def __init__(self, expr, type): self.expr = expr self.type = type pass pass # types for the lang class Type: pass class TAtom(Type): """atomic type""" def __init__(self, label): self.label = label pass def __repr__(self): return self.label pass class TFunc(Type): """ function type """ def __init__(self, params_t, ret_t): self.params_t = params_t self.ret_t = ret_t pass def __repr__(self): return "(%s) -> %s" % (repr(self.params_t), repr(self.ret_t)) pass class TVar(Type): """variable type: must instanciate for each point of code """ def __init__(self): pass def __repr__(self): return "<variable type>" pass
# for inference class TypeCheckError(Exception): pass class TypeEnv: """ Type of Name Table for inferencing """ def __init__(self, parent=None): self.table = {} self.parent = parent pass def put(self, name, type): self.table[name] = type pass def get(self, name): try: return self.table[name] except: if self.parent is not None: return self.parent.get(name) else: raise TypeCheckError("%s not found in type env" % name) pass pass pass class TypeMap: """ Variable Type Reference Table for inferencing """ def __init__(self, parent=None): self.table = {} self.parent = parent pass def put(self, vtype, type): self.table[vtype] = type pass def get(self, vtype): try: return self.table[vtype] except: if self.parent is not None: return self.parent.get(vtype) else: return vtype pass pass pass
def inference(expr, table, type_map): """(Expr, TypeEnv, TypeMap) -> Type, TypeMap """ if expr.__class__ is Val: return TAtom(type(expr.value).__name__), type_map if expr.__class__ is Add: func_type = TFunc([TAtom("int"), TAtom("int")], TAtom("int")) left_type, type_map = inference(expr.left, table, type_map) right_type, type_map = inference(expr.right, table, type_map) ret_type = TAtom("int") compare_type = TFunc([left_type, right_type], ret_type) try: type_map = unify([(func_type, compare_type)], type_map) return ret_type, type_map except: print "Type Error at: %s" % repr(expr) raise pass if expr.__class__ is Let: value_type, type_map = inference(expr.value, table, type_map) table = TypeEnv(table) table.put(expr.name, value_type) body_type, type_map = inference(expr.body, table, type_map) return body_type, type_map if expr.__class__ is Ref: name_type = table.get(expr.name) return name_type, type_map if expr.__class__ is Lambda: table = TypeEnv(table) for key in expr.params: table.put(key, TVar()) pass body_type, type_map = inference(expr.body, table, type_map) arg_types = [concrete_type(table.get(key), type_map, {}) \ for key in expr.params] return TFunc(arg_types, body_type), type_map if expr.__class__ is Apply: func_type, type_map = inference(expr.func, table, type_map) arg_types = [] for arg in expr.args: arg_type, type_map = inference(arg, table, type_map) arg_types.append(arg_type) pass ret_type = TVar() compare_type = TFunc(arg_types, ret_type) try: type_map = unify([(func_type, compare_type)], type_map) return concrete_type(ret_type, type_map, {}), type_map except: print "Type Error at: %s" % repr(expr) raise pass if expr.__class__ is Typed: expr_type, type_map = inference(expr.expr, table, type_map) try: type_map = unify([(expr.type, expr_type)], type_map) return expr_type, type_map except: print "Type Error at: %s" % repr(expr) raise pass raise Exception("Not supported Expr: %s" % expr.__class__.__name__)
def concrete_type(type, type_map, cache): if cache.has_key(type): return cache[type] if type.__class__ is TAtom: cache[type] = type return type if type.__class__ is TFunc: tfunc = TFunc([], None) cache[type] = tfunc for param_t in type.params_t: concrete_p = concrete_type(param_t, type_map, cache) tfunc.params_t.append(concrete_p) pass tfunc.ret_t = concrete_type(type.ret_t, type_map, cache) return tfunc if type.__class__ is TVar: concrete_t = type_map.get(type) cache[type] = concrete_t concrete_t = concrete_type(concrete_t, type_map, cache) cache[type] = concrete_t return concrete_t raise TypeCheckError("mismatched structure of types: %s %s" % (repr(left), repr(right)))
def unify(queue, type_map): "([(Type, Type)], TypeMap) -> TypeMap" if len(queue) == 0: return type_map left, right = queue.pop() if left.__class__ is TVar: type_map = TypeMap(type_map) type_map.put(left, right) return type_map if left.__class__ is TFunc and right.__class__ is TFunc: if len(left.params_t) != len(right.params_t): raise TypeChechError("mismatched param size: %s %s" % (left, right)) for param_l, param_r in zip(left.params_t, right.params_t): queue.append((param_l, param_r)) pass queue.append((left.ret_t, right.ret_t)) return unify(queue, type_map) if left.__class__ is TAtom and right.__class__ is TAtom: if left.label != right.label: raise TypeChechError("mismatched atomic types: %s %s" % (left.label, right.label)) return unify(queue, type_map) if right.__class__ is TVar: queue.append((right, left)) return unify(queue, type_map) raise TypeCheckError("mismatched structure of types: %s %s" % (repr(left), repr(right)))
def print_type(expr): t, tmap = inference(expr, TypeEnv(), TypeMap()) print concrete_type(t, tmap, {}) pass
# valid: <var> -> <var> # id = \x -> x # id print "expr2_0" expr2_0 = Let("id", Lambda(["x"], Ref("x")), Ref("id")) print_type(expr2_0) # valid: int # id = \x -> x # id print "expr2_1" expr2_1 = Let("id", Lambda(["x"], Ref("x")), Apply(Ref("id"), [Val(10)])) print_type(expr2_1) # valid: <var> -> <var> # id: 'a->'a # id(id): 'a->'a print "expr2_2" expr2_2 = Let("id", Lambda(["x"], Ref("x")), Apply(Ref("id"), [Ref("id")])) print_type(expr2_2) # valid?: # idid: (('a->'b)->'b)->'b) print "expr2_3" expr2_3 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])), Ref("idid")) print_type(expr2_3) print "expr2_3_1" expr2_3_1 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])), Let("id", Lambda(["x"], Ref("x")), Apply(Ref("idid"), [Ref("id")]))) print_type(expr2_3_1) print "expr2_3_2" expr2_3_2 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])), Let("id", Lambda(["x"], Ref("x")), Apply(Apply(Ref("idid"), [Ref("id")]), [Val(10)]))) print_type(expr2_3_2) # invalid: # idid = \x -> x(x) # idid(idid)=>idid(idid)=>... print "expr2_4" expr2_4 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])), Apply(Ref("idid"), [Ref("idid")])) print_type(expr2_4)