遺伝的アルゴリズムを書いてみる

hackernewsのリンク記事でに、Genetic Algorithm(遺伝的アルゴリズム)をJavaScriptで書いた

ってのを見たのですが、ソースがわけわからなかったので*1、理解するため、遺伝的アルゴリズムってのを調べて自分で書いてみました。

まず、Wikipediaをみたのですが、日本語と英語の記事に載せてある手続きが違うんですよね。

日本語版では交叉と突然変異は「どちらか」と書いてあるんですが、英語版は交叉したあと突然変異させるというようになってます。で、他のサイトをみると、どうも英語版の手続きが基本のようです。

手続きもすごく短い。簡単に言うと「繰り返し、適当に近そうな奴を選んで混ぜてみたり、たまにちょっとずらしたりしてみてテストさせる」、ってことのようで、遺伝うんぬんとは無関係に、普通に正解探しをやってるのと変わらない気もする。

ということで、それをそのままフレームワーク化したのが、GeneticAlgorithmクラスです。

class GeneticAlgorithm(object):
    def __init__(self, genetics):
        self.genetics = genetics
        pass

    def run(self):
        population = self.genetics.initial()
        while True:
            fits_pops = [(self.genetics.fitness(ch),  ch) for ch in population]
            if self.genetics.check_stop(fits_pops): break
            population = self.next(fits_pops)
            pass
        return population

    def next(self, fits):
        parents_generator = self.genetics.parents(fits)
        size = len(fits)
        nexts = []
        while len(nexts) < size:
            parents = next(parents_generator)
            cross = random.random() < self.genetics.probability_crossover()
            children = self.genetics.crossover(parents) if cross else parents
            for ch in children:
                mutate = random.random() < self.genetics.probability_mutation()
                nexts.append(self.genetics.mutation(ch) if mutate else ch)
                pass
            pass
        return nexts[0:size]
    pass

で、この中で使っている、initial()、fitness(chromo)、parents(pop)、crossover(parents)、mutation(chromo)などを、問題をこのアルゴリズムに合うようにエンコードし、各種戦術を選んで実装し、与えてあげれば計算してくれることになります。

選択や交叉、変異での戦術のほうは典型的なものから選べばいいので、結局重要なのは、「解きたい問題をどうエンコードするか」、であると思います。(でも、これについても英語版Wikipediaにはナップサック問題ではどうするかについて書いてあるけど、日本語版には戦術のバリエーションばかりでエンコードについての言及が一つもない状態だったりします。)

で、例の記事の問題は与えたテキストをGAで推測するもので、それをこのクラスでつかえるようにしたのが、以下です。

    class GuessText(GeneticFunctions):
        def __init__(self, target_text,
                     limit=200, size=400,
                     prob_crossover=0.9, prob_mutation=0.2):
            self.target = self.text2chromo(target_text)
            self.counter = 0

            self.limit = limit
            self.size = size
            self.prob_crossover = prob_crossover
            self.prob_mutation = prob_mutation
            pass

        # GeneticFunctions interface impls
        def probability_crossover(self):
            return self.prob_crossover

        def probability_mutation(self):
            return self.prob_mutation

        def initial(self):
            return [self.random_chromo() for j in range(self.size)]

        def fitness(self, chromo):
            # larger is better, matched == 0
            return -sum(abs(c - t) for c, t in zip(chromo, self.target))

        def check_stop(self, fits_populations):
            self.counter += 1
            if self.counter % 10 == 0:
                best_match = list(sorted(fits_populations))[-1][1]
                fits = [f for f, ch in fits_populations]
                best = max(fits)
                worst = min(fits)
                ave = sum(fits) / len(fits)
                print(
                    "[G %3d] score=(%4d, %4d, %4d): %r" %
                    (self.counter, best, ave, worst,
                     self.chromo2text(best_match)))
                pass
            return self.counter >= self.limit

        def parents(self, fits_populations):
            while True:
                father = self.tournament(fits_populations)
                mother = self.tournament(fits_populations)
                yield (father, mother)
                pass
            pass

        def crossover(self, parents):
            father, mother = parents
            index1 = random.randint(1, len(self.target) - 2)
            index2 = random.randint(1, len(self.target) - 2)
            if index1 > index2: index1, index2 = index2, index1
            child1 = father[:index1] + mother[index1:index2] + father[index2:]
            child2 = mother[:index1] + father[index1:index2] + mother[index2:]
            return (child1, child2)

        def mutation(self, chromosome):
            index = random.randint(0, len(self.target) - 1)
            vary = random.randint(-5, 5)
            mutated = list(chromosome)
            mutated[index] += vary
            return mutated

        # internals
        def tournament(self, fits_populations):
            alicef, alice = self.select_random(fits_populations)
            bobf, bob = self.select_random(fits_populations)
            return alice if alicef > bobf else bob

        def select_random(self, fits_populations):
            return fits_populations[random.randint(0, len(fits_populations)-1)]

        def text2chromo(self, text):
            return [ord(ch) for ch in text]
        def chromo2text(self, chromo):
            return "".join(chr(max(1, min(ch, 255))) for ch in chromo)

        def random_chromo(self):
            return [random.randint(1, 255) for i in range(len(self.target))]
        pass

呼び出し方は、以下のようにするだけ

GeneticAlgorithm(GuessText("Hello World!")).run()

結果は、このようになります。

$ python3 genetic.py
[G  10] score=(-136, -287, -503): 'BqSkp\x1aIri\x85|\x17'
[G  20] score=( -64,  -94, -136): 'Bfgkp\x1a`qwgu\x1b'
[G  30] score=( -39,  -54,  -71): 'Gekkp\x1aVqwku\x1e'
[G  40] score=( -25,  -32,  -44): 'Gflln\x1fVoskt\x1f'
[G  50] score=( -12,  -21,  -30): 'Gemkp Woskd\x1b'
[G  60] score=(  -4,  -10,  -19): 'Gfllo\x1fWorkd!'
[G  70] score=(  -1,   -3,   -9): 'Hello World '
[G  80] score=(   0,   -1,   -6): 'Hello World!'
[G  90] score=(   0,    0,   -8): 'Hello World!'
[G 100] score=(   0,    0,   -7): 'Hello World!'
[G 110] score=(   0,    0,   -8): 'Hello World!'
[G 120] score=(   0,    0,   -6): 'Hello World!'
[G 130] score=(   0,    0,  -12): 'Hello World!'
[G 140] score=(   0,    0,   -9): 'Hello World!'
[G 150] score=(   0,    0,   -9): 'Hello World!'
[G 160] score=(   0,    0,   -5): 'Hello World!'
[G 170] score=(   0,    0,   -8): 'Hello World!'
[G 180] score=(   0,    0,   -8): 'Hello World!'
[G 190] score=(   0,    0,   -6): 'Hello World!'
[G 200] score=(   0,    0,   -6): 'Hello World!'

一般的にはbetterを見つけるものだとは思うんですが、この例ではだいたい70回から130回程度で正解にたどり着くようです。

ソース全体

*1:人工知能とかの分野などは、エンジニアリング技術が弱くてソースが汚いのはよくあることだが