Pynote

Python、機械学習、画像処理について

networkx - 点の位置を指定して、グラフを見やすく表示する方法について

概要

networkx で作成したグラフを graphviz で描画する場合に、点の位置を指定することで見やすいグラフを作成する方法について紹介する。

自動で点をレイアウトする。

graphviz には、グラフの点を配置するいくつかのアルゴリズムが組み込まれている。
これを利用すると、点同士が重なったりしないように自動的にレイアウトしてくれる。

import networkx as nx
from IPython.display import Image, display

# グラフを作成する。
G = nx.star_graph(10)

# 各レイアウト方法で描画する。
progs = ['neato', 'dot', 'twopi', 'circo', 'fdp']
for p in progs:
    A = nx.nx_agraph.to_agraph(G)  # AGraph に変換する。
    A.node_attr.update(fixedsize=True, width=0.35, height=0.35)  # 点の大きさを固定する。
    A.graph_attr.update(abelloc='b', label=p)  # タイトルをつける。
    
    png = A.draw(format='png', prog=p)
    display(Image(png))


明示的に点の位置を指定する。

点の pos 属性に pos="x, y!" と指定することで点の位置を明示的に指定できる。(例: pos="1,3!")
ただし、レイアウトのアルゴリズムは circo または fdp を指定する必要があり、それ以外のアルゴリズムの場合、この属性は無視される。

グラフを描画するヘルパー関数を用意する。
pos には、{点: (x, y), 点: (x, y), ...} である dict を渡す。この dict は自分で作ってもよいが、networkx の関数で2部グラフや円状のグラフ用のレイアウトを作成する関数があるので、それを利用することもできる。

import networkx as nx
from IPython.display import Image, display


def draw_graph(G, pos, **kwargs):
    '''グラフを描画する。
    Args:
        G: グラフ
        pos: {点: (x, y), 点: (x, y), ...} である dict。
    '''
    A = nx.nx_agraph.to_agraph(G)
    # 点の大きさを固定する。
    A.node_attr.update(fixedsize=True, width=0.35, height=0.35, **kwargs)
    # 点の位置を設定する。
    for n, (x, y) in pos.items():
        A.add_node(n, pos='{},{}!'.format(x, y))
    # グラフを描画する。
    png = A.draw(format='png', prog='neato')
    display(Image(png))

サンプル

完全グラフ

G = nx.complete_graph(5)
pos = nx.circular_layout(G)
draw_graph(G, pos)

2部グラフ

networkx.bipartite_layout でレイアウトを作成できる。

bipartite_layout(G, nodes, align='vertical', scale=1, center=None, aspect_ratio=1.3333333333333333)

第2引数には、align='vertical' の場合は左側、align='horizontal' の場合は上側に来る点のグループを指定する。

G = nx.complete_bipartite_graph(3, 4)
g1, g2 = nx.bipartite.sets(G)
pos = nx.bipartite_layout(G, g1, align='vertical')
draw_graph(G, pos)


G = nx.complete_bipartite_graph(3, 4)
g1, g2 = nx.bipartite.sets(G)
pos = nx.bipartite_layout(G, g2, align='vertical')
draw_graph(G, pos)


G = nx.complete_bipartite_graph(3, 4)
g1, g2 = nx.bipartite.sets(G)
pos = nx.bipartite_layout(G, g1, align='horizontal')
draw_graph(G, pos)



G = nx.complete_bipartite_graph(3, 4)
g1, g2 = nx.bipartite.sets(G)
pos = nx.bipartite_layout(G, g2, align='horizontal')
draw_graph(G, pos)


回路

G = nx.cycle_graph(3)
pos = nx.circular_layout(G)
draw_graph(G, pos)


空グラフ

G = nx.empty_graph(3)
pos = nx.circular_layout(G)
draw_graph(G, pos)


2次元グリッド

def grid_2d_layout(G):
    return {n: n for n in G.nodes}

G = nx.grid_2d_graph(4, 3)
pos = grid_2d_layout(G)
draw_graph(G, pos, fontsize=8)


def path_layout(G):
    return {n: (n, 0) for n in G.nodes}

G = nx.path_graph(3)
pos = path_layout(G)
draw_graph(G, pos)


星グラフ

def star_layout(G):
    nodes = list(G.nodes)
    g = G.subgraph(nodes[1:])
    pos = nx.circular_layout(g, center=(1, 1))
    pos[nodes[0]] = (1, 1)
    return pos

G = nx.star_graph(5)
pos = star_layout(G)
draw_graph(G, pos)


点1つのグラフ

G = nx.trivial_graph()
pos = star_layout(G)
draw_graph(G, pos)


車輪

def star_layout(G):
    nodes = list(G.nodes)
    g = G.subgraph(nodes[1:])
    pos = nx.circular_layout(g, center=(1, 1))
    pos[nodes[0]] = (1, 1)
    return pos

G = nx.wheel_graph(6)
pos = star_layout(G)
draw_graph(G, pos)

import networkx as nx

def tree_layout(G):
    pos = nx.nx_pydot.graphviz_layout(G, prog='dot')
    pos = {k: (x / 96, y / 96) for k, (x, y) in pos.items()}
    return pos

G = nx.balanced_tree(2, 3)
pos = tree_layout(G)
draw_graph(G, pos)

円状の木

import networkx as nx
from IPython.display import Image, display


def draw_graph(G, pos):
    '''グラフを描画する。
    Args:
        G: グラフ
        pos: {点: (x, y), 点: (x, y), ...} である dict。
    '''
    A = nx.nx_agraph.to_agraph(G)
    A.node_attr.update(shape='point', color='blue')
    for n, (x, y) in pos.items():
        A.add_node(n, pos='{},{}!'.format(x, y))

    png = A.draw(format='png', prog='neato')
    display(Image(png))

def circular_tree_layout(G):
    pos = nx.nx_pydot.graphviz_layout(G, prog='twopi')
    pos = {k: (x / 96, y / 96) for k, (x, y) in pos.items()}
    return pos

G = nx.balanced_tree(3, 5)
pos = circular_tree_layout(G)
draw_graph(G, pos)