GRL_2_A - Minimum Spanning Tree / PythonでAizu Online Judgeを解く

Posted on 2018/10/14

概要

与えられた重み付きの無向グラフ \(G(V, E)\) について最小全域木を求め、その全域木に含まれる辺の総和を出力してください。

制約

  • \(1 \leq |V| \leq 10,000\)
  • \(0 \leq |E| \leq 100,000\)
  • \(0 \leq w_i \leq 10,000\)
  • グラフは連結である
  • グラフは平行辺を持たない
  • グラフは自己ループを持たない

入力

%python
def input(f):
    global v, e, s, t, w
    
    v, e = map(int, f.readline().split())
    s = []
    t = []
    w = []
    for _ in range(e):
        s_, t_, w_ = map(int, f.readline().split())
        s.append(s_)
        t.append(t_)
        w.append(w_)

サンプル入力

%sh
cat << EOF > /tmp/input1
4 6
0 1 2
1 2 1
2 3 1
3 0 1
0 2 3
1 3 5
EOF

cat << EOF > /tmp/input2
6 9
0 1 1
0 2 3
1 2 1
1 3 7
2 4 1
1 4 3
3 4 1
3 5 1
4 5 6
EOF

Python: 入力をnetworkxで可視化してみる

%sh
pip install networkx
Collecting networkx
Collecting decorator>=4.3.0 (from networkx)
  Using cached https://files.pythonhosted.org/packages/bc/bb/a24838832ba35baf52f32ab1a49b906b5f82fb7c76b2f6a7e35e140bac30/decorator-4.3.0-py2.py3-none-any.whl
Installing collected packages: decorator, networkx
Successfully installed decorator-4.3.0 networkx-2.2
You are using pip version 9.0.1, however version 18.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

%python
import os

if os.environ.get('ZEPPELIN_HOME'):
    import networkx as nx
    import matplotlib.pyplot as plt
    
    def visualize():
        G = nx.MultiDiGraph()
        for i in range(v):
            G.add_node(i, node_color='b')
        for i in range(e):
            G.add_edge(s[i], t[i], t=w[i])

        plt.figure()
        plt.style.use('seaborn')
        plt.axis("off")
        pos = nx.spring_layout(G)
        nodes = nx.draw_networkx_nodes(G, pos, node_color='w', linewidths=2)
        nodes.set_edgecolor('black')
        nx.draw_networkx_labels(G, pos)
        nx.draw_networkx_edges(G, pos)

        edge_labels = dict([((a, b), c['t']) for a, b, c in G.edges(data=True)])
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, label_pos=0.25, font_size=16)

        plt.show()

%python
if os.environ.get('ZEPPELIN_HOME'):
    with open('/tmp/input1') as f:
        print('Example #1')
        input(f)
        visualize()
Example #1

%python
if os.environ.get('ZEPPELIN_HOME'):
    with open('/tmp/input2') as f:
        print('Example #2')
        input(f)
        visualize()
Example #2

考え方

最小全域木を求めるアルゴリズムにはクラスカル法とプリム法と呼ばれるものがありますが、今回は前者を使って最小全域木を求めてみます。

やることは主に2つで:

  • 優先度付きキューに重み順で辺を入れる
  • キューから取り出した辺について、それぞれ別々の木に属していればその辺を使って接続する

これらを繰り返すことで最小全域木が求まります。

Pythonで優先度付きキューを実装するには heapq というモジュールを使うと良さそうです。

実装: UnionFind

%python
class UnionFind:
    def __init__(self, v):
        self.parent = [-1] * v
    
    def find(self, x):
        if self.parent[x] == -1:
            return x
        else:
            self.parent[x] = self.find(self.parent[x])
            return self.parent[x]
    
    def merge(self, x, y):
        px = self.find(x)
        py = self.find(y)
        if px != py:
            self.parent[py] = px
    
    def is_same(self, x, y):
        return self.find(x) == self.find(y)

実装 - クラスカル法

%python
import sys
import heapq


def init():
    global inf
    inf = sys.maxsize


def solve():
    q = []
    for i in range(e):
        heapq.heappush(q, (w[i], i))
    
    res = 0
    uf = UnionFind(v)
    while len(q) > 0:
        wi, i = heapq.heappop(q)
        if not uf.is_same(s[i], t[i]):
            uf.merge(s[i], t[i])
            res += wi
    
    print(res)

Example Output #1

%python
if os.environ.get('ZEPPELIN_HOME'):
    os.environ['INPUT'] = '/tmp/input1'

with open(os.environ.get('INPUT') or '/dev/stdin') as f:
    input(f)
    init()
    solve()
3

Example Output #2

%python
if os.environ.get('ZEPPELIN_HOME'):
    os.environ['INPUT'] = '/tmp/input2'

    with open(os.environ.get('INPUT') or '/dev/stdin') as f:
        input(f)
        init()
        solve()
5

1WA

提出したものの不正解。

%sh
cat << EOF > /tmp/input3
10 15
8 1 4457
1 3 2531
3 4 7111
4 2 1088
2 5 3124
9 0 4427
6 2 2005
7 9 6489
3 8 2313
1 2 7125
5 4 4987
1 9 6782
1 5 1147
7 2 4875
1 6 5959
EOF

27999 が正しい出力でしたが…

%python
if os.environ.get('ZEPPELIN_HOME'):
    os.environ['INPUT'] = '/tmp/input3'

    with open(os.environ.get('INPUT') or '/dev/stdin') as f:
        input(f)
        init()
        solve()
29906

29906 を返している… :-|

グラフの可視化もやっておきましょう

%python
if os.environ.get('ZEPPELIN_HOME'):
    with open('/tmp/input3') as f:
        print('Example #3')
        input(f)
        visualize()
Example #3

見るに堪えない… :-/

原因

heapq で構築した list を for in でイテレートしてました。heapq.heappop を使う必要があります

%python
if os.environ.get('ZEPPELIN_HOME'):
    os.environ['INPUT'] = '/tmp/input3'

    with open(os.environ.get('INPUT') or '/dev/stdin') as f:
        input(f)
        
        q = []
        for i in range(e):
            heapq.heappush(q, w[i])
        
        print('# for in:')
        for wi in q:
            print(wi)
# for in:
1088
2313
1147
2531
3124
2005
4427
6489
4457
7125
4987
7111
6782
4875
5959

%python
if os.environ.get('ZEPPELIN_HOME'):
    os.environ['INPUT'] = '/tmp/input3'

    with open(os.environ.get('INPUT') or '/dev/stdin') as f:
        input(f)
        
        q = []
        for i in range(e):
            heapq.heappush(q, w[i])
        
        print('# heappop')
        while len(q) > 0:
            wi = heapq.heappop(q)
            print(wi)
# heappop
1088
1147
2005
2313
2531
3124
4427
4457
4875
4987
5959
6489
6782
7111
7125

修正したら無事に通りました。

%python
if os.environ.get('ZEPPELIN_HOME'):
    os.environ['INPUT'] = '/tmp/input3'

    with open(os.environ.get('INPUT') or '/dev/stdin') as f:
        input(f)
        init()
        solve()
27999

heapq, listに乗っかっていてカジュアル感があるので使いやすく感じます。