📝 문제 정보#
🧐 관찰 및 접근#
- 일단 세개를 따로 구하면 되는 것 같다.
- XOR이 제일 쉬워보인다.
- XOR은 역연산이 존재하므로, $A_i \oplus A_j = K$를 만족하는 $A_j$는 $A_i \oplus K$로 구할 수 있다.
- $A_i \leq 1\,000\,000$이므로 그냥 구하면 될듯. 양쪽에서 구했으니 2로 나누는것만 조심하면 되겠다.
- $K = 0$일때를 주의하자.
- AND를 다음으로 봐보자.
- $A_i, A_j$ 모두 $K$와 and연산한 결과는 $K$여야 한다.
- 그리고 $A_i - K, A_j - K$는 and 연산한 결과가 $0$이어야 한다.
- 숫자 $N$개에서 두 수의 and 결과가 $0$인 조합 개수를 빠르게 구할수가 있나?
- 이 같은 아이디어로 OR도 해결할 수 있는 것 같으니, 이걸 고민해보자.
- 약간 포함배제같은맛으로 될라나?
- 일단 쉽게 $A_i$에서 $k$번째 비트 하나가 켜져있다고 해보자.
- 그러면 나머지들에서 $k$번째 비트가 꺼져있기만 하면 된다.
- 이건 뭐 $k$번째 비트가 켜져있는 / 꺼져있는 집합으로 하면 될틴데,…
- $A_i$에서 $k_1, k_2$번째 비트 두개가 켜져있다고 해보자.
- 이번에는 나머지들에서 $k_1, k_2$ 번째 비트 모두가 꺼져있어야 한다.
- 그거 개수는 $k_1$이 켜진거 + $k_2$가 켜진거 - $k_1, k_2$가 켜진것으로 구할 수 있고, 이거까진 또 쉽게 되는거같기도? 왜냐면 둘다 켜진건 처음에 보고있으니까.
- $A_i$에서 $k_1, k_2, k_3$ 세개가..
- 나머지들에서 세개가 다 꺼져있어야..
- 근데 이건 각자 더하고 둘둘 빼고 세개 더하고.. 는 까다롭네. 앞에서 하면서 저장해왔으면 되는거같긴 한데..
- 그런데, 이렇게 하다보면 결국 비트가 20개정도니까… 어? 되는거같기도 한데?
- 일단 쉽게 $A_i$에서 $k$번째 비트 하나가 켜져있다고 해보자.
- 어떻게 풀 수 있나 고민해봤는데, 결국 각 비트 자체 대해서 저장한다음에, 다른 배열에서 이를 이용해서 만들어보자.
- $\text{cntExact}$를 해당 비트가 그대로 켜져있는것 (그 숫자 자체)
- $\text{cntOr}$을 켜져있는 비트중 하나라도 켜져있는것이라고 생각하면, 이걸 앞에서 얻은것들로 계산할 수 있을 것 같다.
- 이를 SOS_DP 로 알려져있다.
- 두 수의 and 결과가 0인 것의 개수는, $A_i$를 정했다면 ${A_i}$ 를 뒤집은 비트에서 부분집합의 합과 같다.
- OR은 어떻게 풀릴까?
- $A_i, A_j$ 모두 $K$랑 or 연산한 결과가 $K$여야 한다.
- 예를들어 목표가 $1101$이고, $A_i$가 $1000$이라면 $A_j$는 $?101$이어야 한다. $A_i$에서 꺼진 비트는 모두 켜져있어야 한다.
- 부분집합의 합을 구할때, $A$를 뒤집어서 저장하는 것으로 가능할 것 같다. $?$ 인 부분들에 대해 부분집합의 합을 계산하면 된다.
💻 풀이#
- 코드 (C++):
void solve(){
int N, K; cin >> N >> K;
vector<int> A(N);
rep(i, 0, N) cin >> A[i];
int all1 = (1 << 20) - 1;
// count and
{
ll ans = 0;
vector<int> cnt(1<<20, 0), SOS(1<<20, 0);
rep(i, 0, N) if((A[i] & K) == K){
cnt[A[i] - K]++;
SOS[A[i] - K]++;
}
rep(i, 0, 20) rep(mask, 0, 1<<20) if(mask & (1<<i)) SOS[mask] += SOS[mask^(1<<i)];
rep(i, 0, N) if((A[i] & K) == K){
int mask = all1 - (A[i] - K);
ans += SOS[mask];
}
ans -= cnt[0];
cout << ans/2 << ' ';
}
// count or
{
ll ans = 0;
vector<int> cnt(1<<20, 0), SOS(1<<20, 0);
rep(i, 0, N) if((A[i] | K) == K){
cnt[A[i]]++;
SOS[K - A[i]]++;
}
rep(i, 0, 20) rep(mask, 0, 1<<20) if(mask & (1<<i)) SOS[mask] += SOS[mask^(1<<i)];
rep(i, 0, N) if((A[i] | K) == K){
ans += SOS[A[i]];
}
ans -= cnt[K];
cout << ans/2 << ' ';
}
// count xor
{
ll ans = 0;
vector<int> cnt(1<<20, 0);
rep(i, 0, N) cnt[A[i]]++;
rep(i, 0, N) ans += cnt[A[i] ^ K];
if(K == 0) ans -= N;
cout << ans/2;
}
}