Pythonで型推論のアルゴリズムを書いてみるが

記憶だけをたどって、型推論アルゴリズムがかけるかどうかやってみる。

let recやパラメトリック型は抜きで、まず書いてみようとしたが途中で断念。
(つづき→ http://d.hatena.ne.jp/bellbind/20071122)

以下の型推論のコードはたぶん不正確です。

構文木

# syntax tree nodes for the lang
class Expr:
    def __repr__(self):
        return self.__class__.__name__ + repr(self.__dict__)

    pass

class Val(Expr):
    def __init__(self, python_value):
        self.value = python_value
        pass

    pass

class Add(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right
        pass
    pass


class Let(Expr):
    def __init__(self, name, value, body):
        self.name = name
        self.value = value
        self.body = body
        pass
    pass

class Ref(Expr):
    def __init__(self, name):
        self.name = name
        pass
    pass

class Lambda(Expr):
    def __init__(self, params, body):
        self.params = params
        self.body = body
        pass
    pass

class Apply(Expr):
    def __init__(self, func, args):
        self.func = func
        self.args = args
        pass
    pass

class Typed(Expr):
    """ for type declaration
    """
    def __init__(self, expr, type):
        self.expr = expr
        self.type = type
        pass
    pass

let ( (a b) (c d) body) = (apply (lambda (a c) body) (b d))なので、LetがひとつでLambdaが複数の引数なのは変ではある。これはある意味Addのせい。

# types for the lang
class Type:
    pass

class TAtom(Type):
    """atomic type"""
    def __init__(self, label):
        self.label = label
        pass

    def __repr__(self):
        return self.label
    pass

class TFunc(Type):
    """ function type """
    def __init__(self, params_t, ret_t):
        self.params_t = params_t
        self.ret_t = ret_t
        pass

    def __repr__(self):
        return "(%s) -> %s" % (repr(self.params_t), repr(self.ret_t))

    pass

class TVar(Type):
    """variable type: must instanciate for each point of code
    """
    def __init__(self):
        self.concrete = None
        pass

    def __repr__(self):
        if self.concrete is None:
            return "<variable type>"
        else:
            return "<type=%s>" % repr(self.concrete)
        pass

    pass

TVarは変数型。ある意味型推論内部処理専用。多相型導入すれば使えるようにするかも。

プログラム例

# example
expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))

print(expr0)

evaluator

コードを実行させる機能は型推論には不要だけど、お約束で実装しておく。

ランタイム環境や値など

# for evaluator
class VarNotFound(Exception):
    pass

class Env:
    """ Env table"""
    def __init__(self, parent=None):
        self.parent = parent
        self.table = {}
        pass

    def put(self, name, value):
        self.table[name] = value
        pass

    def get(self, name):
        try:
            return self.table[name]
        except:
            if self.parent is not None:
                return self.parent.get(name)
            else:
                raise VarNotFound(name)
            pass
        pass
    pass

# for runtime value
class Value:
    pass

class VInt(Value):
    def __init__(self, value):
        self.value = value
        pass
    def __repr__(self):
        return repr(value)
    pass

class VFunc(Value):
    def __init__(self, code, env):
        self.env = env
        self.code = code
        pass
    def __repr__(self):
        return "func: %s" % repr(self.code)
    pass

evaluate

def evaluate(expr, env):
    """ (Expr, Env) -> Val
    """
    if expr.__class__ is Val:
        return VInt(expr.value)
    if expr.__class__ is Add:
        return VInt(evaluate(expr.left, env).value + \
                    evaluate(expr.right, env).value)
    if expr.__class__ is Let:
        child_env = Env(env)
        child_env.put(expr.name, evaluate(expr.value, env))
        return evaluate(expr.body, child_env)
    if expr.__class__ is Ref:
        return env.get(expr.name)
    if expr.__class__ is Lambda:
        return VFunc(expr, env)
    if expr.__class__ is Apply:
        func = evaluate(expr.func, env)
        body = func.code.body
        child_env = Env(func.env)
        for index, name in enumerate(func.code.params):
            arg_value = evaluate(expr.args[index], env)
            child_env.put(name, arg_value)
            pass
        return evaluate(body, child_env)
    if expr.__class__ is Typed:
        return evaluate(expr.expr)
    raise Exception("Not supported Expr: %s" % expr.__class__.__name__)

visitorとかにするのも面倒なので、exprのクラスで場合わけ^^;

実行コード例

print(evaluate(expr0, Env()))

型推論

型環境とか

# for inference
class TypeCheckError(Exception):
    pass

class TypeEnv:
    """ Type Table for inferencing """
    def __init__(self, parent=None):
        self.table = {}
        self.parent = parent
        pass

    def put(self, name, type):
        self.table[name] = type
        pass

    def get(self, name):
        try:
            return self.table[name]
        except:
            if self.parent is not None:
                return self.parent.get(name)
            else:
                raise TypeCheckError("%s not found in type env" % name)
            pass
        pass
    pass

evaluate用のEnvとほぼ一緒。追加専用Map

推論関数など、たぶん関数名は変。

def inference(expr, table):
    """(Expr, TypeEnv) -> Type
    """
    if expr.__class__ is Val:
        return TAtom("int")

    if expr.__class__ is Add:
        func_type = TFunc([TAtom("int"), TAtom("int")], TAtom("int"))

        left_type = inference(expr.left, table)
        right_type = inference(expr.right, table)
        compare_type = TFunc([left_type, right_type], TAtom("int"))
        try:
            check_type(func_type, compare_type, [])
            return TAtom("int")
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Let:
        value_type = inference(expr.value, table)
        table = TypeEnv(table)
        table.put(expr.name, value_type)
        body_type = inference(expr.body, table)
        return body_type

    if expr.__class__ is Ref:
        name_type = table.get(expr.name)
        return name_type

    if expr.__class__ is Lambda:
        table = TypeEnv(table)
        for key in expr.params:
            table.put(key, TVar())
            pass
        body_type = inference(expr.body, table)
        arg_types = [table.get(key) for key in expr.params]
        return TFunc(arg_types, body_type)


    if expr.__class__ is Apply:
        func_type = inference(expr.func, table)
        arg_types = [inference(arg, table) for arg in expr.args]
        ret_type = TVar()
        compare_type = TFunc(arg_types, ret_type)
        try:
            check_type(func_type, compare_type, [])
            if ret_type.concrete is None:
                # parametric
                return ret_type
            else:
                return ret_type.concrete
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Typed:
        expr_type == inference(expr.expr, table)
        try:
            check_type(expr.type, expr_type, [])
            return expr_type
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass
    raise Exception("Not supported Expr: %s" % expr.__class__.__name__)

意味論でのルール(だいたい証明図表現になってる)に当たる部分。基本はプログラムになってる式を、ルールの下側のパターンでルールが決まって、上へたどっていく感じ。その過程で型の構造をチェックする。

Applyでは、関数側の型と、実際の引数と戻り値を型変数としてつくった関数型とを作って、この二つの同型チェックという感じ。Addも同様。Typedは宣言した型と式の型をチェックする。ほかは単に型を作って返すだけ。

その同型チェックがこれ。

def check_type(left, right, checked_tvars):
    if left.__class__ is TAtom:
        if right.__class__ is TAtom:
            if left.label != right.label:
                raise TypeCheckError()
            return
        elif right.__class__ is TVar:
            check_type(right, left, checked_tvars)
            return
        else:
            raise TypeCheckError()
        pass

    elif left.__class__ is TFunc:
        if right.__class__ is TFunc:
            if len(left.params_t) != len(right.params_t):
                raise TypeCheckError()
            for param_l, param_r in zip(left.params_t, right.params_t):
                check_type(param_l, param_r, checked_tvars)
                pass
            check_type(left.ret_t, right.ret_t, checked_tvars)
            return
        elif right.__class__ is TVar:
            check_type(right, left, checked_tvars)
            return
        else:
            raise TypeCheckError()
        pass

    elif left.__class__ is TVar:
        if left in checked_tvars:
            return
        checked_tvars.append(left)

        if left.concrete is None:
            left.concrete = get_concrete(right)
        else:
            check_type(left.concrete, right, checked_tvars)
            return
        pass

    else:
        raise Exception("Not supported Type: %s" % left.__class__.__name__)
    pass

def get_concrete(type):
    if type is None:
        return None
    if type.__class__ is TVar:
        return get_concrete(type.concrete)
    return type

prologの推論ぽいのがここ。
構造のチェックだけじゃなくて中で型変数の解決もしてるけど。型変数をキャッシュして自己再帰になる式(id(id)のようなもの)を回避してたりと、仕様はちょっと怪しい(型変数に参照先を持たせてるせいか。たぶんevalのように型変数環境を使うように実装するべきなんだろう)。

コード例

# examples

# valid: <type=int>
print "expr0"
expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))
print inference(expr0, TypeEnv())

# structral error
print "expr1e"
expr1e = Let("x", Val(10),
             Let("func", Lambda(["num"],
                                Add(Ref("x"), Ref("num"))),
                 Apply(Ref("func"), [Lambda(["x"], Val(5))])))

try:
    print inference(expr1e, TypeEnv()) # raise TypeCheckError
except TypeCheckError, e:
    print "expr1e is ERROR"
    pass

# valid: <var> -> <var>
print "expr2_0"
expr2_0 = Let("id", Lambda(["x"], Ref("x")),
              Ref("id"))
print inference(expr2_0, TypeEnv())

# valid: int
print "expr2_1"
expr2_1 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Val(10)]))
print inference(expr2_1, TypeEnv())

# valid: <var> -> <var>
print "expr2_2"
expr2_2 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Ref("id")]))
print inference(expr2_2, TypeEnv())

# valid:
print "expr2_3"
expr2_3 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Ref("idid"))
print inference(expr2_3, TypeEnv())

# valid:
print "expr2_4"
expr2_4 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Apply(Ref("idid"), [Ref("idid")]))
print inference(expr2_4, TypeEnv())

# invalid
print "expr2_5"
expr2_5 = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Apply(Ref("idid"), [Val(10)]))
try:
    print inference(expr2_5, TypeEnv())
except:
    print "expr2_5 is ERROR"
    pass

実行結果

expr0
int
expr1e
Type Error at: Apply{'args': [Lambda{'body': Val{'value': 5}, 'params': ['x']}], 'func': Ref{'name': 'func'}}
expr1e is ERROR
expr2_0
([<variable type>]) -> <variable type>
expr2_1
<variable type>
expr2_2
<variable type>
expr2_3
([<type=([<type=([...]) -> <variable type>>]) -> <variable type>>]) -> <variable type>
expr2_4
<variable type>
expr2_5
Type Error at: Apply{'args': [Val{'value': 10}], 'func': Ref{'name': 'idid'}}
expr2_5 is ERROR