PythonでAizu Online Judgeを解く/GRL_2_B - Minimum-Cost Arborescence

Posted on

最小全域有向木

重み付き有向グラフ \(G(V, E)\) について、頂点 \(r\) を根とする最小全域有向木の辺の重みの総和を求めてください。

制約

  • \(1 \leq |V| \leq 100\)
  • \(0 \leq |E| \leq 1,000\)
  • \(0 \leq w_i \leq 10,000\)
  • グラフ \(G\) は \(r\) を根とする有向木をもつ

Arborescenceって何

アーボレセンス。頭の arbor がラテン語で木を意味しているようです。

サンプル入力

%sh
cat << EOF > /tmp/input1.txt
4 6 0
0 1 3
0 2 2
2 0 1
2 3 1
3 0 1
3 1 5
EOF

cat << EOF > /tmp/input2.txt
6 10 0
0 2 7
0 1 1
0 3 5
1 4 9
2 1 6
1 3 2
3 4 3
4 2 2
2 5 8
3 5 3
EOF

input

%python
def input(f):
    global v, e, r
    global s, t, w
    
    v, e, r = map(int, f.readline().split())
    s = []
    t = []
    w = []
    for _ in range(e):
        s_, t_, w_ = map(int, f.readline().split())
        s.append(s_)
        t.append(t_)
        w.append(w_)

サンプル入力を可視化

%sh
pip install networkx
Requirement already satisfied (use --upgrade to upgrade): networkx in /opt/conda/lib/python2.7/site-packages
Requirement already satisfied (use --upgrade to upgrade): decorator>=4.3.0 in /opt/conda/lib/python2.7/site-packages (from networkx)
You are using pip version 8.1.2, however version 19.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

%python
import os

if os.environ.get('ZEPPELIN_HOME'):
    import networkx as nx
    import matplotlib.pyplot as plt
    
    def generate_input_graph():
        G = nx.MultiDiGraph()
        node_colors = []
        for i in range(v):
            G.add_node(i)
            if i == r:
                node_colors.append('#ff9900')
            else:
                node_colors.append('w')
        for i in range(e):
            G.add_edge(s[i], t[i], t=w[i])
        return G, node_colors
    
    def visualize(G, node_colors, seed=1):
        plt.figure()
        plt.style.use('seaborn')
        plt.axis("off")
        pos = nx.spring_layout(G, seed)
        nodes = nx.draw_networkx_nodes(G, pos, node_color=node_colors, linewidths=2)
        nodes.set_edgecolor('black')
        nx.draw_networkx_labels(G, pos)
        nx.draw_networkx_edges(G, pos)

        edge_labels = dict([((a, b), c['t']) for a, b, c in G.edges(data=True)])
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, label_pos=0.25, font_size=16)

        plt.show()

Example #1

%python
if os.environ.get('ZEPPELIN_HOME'):
    with open('/tmp/input1.txt') as f:
        print('Example #1')
        input(f)
        G, nc = generate_input_graph()
        visualize(G, nc)
    
    plt.show()
Example #1
<matplotlib.figure.Figure at 0x7fe3dc0f1a90>

Example #2

%python
if os.environ.get('ZEPPELIN_HOME'):
    with open('/tmp/input2.txt') as f:
        print('Example #2')
        input(f)
        G, nc = generate_input_graph()
        visualize(G, nc)
    
    plt.show()
Example #2
<matplotlib.figure.Figure at 0x7fe3dc08e110>

各頂点について最小の入辺を残す

まずは頂点ごとに入ってくる辺の中で最小のものを選択していきます。

%python
with open('/tmp/input1.txt') as f:
    input(f)

best = [(10**18, -1) for _ in range(v)]
for i in range(e):
    best[t[i]] = min(best[t[i]], (w[i], s[i]))

best
[(1, 2), (3, 0), (2, 0), (1, 2)]

%python
G = nx.MultiDiGraph()
node_colors = []
for i in range(v):
    G.add_node(i)
    if i == r:
        node_colors.append('#ffcc00')
    else:
        node_colors.append('w')
for i, b in enumerate(best):
    if i == r:
        continue
    G.add_edge(b[1], i, t=b[0])

visualize(G, node_colors)

%python
with open('/tmp/input2.txt') as f:
    input(f)

best = [(10**18, -1) for _ in range(v)]
for i in range(e):
    best[t[i]] = min(best[t[i]], (w[i], s[i]))

best
[(1000000000000000000, -1), (1, 0), (2, 4), (2, 1), (3, 3), (3, 3)]

%python
G = nx.MultiDiGraph()
node_colors = []
for i in range(v):
    G.add_node(i)
    if i == r:
        node_colors.append('#ffcc00')
    else:
        node_colors.append('w')
for i, b in enumerate(best):
    if i == r:
        continue
    G.add_edge(b[1], i, t=b[0])

visualize(G, node_colors, seed=4)

サンプルケースのように選んだ辺から構築されたグラフが閉路を持たない場合はこの操作を実行すれば最小全域有向木が求まります。

構築したグラフに閉路があったらどうなるか

問題となるのは構築されたグラフが閉路を持つ場合です。

閉路ができるケース

%sh
cat << EOF > /tmp/input3.txt
6 6 0
0 1 1
0 2 9
2 3 1
3 4 1
4 5 1
5 2 1
EOF

%python
if os.environ.get('ZEPPELIN_HOME'):
    with open('/tmp/input3.txt') as f:
        print('Example #3')
        input(f)
        G, nc = generate_input_graph()
        visualize(G, nc, seed=1)
    
    plt.show()
Example #3
<matplotlib.figure.Figure at 0x7fe3d5562690>

上のグラフに対して、各頂点で最小の入辺を取っていくとどうなるでしょうか。

%python
with open('/tmp/input3.txt') as f:
    input(f)

best = [(10**18, -1) for _ in range(v)]
for i in range(e):
    best[t[i]] = min(best[t[i]], (w[i], s[i]))

best
[(1000000000000000000, -1), (1, 0), (1, 5), (1, 2), (1, 3), (1, 4)]

%python
G = nx.MultiDiGraph()
node_colors = []
for i in range(v):
    G.add_node(i)
    if i == r:
        node_colors.append('#ffcc00')
    else:
        node_colors.append('w')
for i, b in enumerate(best):
    if i == r:
        continue
    G.add_edge(b[1], i, t=b[0])

visualize(G, node_colors, seed=8)
<matplotlib.figure.Figure at 0x7fe3dc3cc410>

分離されてしまいました。このままでは根 \(r\) から到達できない頂点があるので閉路をうまく崩していく必要があります。

閉路の縮約

https://en.wikipedia.org/wiki/Edmonds%27_algorithm

各辺について

  • 同じ閉路に含まれる
  • 追加しない
  • 閉路に入る場合
  • その辺のコストから終点の最小入辺のコストを引く(閉路内の最小入辺を削除して、見ている辺を追加)
  • それ以外はそのまま残す

上記の手順を繰り返すことで閉路を縮約していきます。

%python
edges = []
for i in range(e):
    edges.append((s[i], t[i], w[i]))

edges
[(0, 1, 1), (0, 2, 9), (2, 3, 1), (3, 4, 1), (4, 5, 1), (5, 2, 1)]

%python
def solve(v, edges, r):
    best = [(10**18, -1) for _ in range(v)]
    for s, t, w in edges:
        best[t] = min(best[t], (w, s))
    best[r] = (0, -1)
    
    group = [0] * v
    comp = [0] * v
    used = [0] * v
    cnt = 0
    for i in range(v):
        if used[i]:
            continue
        
        chain = []
        cur = i
        
        while not used[cur] and cur != -1:
            chain.append(cur)
            used[cur] = 1
            cur = best[cur][1]
        
        if cur != -1:
            cycle = 0
            for x in chain:
                group[x] = cnt
                if x == cur:
                    cycle = 1
                    comp[cnt] = 1
                if not cycle:
                    cnt += 1
            if cycle:
                cnt += 1
        else:
            for x in chain:
                group[x] = cnt
                cnt += 1
    
    if cnt == v:
        return sum(map(lambda x: x[0], best))
    
    next_edges = []
    for s, t, w in edges:
        gs = group[s]
        gt = group[t]
        if gs == gt:
            continue
        if comp[gt]:
            next_edges.append((gs, gt, w - best[t][0]))
        else:
            next_edges.append((gs, gt, w))
    
    res = sum(best[x][0] for x in range(v) if x != r and comp[group[x]])
    return res + solve(cnt, next_edges, group[r])

solve(v, edges, r)
13