GRL_2_B - Minimum-Cost Arborescence / PythonでAizu Online Judgeを解く

    Posted on 2019/02/03

    最小全域有向木

    重み付き有向グラフ \(G(V, E)\) について、頂点 \(r\) を根とする最小全域有向木の辺の重みの総和を求めてください。

    制約

    • \(1 \leq |V| \leq 100\)
    • \(0 \leq |E| \leq 1,000\)
    • \(0 \leq w_i \leq 10,000\)
    • グラフ \(G\) は \(r\) を根とする有向木をもつ

    Arborescenceって何

    アーボレセンス。頭の arbor がラテン語で木を意味しているようです。

    サンプル入力

    %sh
    cat << EOF > /tmp/input1.txt
    4 6 0
    0 1 3
    0 2 2
    2 0 1
    2 3 1
    3 0 1
    3 1 5
    EOF
    
    cat << EOF > /tmp/input2.txt
    6 10 0
    0 2 7
    0 1 1
    0 3 5
    1 4 9
    2 1 6
    1 3 2
    3 4 3
    4 2 2
    2 5 8
    3 5 3
    EOF

    input

    %python
    def input(f):
        global v, e, r
        global s, t, w
        
        v, e, r = map(int, f.readline().split())
        s = []
        t = []
        w = []
        for _ in range(e):
            s_, t_, w_ = map(int, f.readline().split())
            s.append(s_)
            t.append(t_)
            w.append(w_)

    サンプル入力を可視化

    %sh
    pip install networkx
    Requirement already satisfied (use --upgrade to upgrade): networkx in /opt/conda/lib/python2.7/site-packages
    Requirement already satisfied (use --upgrade to upgrade): decorator>=4.3.0 in /opt/conda/lib/python2.7/site-packages (from networkx)
    You are using pip version 8.1.2, however version 19.0.1 is available.
    You should consider upgrading via the 'pip install --upgrade pip' command.
    

    %python
    import os
    
    if os.environ.get('ZEPPELIN_HOME'):
        import networkx as nx
        import matplotlib.pyplot as plt
        
        def generate_input_graph():
            G = nx.MultiDiGraph()
            node_colors = []
            for i in range(v):
                G.add_node(i)
                if i == r:
                    node_colors.append('#ff9900')
                else:
                    node_colors.append('w')
            for i in range(e):
                G.add_edge(s[i], t[i], t=w[i])
            return G, node_colors
        
        def visualize(G, node_colors, seed=1):
            plt.figure()
            plt.style.use('seaborn')
            plt.axis("off")
            pos = nx.spring_layout(G, seed)
            nodes = nx.draw_networkx_nodes(G, pos, node_color=node_colors, linewidths=2)
            nodes.set_edgecolor('black')
            nx.draw_networkx_labels(G, pos)
            nx.draw_networkx_edges(G, pos)
    
            edge_labels = dict([((a, b), c['t']) for a, b, c in G.edges(data=True)])
            nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, label_pos=0.25, font_size=16)
    
            plt.show()

    Example #1

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        with open('/tmp/input1.txt') as f:
            print('Example #1')
            input(f)
            G, nc = generate_input_graph()
            visualize(G, nc)
        
        plt.show()
    Example #1
    <matplotlib.figure.Figure at 0x7fe3dc0f1a90>
    

    Example #2

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        with open('/tmp/input2.txt') as f:
            print('Example #2')
            input(f)
            G, nc = generate_input_graph()
            visualize(G, nc)
        
        plt.show()
    Example #2
    <matplotlib.figure.Figure at 0x7fe3dc08e110>
    

    各頂点について最小の入辺を残す

    まずは頂点ごとに入ってくる辺の中で最小のものを選択していきます。

    %python
    with open('/tmp/input1.txt') as f:
        input(f)
    
    best = [(10**18, -1) for _ in range(v)]
    for i in range(e):
        best[t[i]] = min(best[t[i]], (w[i], s[i]))
    
    best
    [(1, 2), (3, 0), (2, 0), (1, 2)]

    %python
    G = nx.MultiDiGraph()
    node_colors = []
    for i in range(v):
        G.add_node(i)
        if i == r:
            node_colors.append('#ffcc00')
        else:
            node_colors.append('w')
    for i, b in enumerate(best):
        if i == r:
            continue
        G.add_edge(b[1], i, t=b[0])
    
    visualize(G, node_colors)
    <matplotlib.figure.Figure at 0x7fe3d5e0ec50>

    %python
    with open('/tmp/input2.txt') as f:
        input(f)
    
    best = [(10**18, -1) for _ in range(v)]
    for i in range(e):
        best[t[i]] = min(best[t[i]], (w[i], s[i]))
    
    best
    [(1000000000000000000, -1), (1, 0), (2, 4), (2, 1), (3, 3), (3, 3)]

    %python
    G = nx.MultiDiGraph()
    node_colors = []
    for i in range(v):
        G.add_node(i)
        if i == r:
            node_colors.append('#ffcc00')
        else:
            node_colors.append('w')
    for i, b in enumerate(best):
        if i == r:
            continue
        G.add_edge(b[1], i, t=b[0])
    
    visualize(G, node_colors, seed=4)
    <matplotlib.figure.Figure at 0x7fe3dc0f2390>

    サンプルケースのように選んだ辺から構築されたグラフが閉路を持たない場合はこの操作を実行すれば最小全域有向木が求まります。

    構築したグラフに閉路があったらどうなるか

    問題となるのは構築されたグラフが閉路を持つ場合です。

    閉路ができるケース

    %sh
    cat << EOF > /tmp/input3.txt
    6 6 0
    0 1 1
    0 2 9
    2 3 1
    3 4 1
    4 5 1
    5 2 1
    EOF

    %python
    if os.environ.get('ZEPPELIN_HOME'):
        with open('/tmp/input3.txt') as f:
            print('Example #3')
            input(f)
            G, nc = generate_input_graph()
            visualize(G, nc, seed=1)
        
        plt.show()
    Example #3
    <matplotlib.figure.Figure at 0x7fe3d5562690>
    

    上のグラフに対して、各頂点で最小の入辺を取っていくとどうなるでしょうか。

    %python
    with open('/tmp/input3.txt') as f:
        input(f)
    
    best = [(10**18, -1) for _ in range(v)]
    for i in range(e):
        best[t[i]] = min(best[t[i]], (w[i], s[i]))
    
    best
    [(1000000000000000000, -1), (1, 0), (1, 5), (1, 2), (1, 3), (1, 4)]

    %python
    G = nx.MultiDiGraph()
    node_colors = []
    for i in range(v):
        G.add_node(i)
        if i == r:
            node_colors.append('#ffcc00')
        else:
            node_colors.append('w')
    for i, b in enumerate(best):
        if i == r:
            continue
        G.add_edge(b[1], i, t=b[0])
    
    visualize(G, node_colors, seed=8)
    <matplotlib.figure.Figure at 0x7fe3dc3cc410>
    

    分離されてしまいました。このままでは根 \(r\) から到達できない頂点があるので閉路をうまく崩していく必要があります。

    閉路の縮約

    https://en.wikipedia.org/wiki/Edmonds%27_algorithm

    各辺について

    • 同じ閉路に含まれる
    • 追加しない
    • 閉路に入る場合
    • その辺のコストから終点の最小入辺のコストを引く(閉路内の最小入辺を削除して、見ている辺を追加)
    • それ以外はそのまま残す

    上記の手順を繰り返すことで閉路を縮約していきます。

    %python
    edges = []
    for i in range(e):
        edges.append((s[i], t[i], w[i]))
    
    edges
    [(0, 1, 1), (0, 2, 9), (2, 3, 1), (3, 4, 1), (4, 5, 1), (5, 2, 1)]

    %python
    def solve(v, edges, r):
        best = [(10**18, -1) for _ in range(v)]
        for s, t, w in edges:
            best[t] = min(best[t], (w, s))
        best[r] = (0, -1)
        
        group = [0] * v
        comp = [0] * v
        used = [0] * v
        cnt = 0
        for i in range(v):
            if used[i]:
                continue
            
            chain = []
            cur = i
            
            while not used[cur] and cur != -1:
                chain.append(cur)
                used[cur] = 1
                cur = best[cur][1]
            
            if cur != -1:
                cycle = 0
                for x in chain:
                    group[x] = cnt
                    if x == cur:
                        cycle = 1
                        comp[cnt] = 1
                    if not cycle:
                        cnt += 1
                if cycle:
                    cnt += 1
            else:
                for x in chain:
                    group[x] = cnt
                    cnt += 1
        
        if cnt == v:
            return sum(map(lambda x: x[0], best))
        
        next_edges = []
        for s, t, w in edges:
            gs = group[s]
            gt = group[t]
            if gs == gt:
                continue
            if comp[gt]:
                next_edges.append((gs, gt, w - best[t][0]))
            else:
                next_edges.append((gs, gt, w))
        
        res = sum(best[x][0] for x in range(v) if x != r and comp[group[x]])
        return res + solve(cnt, next_edges, group[r])
    
    solve(v, edges, r)
    13