WC2018 州区划分,子集卷积学习笔记

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#include<stack>
#include<cmath>
using namespace std;
typedef long long ll;

const int maxn = 23;
const int mod = 998244353;

int n,m,q,lim;
int p[maxn],d[maxn],w[maxn],cnt[1<<maxn],sum[1<<maxn],inv[1<<maxn];
int f[maxn][1<<maxn],g[maxn][1<<maxn];

int qsm(int i,int po){
    int res=1;
    while(po){
        if(po&1) res=1ll*res*i%mod;
        i=1ll*i*i%mod;
        po>>=1;
    } return res;
}

int fa[maxn];
int find(int x){
    return x==fa[x]?x:fa[x]=find(fa[x]);
}
bool check(int s){
    if(cnt[s]==1) return false;
    for(int i=1;i<=n;++i) d[i]=0,fa[i]=i;
    
    int k=cnt[s];
    for(int i=1;i<=n;++i){
        if(s&(1<<i-1)){
            sum[s]=(sum[s]+w[i])%mod;
            for(int j=i+1;j<=n;++j){
                if((s&(1<<j-1))&&(p[i]&(1<<j-1))){
                    if(find(i)!=find(j)){
                        fa[fa[i]]=fa[j];
                        --k;
                    }
                    d[i]++,d[j]++;
                }
            }
        }
    }
    sum[s]=q==0?1:q==1?sum[s]:sum[s]*sum[s]%mod;
    if(k>1) return true;
    for(int i=1;i<=n;++i) if((s&(1<<i-1))&&(d[i]&1)) return true;
    return false;
}

void fwt(int *A,int type){
    for(int i=1;i<lim;i<<=1){
        for(int j=0;j<lim;j+=(i<<1)){
            for(int k=0;k<i;++k){
                int x=A[j+k],y=A[i+j+k];
                if(type==1) A[i+j+k]=(y+x)%mod;
                else A[i+j+k]=(y-x+mod)%mod;
            }
        }
    }
}

ll read(){ ll s=0,f=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') f=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ s=s*10+ch-'0'; ch=getchar(); } return s*f;}

int main(){
    n=read(),m=read(),q=read();
    lim=1<<n;
    int u,v;
    for(int i=1;i<=m;++i){
        u=read(),v=read();
        p[u]|=(1<<v-1),p[v]|=(1<<u-1);
    }
    for(int i=1;i<=n;++i) w[i]=read();
    
    for(int i=1;i<(1<<n);++i) cnt[i]=cnt[i>>1]+(i&1);
    
    for(int i=1;i<(1<<n);++i){
        if(check(i)) g[cnt[i]][i]=sum[i];
        else g[cnt[i]][i]=0;
        inv[i]=qsm(sum[i],mod-2);
    }
    for(int i=0;i<=n;++i) fwt(g[i],1);
    f[0][0]=1;
    fwt(f[0],1);
    for(int i=1;i<=n;++i){
        for(int j=0;j<=i-1;++j){
            for(int k=0;k<(1<<n);++k){
                f[i][k]=(f[i][k]+1ll*f[j][k]*g[i-j][k]%mod)%mod;
            }
        }
        fwt(f[i],-1);
        for(int k=0;k<(1<<n);++k){
            if(i==cnt[k]) f[i][k]=1ll*f[i][k]*inv[k]%mod;
            else f[i][k]=0;
        }
        if(i!=n) fwt(f[i],1);
    }
    
    printf("%d\n",f[n][(1<<n)-1]);

    return 0;
}

https://www.luogu.org/problemnew/solution/P4221