HDU2243 (考研路茫茫——单词情结)[AC自动机,矩阵乘法]

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2243

题目大意:给定n个模板串,问有多少个长度不超过L的含至少一个模板串的字符串。
首先考虑问题的相反问题,求有多少个不含任何模板串的字符串,并只考虑长度正好为$l$的字符串的数量,这个问题的解题方法见POJ2778(DNA Sequence)
现在我们已经能计算长度为$l$的字符串的数量,但问题要求的是长度不超过$L$的字符串数量,答案对应的矩阵为 $A^1+A^2+…+A^L$ 。
令 $S_n=A^1+A^2+…+A^n$,则$S_n=S_{n-1}A+A$
$$
\left[ \begin{matrix} S_n & A \end{matrix} \right] = \left[ \begin{matrix} S_{n-1} & A \end{matrix} \right] \left[ \begin{matrix} A & 0 \\ E & E \end{matrix} \right] = \left[ \begin{matrix} 0 & A \end{matrix} \right] {\left[ \begin{matrix} A & 0 \\ E & E \end{matrix} \right]}^n
$$
这样就能用矩阵快速幂求出不含任何模板串的答案对应的矩阵了。接下来要用总方案数减去这个答案。
使用同样的方法求出$T_n=26^1+26^2+…+26^n$,并用$T_n$减去上述答案即可。
$$
\left[ \begin{matrix} T_n & 26 \end{matrix} \right] = \left[ \begin{matrix} T_{n-1} & 26 \end{matrix} \right] \left[ \begin{matrix} 26 & 0 \\ 1 & 1 \end{matrix} \right] = \left[ \begin{matrix} 0 & 26 \end{matrix} \right] {\left[ \begin{matrix} 26 & 0 \\ 1 & 1 \end{matrix} \right]}^n
$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <stack>
using namespace std;
typedef unsigned long long ull;
const int maxn=32;

int ch[maxn][26];
int f[maxn];
bool danger[maxn];
int sz;
void init (){
sz=1;
memset(ch[0],0,sizeof(ch[0]));
memset(f,0,sizeof(f));
memset(danger,0,sizeof(danger));
}
int idx(char c) {
return c-'a';
}
void insert(char *s) {
int u=0,n=strlen(s);
for(int i=0;i<n;i++) {
int c=idx(s[i]);
if(!ch[u][c]) {
memset(ch[sz],0,sizeof(ch[sz]));
danger[sz]=false;
ch[u][c]=sz++;
}
u=ch[u][c];
}
danger[u]=true;
}
void getFail() {
queue<int> q;
while(!q.empty()) q.pop();
f[0]=0;
for(int c=0;c<26;c++) {
int u=ch[0][c];
if(u) { f[u]=0;q.push(u); }
}
while(!q.empty()) {
int r=q.front(); q.pop();
for(int c=0;c<26;c++) {
int u=ch[r][c];
if(!u) { ch[r][c]=ch[f[r]][c];continue; }
q.push(u);
int v=f[r];
while(v && !ch[v][c]) v=f[v];
f[u]=ch[v][c];
danger[u]|=danger[ch[v][c]];
}
}
}

ull mat[maxn*2][maxn*2];
ull A[maxn][maxn];
void build() {
memset(A,0,sizeof(A));
for(int i=0;i<sz;i++) if(!danger[i]) {
for(int j=0;j<26;j++) {
int u=ch[i][j];
if(!danger[u]) A[i][u]++;
}
}
memset(mat,0,sizeof(mat));
for(int i=0;i<sz;i++) {
for(int j=0;j<sz;j++) {
mat[i][j]=A[i][j];
}
}
for(int i=0;i<sz;i++)
mat[i+sz][i]=mat[i+sz][i+sz]=1;
}

ull ans[maxn*2][maxn*2];
void pow_mod(int y) {
ull tmp[maxn*2][maxn*2];
memset(ans,0,sizeof(ans));
for(int i=0;i<sz*2;i++)
ans[i][i]=1;
for(;y;y>>=1) {
if(y&1) {
memset(tmp,0,sizeof(tmp));
for(int i=0;i<sz*2;i++)
for(int j=0;j<sz*2;j++)
for(int k=0;k<sz*2;k++)
tmp[i][j]=((ans[i][k]*mat[k][j])+tmp[i][j]);
memcpy(ans,tmp,sizeof(tmp));
}
memset(tmp,0,sizeof(tmp));
for(int i=0;i<sz*2;i++)
for(int j=0;j<sz*2;j++)
for(int k=0;k<sz*2;k++)
tmp[i][j]=((mat[i][k]*mat[k][j])+tmp[i][j]);
memcpy(mat,tmp,sizeof(tmp));
}
}

ull tot[2][2];
ull mat2[2][2]={
{26,0},
{1,1}
};
void pow_mod_sum(int y) {
ull tmp[2][2];
mat2[0][0]=26;
mat2[0][1]=0;
mat2[1][0]=1;
mat2[1][1]=1;
memset(tot,0,sizeof(tot));
for(int i=0;i<2;i++)
tot[i][i]=1;
for(;y;y>>=1) {
if(y&1) {
memset(tmp,0,sizeof(tmp));
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
tmp[i][j]=((tot[i][k]*mat2[k][j])+tmp[i][j]);
memcpy(tot,tmp,sizeof(tmp));
}
memset(tmp,0,sizeof(tmp));
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
tmp[i][j]=((mat2[i][k]*mat2[k][j])+tmp[i][j]);
memcpy(mat2,tmp,sizeof(tmp));
}
}

char s[15];
ull seq[maxn][maxn*2];
ull ans2[maxn][maxn*2];
int main() {
int n,l;
while(~scanf("%d%d",&n,&l)) {
init();
for(int i=0;i<n;i++) {
scanf("%s",s);
insert(s);
}
getFail();
build();
pow_mod(l);
memset(seq,0,sizeof(seq));
memset(ans2,0,sizeof(ans2));
for(int i=0;i<sz;i++) {
for(int j=0;j<sz;j++) {
seq[i][j+sz]=A[i][j];
}
}
for(int i=0;i<sz;i++) {
for(int j=0;j<2*sz;j++) {
for(int k=0;k<2*sz;k++) {
ans2[i][j]=seq[i][k]*ans[k][j]+ans2[i][j];
}
}
}
ull a=0;
for(int i=0;i<sz;i++) {
a=(ans2[0][i]+a);
}
pow_mod_sum(l);
ull sum=tot[1][0]*26;
printf("%I64u\n",sum-a);
}
}