GRL_2_A - Minimum Spanning Tree / PythonでAizu Online Judgeを解く

    Posted on 2018/10/14

    概要

    与えられた重み付きの無向グラフ \(G(V, E)\) について最小全域木を求め、その全域木に含まれる辺の総和を出力してください。

    制約

    • \(1 \leq |V| \leq 10,000\)
    • \(0 \leq |E| \leq 100,000\)
    • \(0 \leq w_i \leq 10,000\)
    • グラフは連結である
    • グラフは平行辺を持たない
    • グラフは自己ループを持たない

    入力

    %python
    def input(f):
        global v, e, s, t, w
        
        v, e = map(int, f.readline().split())
        s = []
        t = []
        w = []
        for _ in range(e):
            s_, t_, w_ = map(int, f.readline().split())
            s.append(s_)
            t.append(t_)
            w.append(w_)

    サンプル入力

    %sh
    cat << EOF > /tmp/input1
    4 6
    0 1 2
    1 2 1
    2 3 1
    3 0 1
    0 2 3
    1 3 5
    EOF
    
    cat << EOF > /tmp/input2
    6 9
    0 1 1
    0 2 3
    1 2 1
    1 3 7
    2 4 1
    1 4 3
    3 4 1
    3 5 1
    4 5 6
    EOF

    Python: 入力をnetworkxで可視化してみる

    %sh
    pip install networkx
    Collecting networkx
    Collecting decorator>=4.3.0 (from networkx)
      Using cached https://files.pythonhosted.org/packages/bc/bb/a24838832ba35baf52f32ab1a49b906b5f82fb7c76b2f6a7e35e140bac30/decorator-4.3.0-py2.py3-none-any.whl
    Installing collected packages: decorator, networkx
    Successfully installed decorator-4.3.0 networkx-2.2
    You are using pip version 9.0.1, however version 18.1 is available.
    You should consider upgrading via the 'pip install --upgrade pip' command.
    

    %python
    import os
    
    if os.environ.get('ZEPPELIN_HOME'):
        import networkx as nx
        import matplotlib.pyplot as plt
        
        def visualize():
            G = nx.MultiDiGraph()
            for i in range(v):
                G.add_node(i, node_color='b')
            for i in range(e):
                G.add_edge(s[i], t[i], t=w[i])
    
            plt.figure()
            plt.style.use('seaborn')
            plt.axis("off")
            pos = nx.spring_layout(G)
            nodes = nx.draw_networkx_nodes(G, pos, node_color='w', linewidths=2)
            nodes.set_edgecolor('black')
            nx.draw_networkx_labels(G, pos)
            nx.draw_networkx_edges(G, pos)
    
            edge_labels = dict([((a, b), c['t']) for a, b, c in G.edges(data=True)])
            nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, label_pos=0.25, font_size=16)
    
            plt.show()

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        with open('/tmp/input1') as f:
            print('Example #1')
            input(f)
            visualize()
    Example #1
    

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        with open('/tmp/input2') as f:
            print('Example #2')
            input(f)
            visualize()
    Example #2
    

    考え方

    最小全域木を求めるアルゴリズムにはクラスカル法とプリム法と呼ばれるものがありますが、今回は前者を使って最小全域木を求めてみます。

    やることは主に2つで:

    • 優先度付きキューに重み順で辺を入れる
    • キューから取り出した辺について、それぞれ別々の木に属していればその辺を使って接続する

    これらを繰り返すことで最小全域木が求まります。

    Pythonで優先度付きキューを実装するには heapq というモジュールを使うと良さそうです。

    実装: UnionFind

    %python
    class UnionFind:
        def __init__(self, v):
            self.parent = [-1] * v
        
        def find(self, x):
            if self.parent[x] == -1:
                return x
            else:
                self.parent[x] = self.find(self.parent[x])
                return self.parent[x]
        
        def merge(self, x, y):
            px = self.find(x)
            py = self.find(y)
            if px != py:
                self.parent[py] = px
        
        def is_same(self, x, y):
            return self.find(x) == self.find(y)

    実装 - クラスカル法

    %python
    import sys
    import heapq
    
    
    def init():
        global inf
        inf = sys.maxsize
    
    
    def solve():
        q = []
        for i in range(e):
            heapq.heappush(q, (w[i], i))
        
        res = 0
        uf = UnionFind(v)
        while len(q) > 0:
            wi, i = heapq.heappop(q)
            if not uf.is_same(s[i], t[i]):
                uf.merge(s[i], t[i])
                res += wi
        
        print(res)

    Example Output #1

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        os.environ['INPUT'] = '/tmp/input1'
    
    with open(os.environ.get('INPUT') or '/dev/stdin') as f:
        input(f)
        init()
        solve()
    3
    

    Example Output #2

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        os.environ['INPUT'] = '/tmp/input2'
    
        with open(os.environ.get('INPUT') or '/dev/stdin') as f:
            input(f)
            init()
            solve()
    5
    

    1WA

    提出したものの不正解。

    %sh
    cat << EOF > /tmp/input3
    10 15
    8 1 4457
    1 3 2531
    3 4 7111
    4 2 1088
    2 5 3124
    9 0 4427
    6 2 2005
    7 9 6489
    3 8 2313
    1 2 7125
    5 4 4987
    1 9 6782
    1 5 1147
    7 2 4875
    1 6 5959
    EOF

    27999 が正しい出力でしたが…

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        os.environ['INPUT'] = '/tmp/input3'
    
        with open(os.environ.get('INPUT') or '/dev/stdin') as f:
            input(f)
            init()
            solve()
    29906
    

    29906 を返している… :-|

    グラフの可視化もやっておきましょう

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        with open('/tmp/input3') as f:
            print('Example #3')
            input(f)
            visualize()
    Example #3
    

    見るに堪えない… :-/

    原因

    heapq で構築した list を for in でイテレートしてました。heapq.heappop を使う必要があります

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        os.environ['INPUT'] = '/tmp/input3'
    
        with open(os.environ.get('INPUT') or '/dev/stdin') as f:
            input(f)
            
            q = []
            for i in range(e):
                heapq.heappush(q, w[i])
            
            print('# for in:')
            for wi in q:
                print(wi)
    # for in:
    1088
    2313
    1147
    2531
    3124
    2005
    4427
    6489
    4457
    7125
    4987
    7111
    6782
    4875
    5959
    

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        os.environ['INPUT'] = '/tmp/input3'
    
        with open(os.environ.get('INPUT') or '/dev/stdin') as f:
            input(f)
            
            q = []
            for i in range(e):
                heapq.heappush(q, w[i])
            
            print('# heappop')
            while len(q) > 0:
                wi = heapq.heappop(q)
                print(wi)
    # heappop
    1088
    1147
    2005
    2313
    2531
    3124
    4427
    4457
    4875
    4987
    5959
    6489
    6782
    7111
    7125
    

    修正したら無事に通りました。

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        os.environ['INPUT'] = '/tmp/input3'
    
        with open(os.environ.get('INPUT') or '/dev/stdin') as f:
            input(f)
            init()
            solve()
    27999
    

    heapq, listに乗っかっていてカジュアル感があるので使いやすく感じます。