PythonのgeneratorでMapReduce

pythonにはリスト用のmapやreduce組み込み関数もあるけれど、generatorを使って平行プログラミング風にGoogleMapReduce風サーバを書いてみた。サーバと入ってもgeneratorなので、実際の実行はシングルスレッド上で処理されるです。

generatorはsendメソッドを使うことで、erlangのprocess(とかRubyのfiberとか)のようにも使うことができる、のは周知のこと

  • g = spawn(gf) => g = gf(); g.next()
  • g ! m => g.send(m)
  • receive m -> ... => m = yield; ...
  • 自己再帰呼び出し => ループ

ついでにgeneratorはg.close()で終了させれます。

コード: MapReduceサーバ

def server(docs):
    mappers = []
    for doc in docs:
        m = mapper(doc)
        m.next()
        mappers.append(m)
        pass

    while True:
        mapfunc, reducefunc, post = yield

        r = reducer(reducefunc, post)
        r.next()
        for m in mappers:
            m.send((mapfunc, r))
            pass
        r.send(None)
        pass
    pass

def mapper(doc):
    while True:
        mapfunc, reducer = yield
        reducer.send(mapfunc(doc))
        pass
    pass

def reducer(reducefunc, post):
    result = {}
    while True:
        mapped = yield
        if mapped is None:
            post(result)
            continue
        for k in mapped:
            if k in result: result[k] = reducefunc(k, [result[k], mapped[k]])
            else: result[k] = mapped[k]
            pass
        pass
    pass

reducerが終了を検知する方法はsend(None)ではしょってます(本当にマルチプロセス化するなら、mapperの数と受け取った数を比べるとか、だろうか)。

利用例

# example
docs = ["foo bar buzz",
        "foo bar quux",
        "bar buzz quux",
        "hoge hoge hoge",
        "hoge huga"]
s = server(docs)
s.next()

# map
def wordcount(doc):
    from collections import defaultdict
    result = defaultdict(lambda: 0)
    words = doc.split()
    for w in words:
        result[w] += 1
        pass
    return dict(result)

def wordexist(doc):
    result = {}
    for w in doc.split():
        result[w] = 1
        pass
    return result

# reduce
def collectcount(key, values):
    return sum(values)

# post processing
def printout(result):
    print(result)
    pass

# invoke
s.send((wordcount, collectcount, printout))
s.send((wordexist, collectcount, printout))

実行結果

$ python mapreduce.py
{'bar': 3, 'huga': 1, 'hoge': 4, 'quux': 2, 'buzz': 2, 'foo': 2}
{'bar': 3, 'huga': 1, 'hoge': 2, 'quux': 2, 'buzz': 2, 'foo': 2}

余談: generatorをrequest/response風サーバとして使う場合

上記のサーバ風generatorの書き方は以下のようになります:

def actor(init, act):
  state = init()
  while True:
      m = yield
      state = act(state, m)

このサーバはsendでメッセージを送りつける一方で、サーバでのメッセージ処理の「続きの処理」をしたい場合は、上記mapper/reducerのように、「続きの処理」をsendで受け取るように記述する必要があります。


そうではなく、Request/Response風にsendで結果を受け取って、その結果を使って続きの処理をしたい場合は、以下のようにします:

def reqres(handler):
   m = yield
   while True:
     r = handler(m)
     m = yield r

たとえば、受け取ったものを大文字にするサーバ

server = reqres(lambda s: s.upper())
server.next()
print server.send("abc")
print server.send("def")
server.close()

ただし、Webサーバ等と違い、上記handler中でサーバ自分自身にsendすることはできません。generatorはその処理中で自分自身にsend/nextすることはできないからです(やるとValueErrorが発生する)。

追記

MapReduceってmap関数で複数のプロセスを作って、それにdocsを分割して流し込むんだっけ(mapプロセス中でreduceし、最後にmapプロセス間でreduce)。

def server(docs):
    import math
    mappercount = int(math.sqrt(len(docs)))
    while True:
        mapfunc, reducefunc, post = yield

        mappers = [mapper(mapfunc, reducefunc) for i in xrange(mappercount)]
        for m in mappers: m.next()

        r = reducer(mappercount, reducefunc, post)
        r.next()
        for i in xrange(mappercount):
            s = slice(i, None, mappercount)
            mappers[i].send((docs[s], r))
            pass
        pass
    pass

def reduce_all(reducefunc, result, mapped):
    for k in mapped:
        if k in result: result[k] = reducefunc(k, [result[k], mapped[k]])
        else: result[k] = mapped[k]
        pass
    pass

def mapper(mapfunc, reducefunc):
    while True:
        docs, reducer = yield
        result = {}
        for doc in docs:
            mapped = mapfunc(doc)
            reduce_all(reducefunc, result, mapped)
            pass
        reducer.send(result)
        pass
    pass

def reducer(mappercount, reducefunc, post):
    result = {}
    count = 0
    while True:
        mapped = yield
        reduce_all(reducefunc, result, mapped)
        count += 1
        if count == mappercount:
            post(result)
            continue
        pass
    pass