AtCoder/arc092_b - Two Sequences

Posted on

概要

長さ \(N\) の数列 \(a, b\) が与えられます。数列からそれぞれ一つずつ要素を選んでペアにしたときの全てのペアの和(\(a_i + b_i\))について排他的論理和(xor)をとった値を計算してください。

制約

  • 入力はすべて整数
  • \(1 \leq N \leq 200000\)
  • \(0 \leq a_i, b_i < 2^{28}\)

Examples

まずはサンプルを手計算してみます。

Example #1

\(N = 2, a = \{1, 2\}, b = \{3, 4\}\)

のときは

\(4 (1+3) \oplus 5 (1+4) \oplus 5 (2+3) \oplus 6 (2+4) = 2\)

となります。

Example #2

\(N = 6, a = \{4, 6, 0, 0, 3, 3\}, b = \{0, 5, 6, 5, 0, 3\}\)

36通り計算するのはだるいので排他的論理和の性質を振り返ってみます。

排他的論理和(xor)とは何だったのか

入力が異なるときにビットが立つようなものでした。

AB\(A \oplus B\)
000
011
101
110

a と b を昇順に並び替えると、

\(a = \{0, 0, 3, 3, 4, 6\}\)
\(b = \{0, 0, 3, 5, 5, 6\}\)

1番目の 0 と 2番目の 0 で同じ和が生成されるので xor を取ると打ち消し合う形になります。3 についても同様であるので \(a = \{4, 6\}\) と \(b = \{5, 6\}\) についてxorを求めればよく

\(9 (4+5) \oplus 10 (4+6) \oplus 11 (6+5) \oplus 12 (6+6) = 8\)

となります。

考え方

場合分ける

ビットごとに処理するのが定石なので1ビットずつ処理していく形に落とし込むことを考えるわけですが、足し算すると繰り上がりがあるのでどうしたものかなという感じになります。

kビット目が 1 になるのは、kビット目までの和が2進数表記で 01xxx11xxx のようになる場合です。それぞれ場合分けすると、

01xxx の場合、

$$
2^{k-1} \leq a_i + b_i < 2^{k}
$$

11xxx の場合、

$$
2^k+2^{k-1} \leq a_i + b_i < 2^{k+1}
$$

※ 注意: \(a_i\) と \(b_i\) は k ビット目までに mask したものです

\(a_i\) を移行すると

01xxx の場合、

$$
2^{k-1} - a_i \leq b_i < 2^{k} - a_i
$$

11xxx の場合、

$$
2^k+2^{k-1} - a_i \leq b_i < 2^{k+1} - a_i
$$

となります。

数え上げる

各ビットについて、\(b_i\) の値を mask したものを前計算しソートしておくと、条件を満たすような \(b_i\) の個数について二分探索で求めることができます。

  p[0] = 1;
  rep (k, psize-1) {
    p[k+1] = p[k] * 2LL;
  }
 
  int_t mask = 1;
  rep (k, psize) if (k > 0) {
    rep (i, n) t[k][i] = b[i] & mask;
    sort(t[k], t[k]+n);
    mask = (mask << 1LL) | 1LL;
  }

それぞれ条件について \(a_i\) を固定し二分探索していきます。

int count(int k, int_t ai) {
  int res = 0;
  auto tk = t[k];
  int c1 = lower_bound(tk, tk+n, p[k] - ai) - lower_bound(tk, tk+n, p[k-1] - ai);
  int c2 = lower_bound(tk, tk+n, p[k+1] - ai) - lower_bound(tk, tk+n, p[k] + p[k-1] - ai);
  return c1 + c2;
}

すべてのビットについて数え上げを行い、条件を満たすようなペアの個数が奇数であれば、そのビットを残すように答えを組み立てます。

  int cnt[psize];
  fill(cnt, cnt+psize, 0);
  int_t mask2 = 1;
  rep (k, psize-1) if (k > 0) {
    rep (i, n) {
      cnt[k] += count(k, a[i] & mask2);
    }
    mask2 = (mask2 << 1LL) | 1LL;
  }
 
  int_t res = 0;
  rep (k, psize) if (k > 0) {
    if (cnt[k] % 2) {
      res += 1LL << (k-1);
    }
  }
  cout << res << endl;

雑感

ビットごとにやるのは xor の文字列を見た瞬間に浮かびましたが、そのあとの場合分けのところは苦戦しました(どうやったら克服できるのだろう……)。足し算しているのでビット数の上限設定にも注意が必要です。

%md