[图论记录]AT2267 [AGC008E] Next or Nextnext

· · 个人记录

题意 : 给定一个长度为 n 的 序列 a ,问有多少长度为 n 的排列 p ,满足对于任意 ip_i=a_ip_{p_i}=a_i

答案对 10^9+7 取模。

------------ 根据排列 $p$ ,对每个 $i$ 连边 $i\rightarrow p_i$ ,会形成若干个有向环。 约束相当于 : 在 $u$ 点两步数以内要能走到 $a_u$。 根据序列 $a$ ,对每个 $i$ 连边 $i\rightarrow a_i$ ,会形成基环内向树森林。 先观察一个 $p$ 图可能对应怎样的 $a$ 图 : - 情况 $1$ : 全部有 $p_i=a_i

会形成一个和原来相同的环。

接下来考虑反映射,即由 a 图得到 p 图。

咋这么难写……

#include<algorithm>
#include<cstdio>
#define ll long long
#define MaxN 100500
using namespace std;
const int mod=1000000007,inv2=500000004;
struct UFS{
  int f[MaxN];
  void Init(int n)
  {for (int i=1;i<=n;i++)f[i]=i;}
  int find(int u)
  {return f[u]==u ? u : f[u]=find(f[u]);}
  bool merge(int u,int v){
    u=find(u);v=find(v);
    if (u==v)return 0;
    f[u]=v;return 1;
  }
}T;
vector<int> g[MaxN];
int st[MaxN],tot,fl;
void pfs(int u)
{
  st[++tot]=u;
  if (st[tot]==fl)return ;
  for (int i=0;i<g[u].size();i++){
    pfs(g[u][i]);
    if (st[tot]==fl)return ;
  }tot--;
}
bool vis[MaxN];
int buf,ans=1;
void dfs(int u)
{
  vis[u]=1;buf++;
  int c=0;
  for (int i=0;i<g[u].size();i++)
    if (!vis[g[u][i]]){
      dfs(g[u][i]);
      c++;
    }
  if (c>1)ans=0;
}
int o[MaxN];
void solve(int u,int v)
{
  tot=0;fl=v;pfs(u);
  reverse(st+1,st+tot+1);
  for (int i=1;i<=tot;i++)vis[st[i]]=1;
  int l1=tot;while(l1>=1&&g[st[l1]].size()+(l1==1)==1)l1--;
  if (l1==0){o[tot]++;return ;}
  for (int i=1,las=l1-tot;i<=tot;i++)
    if (g[st[i]].size()+(i==1)>1){
      int c2=i-las;buf=-1;dfs(st[i]);
      ans=(ans*((buf<c2)+(buf<=c2)))%mod;
      las=i;
    }
}
ll powM(ll a,int t=mod-2){
  ll ret=1;
  while(t){
    if (t&1)ret=ret*a%mod;
    a=a*a%mod;t>>=1;
  }return ret;
}
int fac[MaxN],ifac[MaxN],pw2[MaxN],ipw2[MaxN];
ll C(int n,int m)
{return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
void Init(int n)
{
  fac[0]=pw2[0]=ipw2[0]=1;
  for (int i=1;i<=n;i++){ 
    fac[i]=1ll*fac[i-1]*i%mod;
    pw2[i]=(pw2[i-1]<<1)%mod;
    ipw2[i]=1ll*ipw2[i-1]*inv2%mod;
  }ifac[n]=powM(fac[n]);
  for (int i=n;i;i--)
    ifac[i-1]=1ll*ifac[i]*i%mod;
}
struct Data{int u,v;}b[MaxN];
int n,m;
int main()
{
  scanf("%d",&n);
  T.Init(n);Init(n);
  for (int i=1;i<=n;i++)pw2[i]=(pw2[i-1]<<1)%mod;
  for (int i=1,p;i<=n;i++){
    scanf("%d",&p);
    if (T.merge(i,p))g[p].push_back(i);
    else b[++m]=(Data){i,p};
  }for (int i=1;i<=m;i++)solve(b[i].u,b[i].v);
  for (int i=1;i<=n;i++){
    if (!o[i])continue;
    int c=o[i],buf=0,pw=1;
    for (int k=0;2*k<=c;k++){
      int sav=C(c,2*k)*fac[2*k]%mod*ifac[k]%mod*ipw2[k]%mod*pw%mod;
      if ((i&1)&&i>1)sav=1ll*sav*pw2[c-2*k]%mod;
      buf=(buf+sav)%mod;
      pw=1ll*pw*i%mod;
    }ans=1ll*ans*buf%mod;
  }printf("%d",ans);
  return 0;
}