graphvizを使ってツリー構造をpng画像化できる基本ノードクラス

概要

  • ツリーを構築でき、構造をpng画像として出力できるノードクラスが欲しい
  • 存在するかもしれないが、作ってみる
  • graphviz、pydotplusを使用した
  • DOMへの敬意を込めてappend_child()というメソッド名をjsから拝借した

良かった点

print_labels()、print_edges()をベタに実装してみて、ほとんど構造が同じであることに気が付いた(ともに再帰処理を行う)。print_labels()はツリーを辿りながら(その時点での)親のみを使用する。print_edges()は子も利用する。従って、ノード処理用の関数を呼び出すタイミングだけが異なる。この違いを抽象化できた。

処理タイミングのイメージ (一応)

列車の分岐ポイント。走行中の列車が分岐ポイントに差し掛かった時、そのポイントが(現在のスコープにおける)である。分岐ポイントを通過し切っていない段階では、子ノードはまだ存在しない(辿り着いていない)。親と子を同時に処理に利用したい場合は、次の分岐ポイント(現在の親からみた子)まで待つ必要がある。このことがそのまま処理タイミングの違いにつながる。

実装 (Python)

# dot_writer.py
import sys
import os
import pydotplus


class BaseNode(object):
    """ツリーのノードを表現するクラス"""

    # ノードIDの発行に使用するカウンター
    __counter__ = 0
    
    def __init__(self, label):
        BaseNode.__counter__ += 1
        self._id = BaseNode.__counter__    # ノードID
        self._label = str(label)    # ノードの表現
        self._children = []    # 子ノード

    def __str__(self):
        s = "<%s %d[%s]>" % (
            self.__class__.__name__, self._id, self._label,
        )
        return s

    def append_child(self, node):
        """子ノードを追加する"""
        if isinstance(node, BaseNode):
            self._children.append(node)

    @staticmethod
    def traverse(parent, func, rtype="a", depth=0):
        """処理関数を受け取って再帰処理を行う"""
        
        # 関数の呼び出しタイミングで切り分ける
        if rtype == "a":
            func(parent, depth)    # ここで呼び出す
            if len(parent._children) > 0:
                depth += 1
                for child in parent._children:
                    BaseNode.traverse(child, func, depth=depth)
        elif rtype == "b":
            if len(parent._children) > 0:
                depth += 1
                for child in parent._children:
                    func(parent, child, depth)    # 引数が異なる
                    BaseNode.traverse(child, func, rtype=rtype, depth=depth)

    @staticmethod
    def print_labels(parent):
        """ノードの定義部分を生成する"""

        def p(parent, depth):
            """ノードの定義部分を表示する"""
            template = "%s\"%s\" [label=\"%s\"]"
            indent = " " * 4
            print(template % (indent, parent._id, parent._label))

        # ノードの定義部分を出力する
        BaseNode.traverse(parent, p)

    @staticmethod
    def print_edges(parent):
        """ノードの接続情報を書き出す"""

        def p(parent, child, depth):
            """ノード間の接続情報を表示する"""
            template = "%s\"%s\" -> \"%s\""
            indent = " " * 8
            print(template % (indent, parent._id, child._id))

        # ノードの接続情報を出力する
        BaseNode.traverse(parent, p, rtype="b")

    @staticmethod
    def write_dot(root, path, shape="box", redirect=True):
        """グラフ情報を.dotファイルに書き込む"""

        # ファイルを作成して標準出力をリダイレクトさせる
        # redirect=Trueの場合、以降のすべてのprint()の実行結果が
        # 指定ファイルに書き込まれる
        if redirect:
            f_out = open(path, "w")
            sys.stdout = f_out

        # インデント部分
        indent = " " * 4

        # グラフ定義の開始部分を書き出す
        def_start_lines = [
            "digraph {",
            "%snode [shape=%s]" % (indent, shape),
        ]
        def_start = "\n".join(def_start_lines)
        print(def_start)

        # ノードの定義部分を書き出す
        BaseNode.print_labels(root)

        # ノードの接続情報を書き出す
        BaseNode.print_edges(root)

        # グラフ定義の終了部分を書き出す
        def_end = "}"
        print(def_end)

        # 標準出力を元に戻す
        if redirect:
            sys.stdout = sys.__stdout__

    @staticmethod
    def write_png(dot_path, output_name, remove_dot=False):
        """
        png画像を生成する
        
        .dotファイルはグラフ情報のパースの際に利用されるだけなので必要により削除可
        """

        if not os.path.exists(dot_path):
            emsg = "%s not found" % dot_path
            raise FileNotFoundError(emsg)
        else:
            with open(dot_path, "r") as f:
                dot_data = f.read()

                # グラフを生成する
                graph = pydotplus.graph_from_dot_data(dot_data)

                # png画像を生成する
                graph.write_png(output_name)

        # dotファイルを削除する
        if remove_dot:
            os.remove(dot_path)

使用例

def main():
    """BaseNodeの使用例"""

    # ノード名からノードを生成
    labels = ["root", "c1-1", "c1-2", "c2-1", "c2-2"]
    nodes = [BaseNode(label) for label in labels]
    
    # ルートノードに子を追加 (root → c1-1, c1-2)
    root = nodes[0]
    for node in nodes[1:3]: root.append_child(node)

    # 子に孫を追加 (c1-2 → c2-1, c2-2)
    for node in nodes[3:]: nodes[2].append_child(node)

    # .dotファイルを出力
    name = "tree_sample"
    dot = "./%s.dot" % name
    BaseNode.write_dot(root, dot)

    # ツリー構造をpng画像に出力
    png = "./%s.png" % name
    BaseNode.write_png(dot, png, remove_dot=True)


if __name__ == '__main__':
    main()

以下の画像が生成される。
f:id:zdassen:20170609223025p:plain