PythonでDSL記法化に挑戦してみた

PythonでのDSL風記述に挑戦するため、PythonでのProlog風のソルバーを作ってみました。

エンジンの実装は、 合併のお知らせ|OKIソフトウェア をベースにしています。

PythonPrologDSLをやってるものとしては、以下のものがありました:

では、関数のfunc_codeから変数を引っ張ってきて、適当に値を入れてcodeを直接execするというやり方になってます。これだと、自由変数が使えなくなるので、変数は引数で宣言するタイプで実装してみました。

結果としては、Prologっぽくはなってるけど、Pythonicじゃなくなってしまった感じですけれど(ベースのTiny Prolog風な使い方もできはします)。

例1は、おじいさんと孫を問い合わせるもの:

import rules
e = rules.Engine()

e[lambda father:
  father("ieyasu", "hirotada")
] = True
e[lambda father:
  father("hidetada", "ieyasu")
] = True
e[lambda father:
  father("hideyasu", "ieyasu")
] = True
e[lambda granpa, father, X, Y, Z, __: # __: cut operator
  granpa(X, Z) <= (father(X, Y), father(Y, Z), __)
] = True

for pat in e(lambda granpa, X: granpa(X, "hirotada")):
    print pat[lambda X: X]
    pass

ルールやクエリは、lambda中に記述するようにしてみました(defで関数にしてもかまわない)。利用する変数は、引数に並べる必要があります(ここは何とかしたいところだが)。cutはどれに当てるか悩んだのですが、"__"を使用するようにしました(Term.__pos__でもよかったかもしれない)。ルールは goal <= (hypo1, hypo2, ...)という感じです(Term.__le__を矢印っぽく使ってみる)。

クエリの結果は、generatorで引っ張ってくるのは、参考元のと同じですが、このときの取り出しも、lambdaで取れるようにしてみました。これはいまいちだったかも。


例2は、listのappendに、FFIでトレースを表示するようにさせてみたもの

import sys
import rules

e = rules.Engine()
e[lambda append, writeln, A, X, H, T, B, R: (
        # FFI: returns True or False
        writeln(X) <= (
            lambda env:
            sys.stdout.write("last: %s\n" % repr(env.expand(X))) is None
            ),
        append([], A, A) <= writeln(A),
        append(H | T, B, H | R) <= append(T, B, R), # V1 | V2: list pattern
        )] = True

for pat in e(lambda append, R:
             append([1, 2], [3, 4], R)):
    print pat[lambda R: R]
    pass

print
for pat in e(lambda append, A, B:
             append(A, B, [1, 2, 3, 4])):
    print pat[lambda A, B: [A, B]]

print
for pat in e(lambda append, A, B, C, D:
             append(A | B | C, D, [1, 2, 3])):
    print pat[lambda A, B, C, D: [A, B, C, D]]
    pass

一つのlambdaで複数のルールを定義できるようにしてあります。

リスト専用のパターンとして A | Bで car/cdrにマッチするようにさせてみました。Var.__or__を利用してます。__or__は左結合だけど、無理やり右結合風にマッチします(A | B | C は [A, B | C]という感じ)。


例3は、演算

import rules
e = rules.Engine()

e[lambda add30, X, R:
      add30(X, R) <= (X > 10, R == X + 30)
      # ==: when lhs is unbound Var, bind it
  ] = True

e[lambda eq, add50, X, R: (
        eq(X, X),
        add50(X, R) <= (X > 10, eq(R, X + 50))
        )] = True

for pat in e(lambda add30, R:
             add30(20, R)):
    print pat[lambda R: R]
    pass

print
for pat in e(lambda add50, R:
             add50(20, R)):
    print pat[lambda R: R]
    pass

print
for pat in e(lambda add50, R:
             add50(120, R)):
    print pat[lambda R: R]
    pass

変数に二項演算子を定義して、ルールっぽくなるようにしてみました(is風に==で束縛させるのようにしたのはやりすぎだったかもしれない)。

例4: ラムダ計算

import rules

e = rules.Engine()
e[lambda number,
  eval, ext, get, lam, app, ref, closure, cell, nil, not_found,
  X, Env, Arg, Val, Var, Body, Opr, Opd, Res, CEnv, EEnv, Const, Name, Rest,
  _Var, _Rest, _Val, _Env, __: (
        number(X) <= (lambda env: isinstance(env.expand(X), int)),

        eval(Env, lam(Arg, Body), closure(Env, Arg, Body)),
        eval(Env, app(Opr, Opd), Res) <= (
            eval(Env, Opr, closure(CEnv, Arg, Body)),
            eval(Env, Opd, Val),
            ext(CEnv, Arg, Val, EEnv),
            eval(EEnv, Body, Res),
            ),
        eval(Env, ref(Var), Res) <= get(Env, Var, Res),
        eval(_Env, Const, Const) <= number(Const),

        ext(Env, Var, Val, cell(Var, Val, Env)),

        get(nil, _Var, not_found),
        get(cell(Var, Val, _Rest), Var, Val) <= __,
        get(cell(Name, _Val, Rest), Var, Res) <= (
            Name != Var,
            get(Rest, Var, Res),
            ),
        )] = True

for pat in e(lambda eval, app, lam, ref, nil, x, Res:
                 eval(nil, app(lam(x, ref(x)), 10), Res)):
    print pat[lambda Res: Res]

prolog版等→PrologコードをHaskellの型システムで書く - ラシウラ

実装コード: rules.py

"""
DSL based rule engine
inspired from tiny prolog: http://www.okisoft.co.jp/esc/prolog/in-python.html

instanciate engine
>>> e = Engine()

Defining rules
>>> e[lambda father:
...   father("ieyasu", "hirotada")
... ] = True
>>> e[lambda father:
...   father("hidetada", "ieyasu")
... ] = True
>>> e[lambda granpa, father, X, Y, Z, __:
...   granpa(X, Z) <= (father(X, Y), father(Y, Z), __)
... ] = True

Querying
>>> [pat[lambda X: X] for pat in e(lambda granpa, X:
...                                granpa("hidetada", X))]
['hirotada']
"""

def _eval(func):
    names = func.func_code.co_varnames[:func.func_code.co_argcount]
    args = []
    for name in names:
        if name == "__": args.append(Cut())
        elif name[0] == "_" or name[0].isupper(): args.append(Var(name))
        else: args.append(Term(name))
        pass
    return func(*args)

class Engine(object):
    def __init__(self, rules=None):
        if rules is None: rules = []
        self.rules = rules
        pass
    def __repr__(self):
        return "Engine(%s)" % repr(self.goals)
    def __setitem__(self, func, bool):
        obj = _eval(func)
        if not bool: return
        self.append(obj, bool)
        pass
    def __call__(self, func):
        obj = _eval(func)
        if isinstance(obj, Term):
            return self.query(obj)
        return None
    def append(self, obj, bool=True):
        if isinstance(obj, Term): self.rules.append((Rule(obj), bool))
        if isinstance(obj, Rule): self.rules.append((obj, bool))
        if isinstance(obj, tuple):
            for item in obj: self.append(item, bool)
            pass
        pass

    def query(self, goal):
        env = Env()
        for e in self.resolve([goal], env, [False]):  yield e
        # print "clean: %s" % repr(env)
        pass

    def resolve(self, goals, env, cut):
        if len(goals) == 0:
            yield env
            return
        goal = goals[0]
        if isinstance(goal, Cut):
            for e in self.resolve(goals[1:], env, cut): yield e
            cut[0] = True
            return
        if not isinstance(goal, Term) and callable(goal):
            res = goal(env)
            if res:
                for e in self.resolve(goals[1:], env, cut): yield e
                pass
            return
        rcut = [False]
        for rule, bool in self.rules:
            if cut[0] or rcut[0]: break
            if goal.name != rule.term.name: continue
            renv = Env()
            env.store()
            unified = self.unify((goal, env), (rule.term, renv))
            if unified:
                for _ in self.resolve(rule.hypos, renv, rcut):
                    for e in self.resolve(goals[1:], env, cut): yield e
                    if cut[0]: rcut[0] = True
                    pass
                pass
            env.load()
            pass
        pass

    def unify(self, a, b):
        while True:
            avar, aenv = a
            bvar, benv = b
            if isinstance(avar, Var):
                apair = aenv.get(avar)
                if apair is None:
                    b = benv.ref(bvar)
                    if a is not b:
                        bvar, benv = b
                        if not isinstance(bvar, Term) and callable(bvar):
                            bvar = bvar(benv)
                            pass
                        aenv.put(avar, (bvar, benv))
                        pass
                    return True
                else:
                    avar, aenv = apair
                    a = aenv.ref(avar)
                    continue
                pass
            elif isinstance(bvar, Var):
                a, b = b, a
                continue
            else: break
            pass

        avar, aenv = a
        bvar, benv = b
        if isinstance(avar, Term) and isinstance(bvar, Term):
            if avar.name != bvar.name: return False
            avar = avar.subs
            bvar = bvar.subs
            pass

        if (isinstance(avar, list) and isinstance(bvar, list) or
            isinstance(avar, tuple) and isinstance(bvar, tuple)):
            if len(avar) != len(bvar): return False
            for i in range(len(avar)):
                res = self.unify((avar[i], aenv), (bvar[i], benv))
                if not res: return False
                pass
            return True

        if isinstance(avar, Pair) and isinstance(bvar, Pair):
            res = self.unify((avar.left, aenv), (bvar.left, benv))
            if not res: return False
            res = self.unify((avar.right, aenv), (bvar.right, benv))
            return res

        if isinstance(bvar, Pair) and isinstance(avar, list):
            if len(avar) == 0: return False
            res = self.unify((avar[0], aenv), (bvar.left, benv))
            if not res: return False
            res = self.unify((avar[1:], aenv), (bvar.right, benv))
            return res

        if isinstance(avar, Pair) and isinstance(bvar, list):
            if len(bvar) == 0: return False
            res = self.unify((avar.left, aenv), (bvar[0], benv))
            if not res: return False
            res = self.unify((avar.right, aenv), (bvar[1:], benv))
            return res

        return avar == bvar
    pass

class Env(object):
    def __init__(self, map=None):
        if map is None: map = {}
        self.map = map
        self.stack = []
        pass
    def __getitem__(self, func):
        obj = _eval(func)
        return self.expand(obj)
    def __repr__(self):
        return "Env(%s)" % repr(self.map)
    def store(self):
        self.stack.append(self.map.copy())
        pass
    def load(self):
        self.map = self.stack.pop()
        pass
    def get(self, var):
        if not isinstance(var, Var): return None
        return self.map.get(var.name, None)
    def put(self, var, pair):
        if var.name[0] == "_": return
        self.map[var.name] = pair
        pass
    def remove(self, var):
        del self.map[var]
        pass
    def ref(self, var):
        if isinstance(var, list): return (var, self)
        pair = self.get(var)
        if not pair: return (var, self)
        term, env = pair
        if isinstance(term, Var): return env.ref(term)
        return pair
    def expand(self, term):
        term, env = self.ref(term)
        if isinstance(term, Pair):
            l = env.expand(term.left)
            r = env.expand(term.right)
            if r is None: r = []
            return [l] + r
        if isinstance(term, Term):
            return Term(term.name, env.expand(term.subs))
        if isinstance(term, list):
            return [env.expand(item) for item in term]
        if isinstance(term, tuple):
            return tuple([env.expand(item) for item in term])
        return term
    pass


class Cut(object):
    def __repr__(self):
        return "Cut()"
    pass

class Var(object):
    def __init__(self, name):
        self.name = name
        pass
    def __repr__(self):
        return "Var(%s)" % self.name
    # list pattern
    def __or__(self, other):
        return Pair(self, other)
    def __ror__(self, other):
        return Pair(other, self)

    # predicate
    def __eq__(self, other):
        def eq(env):
            l = env.expand(self)
            r = other(env) if callable(other) else other
            if isinstance(l, Var):
                env.put(l, (r, env))
                return True
            return l == r
        return eq
    def __ne__(self, other):
        return (lambda env: env.expand(self) !=
                (other(env) if callable(other) else other))
    def __lt__(self, other):
        return (lambda env: env.expand(self) <
                (other(env) if callable(other) else other))
    def __gt__(self, other):
        return (lambda env: env.expand(self) >
                (other(env) if callable(other) else other))
    def __le__(self, other):
        return (lambda env: env.expand(self) <=
                (other(env) if callable(other) else other))
    def __ge__(self, other):
        return (lambda env: env.expand(self) >=
                (other(env) if callable(other) else other))
    def __nonzero__(self):
        return (lambda env: not env.expand(self))

    # arith
    def __add__(self, other):
        return (lambda env: env.expand(self) +
                (other(env) if callable(other) else other))
    def __sub__(self, other):
        return (lambda env: env.expand(self) -
                (other(env) if callable(other) else other))
    def __mul__(self, other):
        return (lambda env: env.expand(self) *
                (other(env) if callable(other) else other))
    def __div__(self, other):
        return (lambda env: env.expand(self) /
                (other(env) if callable(other) else other))
    def __floordiv__(self, other):
        return (lambda env: env.expand(self) //
                (other(env) if callable(other) else other))
    def __mod__(self, other):
        return (lambda env: env.expand(self) %
                (other(env) if callable(other) else other))
    def __pow__(self, other):
        return (lambda env: env.expand(self) **
                (other(env) if callable(other) else other))
    def __radd__(self, other):
        return (lambda env: (other(env) if callable(other) else other) +
                env.expand(self))
    def __rsub__(self, other):
        return (lambda env: (other(env) if callable(other) else other) -
                env.expand(self))
    def __rmul__(self, other):
        return (lambda env: (other(env) if callable(other) else other) *
                env.expand(self))
    def __rdiv__(self, other):
        return (lambda env: (other(env) if callable(other) else other) /
                env.expand(self))
    def __rfloordiv__(self, other):
        return (lambda env: (other(env) if callable(other) else other) //
                env.expand(self))
    def __rmod__(self, other):
        return (lambda env: (other(env) if callable(other) else other) %
                env.expand(self))
    def __rpow__(self, other):
        return (lambda env: (other(env) if callable(other) else other) **
                env.expand(self))
    pass

class Pair(object):
    def __init__(self, left, right):
        self.left = left
        self.right = right
        pass
    def __repr__(self):
        return "Pair(%s, %s)" % (repr(self.left), repr(self.right))

    # list pattern
    def __or__(self, other):
        return Pair(self.left, Pair(self.right, other))
    pass

class Term(object):
    def __init__(self, name, subs=()):
        self.name = name
        self.subs = subs
        pass
    def __repr__(self):
        return "Term(%s, %s)" % (self.name, repr(self.subs))

    # structure
    def __call__(self, *subs):
        return Term(self.name, subs)

    # predicate rule
    def __le__(self, hypos):
        if not isinstance(hypos, tuple): hypos = (hypos, )
        goal = Rule(self, hypos)
        return goal
    pass

class Rule(object):
    def __init__(self, term, hypos=()):
        self.term = term
        self.hypos = hypos
        pass
    def __repr__(self):
        return "Rule(%s, %s)" % (repr(self.term), repr(self.hypos))
    pass

if __name__ == "__main__":
    import doctest
    doctest.testmod()
    pass