型推論つづきのつづき

高階関数型推論の問題

とりあえずできたが、今度は推論で完全じゃない部分を入れてないためにおきる問題にぶつかる。

Haskellだと普通はidid x = x xは型が(((...->t)->t)-t)となる型エラーで定義できない。
型なしの式としてみれば、このididは引数によっては実行でき停止する。

  • 例: id x = x; (idid id) 10 => (id id) 10 => id 10 => 10

でも無限ループに落ちるパターンもある。

  • 例: idid idid => idid idid => ...

前回のアルゴリズム実装で型推論を実装すると、idid :: ((a->b)->b)->b と結果が返り、idid にどんな関数を渡した式の型がvar typeになる。id idはよくて、x xは禁止にするにはどうすればいいか。それ以前に id x = x; id idの型が関数型にならないのも問題だが。なんとなく関数型のunificationが甘いのが原因だろうと思っている。

# syntax tree nodes for the lang
class Expr:
    def __repr__(self):
        return self.__class__.__name__ + repr(self.__dict__)

    pass

class Val(Expr):
    def __init__(self, python_value):
        self.value = python_value
        pass

    pass

class Add(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right
        pass
    pass


class Let(Expr):
    def __init__(self, name, value, body):
        self.name = name
        self.value = value
        self.body = body
        pass
    pass

class Ref(Expr):
    def __init__(self, name):
        self.name = name
        pass
    pass

class Lambda(Expr):
    def __init__(self, params, body):
        self.params = params
        self.body = body
        pass
    pass

class Apply(Expr):
    def __init__(self, func, args):
        self.func = func
        self.args = args
        pass
    pass

class Typed(Expr):
    """ for type declaration
    """
    def __init__(self, expr, type):
        self.expr = expr
        self.type = type
        pass
    pass

# types for the lang
class Type:
    pass

class TAtom(Type):
    """atomic type"""
    def __init__(self, label):
        self.label = label
        pass

    def __repr__(self):
        return self.label
    pass

class TFunc(Type):
    """ function type """
    def __init__(self, params_t, ret_t):
        self.params_t = params_t
        self.ret_t = ret_t
        pass

    def __repr__(self):
        return "(%s) -> %s" % (repr(self.params_t), repr(self.ret_t))

    pass

class TVar(Type):
    """variable type: must instanciate for each point of code
    """
    def __init__(self):
        pass

    def __repr__(self):
        return "<variable type>"

    pass
# for inference
class TypeCheckError(Exception):
    pass

class TypeEnv:
    """ Type of Name Table for inferencing """
    def __init__(self, parent=None):
        self.table = {}
        self.parent = parent
        pass

    def put(self, name, type):
        self.table[name] = type
        pass

    def get(self, name):
        try:
            return self.table[name]
        except:
            if self.parent is not None:
                return self.parent.get(name)
            else:
                raise TypeCheckError("%s not found in type env" % name)
            pass
        pass

    pass
class TypeMap:
    """ Variable Type Reference Table for inferencing """
    def __init__(self, parent=None):
        self.table = {}
        self.parent = parent
        pass

    def put(self, vtype, type):
        self.table[vtype] = type
        pass

    def get(self, vtype):
        try:
            return self.table[vtype]
        except:
            if self.parent is not None:
                return self.parent.get(vtype)
            else:
                return vtype
            pass
        pass

    pass
def inference(expr, table, type_map):
    """(Expr, TypeEnv, TypeMap) -> Type, TypeMap
    """
    if expr.__class__ is Val:
        return TAtom(type(expr.value).__name__), type_map

    if expr.__class__ is Add:
        func_type = TFunc([TAtom("int"), TAtom("int")], TAtom("int"))

        left_type, type_map = inference(expr.left, table, type_map)
        right_type, type_map = inference(expr.right, table, type_map)
        ret_type = TAtom("int")
        compare_type = TFunc([left_type, right_type], ret_type)
        try:
            type_map = unify([(func_type, compare_type)], type_map)
            return ret_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Let:
        value_type, type_map = inference(expr.value, table, type_map)
        table = TypeEnv(table)
        table.put(expr.name, value_type)
        body_type, type_map = inference(expr.body, table, type_map)
        return body_type, type_map


    if expr.__class__ is Ref:
        name_type = table.get(expr.name)
        return name_type, type_map

    if expr.__class__ is Lambda:
        table = TypeEnv(table)
        for key in expr.params:
            table.put(key, TVar())
            pass
        body_type, type_map = inference(expr.body, table, type_map)
        arg_types = [concrete_type(table.get(key), type_map, {}) \
                     for key in expr.params]
        return TFunc(arg_types, body_type), type_map


    if expr.__class__ is Apply:
        func_type, type_map = inference(expr.func, table, type_map)
        arg_types = []
        for arg in expr.args:
            arg_type, type_map = inference(arg, table, type_map)
            arg_types.append(arg_type)
            pass
        ret_type = TVar()
        compare_type = TFunc(arg_types, ret_type)
        try:
            type_map = unify([(func_type, compare_type)], type_map)

            return concrete_type(ret_type, type_map, {}), type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Typed:
        expr_type, type_map = inference(expr.expr, table, type_map)
        try:
            type_map = unify([(expr.type, expr_type)], type_map)
            return expr_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass
    raise Exception("Not supported Expr: %s" % expr.__class__.__name__)
def concrete_type(type, type_map, cache):
    if cache.has_key(type):
        return cache[type]
    if type.__class__ is TAtom:
        cache[type] = type
        return type
    if type.__class__ is TFunc:
        tfunc = TFunc([], None)
        cache[type] = tfunc
        for param_t in type.params_t:
            concrete_p = concrete_type(param_t, type_map, cache)
            tfunc.params_t.append(concrete_p)
            pass
        tfunc.ret_t = concrete_type(type.ret_t, type_map, cache)
        return tfunc
    if type.__class__ is TVar:
        concrete_t = type_map.get(type)
        cache[type] = concrete_t
        concrete_t = concrete_type(concrete_t, type_map, cache)
        cache[type] = concrete_t
        return concrete_t
    raise TypeCheckError("mismatched structure of types: %s %s" % (repr(left), repr(right)))
def unify(queue, type_map):
    "([(Type, Type)], TypeMap) -> TypeMap"
    if len(queue) == 0:
        return type_map
    left, right = queue.pop()

    if left.__class__ is TVar:
        type_map = TypeMap(type_map)
        type_map.put(left, right)
        return type_map

    if left.__class__ is TFunc and right.__class__ is TFunc:
        if len(left.params_t) != len(right.params_t):
            raise TypeChechError("mismatched param size: %s %s" % (left, right))
        for param_l, param_r in zip(left.params_t, right.params_t):
            queue.append((param_l, param_r))
            pass
        queue.append((left.ret_t, right.ret_t))
        return unify(queue, type_map)

    if left.__class__ is TAtom and right.__class__ is TAtom:
        if left.label != right.label:
            raise TypeChechError("mismatched atomic types: %s %s" % (left.label, right.label))
        return unify(queue, type_map)

    if right.__class__ is TVar:
        queue.append((right, left))
        return unify(queue, type_map)

    raise TypeCheckError("mismatched structure of types: %s %s" % (repr(left), repr(right)))
def print_type(expr):
    t, tmap = inference(expr, TypeEnv(), TypeMap())
    print concrete_type(t, tmap, {})
    pass
# valid: <var> -> <var>
#  id = \x -> x
#  id
print "expr2_0"
expr2_0 = Let("id", Lambda(["x"], Ref("x")),
              Ref("id"))
print_type(expr2_0)

# valid: int
#  id = \x -> x
#  id
print "expr2_1"
expr2_1 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Val(10)]))
print_type(expr2_1)

# valid: <var> -> <var>
#  id: 'a->'a
#  id(id): 'a->'a
print "expr2_2"
expr2_2 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Ref("id")]))
print_type(expr2_2)

# valid?:
#  idid: (('a->'b)->'b)->'b)
print "expr2_3"
expr2_3 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Ref("idid"))
print_type(expr2_3)

print "expr2_3_1"
expr2_3_1 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                Let("id", Lambda(["x"], Ref("x")),
                    Apply(Ref("idid"), [Ref("id")])))
print_type(expr2_3_1)

print "expr2_3_2"
expr2_3_2 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                Let("id", Lambda(["x"], Ref("x")),
                    Apply(Apply(Ref("idid"), [Ref("id")]), [Val(10)])))
print_type(expr2_3_2)

# invalid:
#  idid = \x -> x(x)
#  idid(idid)=>idid(idid)=>...
print "expr2_4"
expr2_4 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Apply(Ref("idid"), [Ref("idid")]))
print_type(expr2_4)