CCPC2023秦皇岛

C.题意:给定一个字符串,多次区间询问给定一个子串。问有多少种方案删掉最短的区间使得子串变成回文串

Sol:考虑每次询问可以独立处理,当然并不是指把子串单独提出来,而是按顺序正常回答。

  • 一个关键的点是,考虑如果第一个字符和最后一个字符不一样,那一定要删一个,依次推理下去,我们可以得到应该保留最长回文前缀或最长回文后缀。

  • 再考虑如果第一个字符和最后一个字符一样,那我们类双指针缩小删除区间,直到两边不一样。此时又回归到第一种情况。

根据上面这个观察,我们得到获得最小删除代价的做法:

  • 对于一个区间[l, r] 我们首先可以先让首尾相同字符不断相
    消, 最后剩余区间[x, y], 如果这个区间为空则为回文串。这一部分等价于求原串后缀l 和反串后缀n − r + 1 的最长公共前缀( LCP ), 再对区间长度取min。
  • 如果不为回文串, 那么str[x]=str[y]str[x] = str[y], x, y 其中之一会被删除。
    删除最少长度, 等价于保留最多长度, 即保留x 开头的一个
    回文串或y 结尾的一个回文串。这是经典的区间最长回文子串问题,可通过在回文自动机(PAM ) 上倍增跳fail 实现

这样我们就可以得到删除最小的区间长度,下面计算方案数。

  • 假设删除[x, x + t] 是一个最优解, 在求方案数时, 我们还可以在保持剩余字符串不变的前提下, 将此区间向左进行滑动。具体地, str[x − 1] = str[x + t] 则左移一次,str[x − 2] = str[x + t − 1] 则左移两次…
  • 这里有一个思考的点:首先这里一定是左移,为什么不能右移?如果右移,表示s[x]=s[x+t+1],又因为我们保留的是最长回文后缀,所以s[y]=s[x+t-1],则得到s[x]=x[y],这与前面的算法流程矛盾,我们一定是一直找到左右端点不相等才开始找回文前后缀的。

可以发现这里的左移次数就是求两个反串后缀的LCP, 由于L 的存在, 我们还要对x− L 取min。类似地, [y − t, y] 也可以进行右移, 且和左边的区间不交, 所
以可以直接相加。可以通过画图简单地证明只有这种形式的串才能满足最短的
要求。

  • 求LCP 可以用后缀数组,加上PAM 倍增的预处理, 所以时空复杂度均为O(n log n).

实现细节与debug:

1.注意区间位下标到回文自动机点的编号的映射

2.后缀数组的求lcp需要封装一下,要牢记lc数组是排名为i的和i-1的lcp,所以区间查询是左开右闭的。区间长度为1需要特判。

3.回文树上倍增的时候我们要讨论如果当前节点已经满足条件了。不然我们是倍增到最后一个不满足条件的点,父节点就是我们要找的点。

3.统计方案数的时候,我们需要注意判断反串的上下界,正串的上界,反串的坐标社id函数映射的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
void solve() {
int n;
cin >> n;
string s;
cin >> s;

string rs = s;
reverse(rs.begin(), rs.end());

string tmp = s + "&" + rs;
SA sa(tmp);
SparseTable<int> qmx(sa.lc, [](int i, int j) { return min(i, j); });

auto pre = [&](PAM& pp, string tt) {
deb(tt);
pp.work(tt);
int tot = pp.t.size() - 1;
vector<vector<int>> e(tot + 1);
for (int i = 2; i <= tot; i++) e[pp.fail(i)].push_back(i);
int jie = __lg(tot);
vector st(jie + 1, vector<int>(tot + 1));

auto dfs = [&](auto self, int u) -> void {
for (auto v : e[u]) {
// deb(u, v);
st[0][v] = u;
self(self, v);
}
};

dfs(dfs, 1);
for (int j = 1; j <= jie; j++) {
for (int i = 1; i <= tot; i++) {
st[j][i] = st[j - 1][st[j - 1][i]];
}
}
return st;
};
auto getlcp = [&](int pos1, int pos2) {
int c1 = min(sa.rk[pos1], sa.rk[pos2]);
int c2 = max(sa.rk[pos1], sa.rk[pos2]);
assert(c1 < c2);
return qmx.get(c1 + 1, c2);
};
auto st1 = pre(pam1, s);
auto st2 = pre(pam2, rs);
int tot1 = pam1.size() - 1, tot2 = pam2.size() - 1;
int jie1 = __lg(tot1), jie2 = __lg(tot2);
deb(s);

int m;
cin >> m;
auto id = [&](int x) {
return 2 * n + 2 - x;
};
for (int i = 1; i <= m; i++) {
int l, r;
cin >> l >> r;
int lcp1 = getlcp(l, id(r));
int len = r - l + 1;
if (len == 1) {
cout << 0 << " " << 0 << endl;
continue;
}
lcp1 = min(lcp1, len);
deb(lcp1);
if (lcp1 == len) {
cout << 0 << " " << 0 << endl;
continue;
}

int ql = l + lcp1, qr = r - lcp1;
int cur = pam1.idpos[qr], cul = pam2.idpos[n + 1 - ql];
deb(ql, qr);
// deb(cur, cul);
deb(pam1.len(cur), pam2.len(cul));
int maxpre, maxsuf;
if (pam1.len(cur) <= qr - ql + 1) {
maxsuf = pam1.len(cur);
} else {
for (int j = jie1; j >= 0; j--) {
if (pam1.len(st1[j][cur]) > qr - ql + 1) {
cur = st1[j][cur];
}
}
maxsuf = max(pam1.len(st1[0][cur]), 1);
}
//--------------------
if (pam2.len(cul) <= qr - ql + 1) {
maxpre = pam2.len(cul);
} else {
for (int j = jie2; j >= 0; j--) {
if (pam2.len(st2[j][cul]) > qr - ql + 1) {
cul = st2[j][cul];
}
}

maxpre = max(1, pam2.len(st2[0][cul]));
}
// deb(cur, cul);
deb(maxpre, maxsuf);
int ans1 = qr - ql + 1 - max(maxsuf, maxpre);
int ans2 = 0;
assert(ans1 < qr - ql + 1);
deb("aaa");
if (qr - ql + 1 - maxsuf == ans1) {
ans2++;
int fl = ql, fr = fl + ans1 - 1;
deb("maxsuf", fl, fr);
if (id(fl - 1) > n + 1 && id(fl - 1) < 2 * n + 2) {
int tmpcp = getlcp(id(fl - 1), id(fr));
tmpcp = min(tmpcp, fl - 1 - l + 1);
ans2 += tmpcp;
}
}
deb("aaa");
if (qr - ql + 1 - maxpre == ans1) {
ans2++;
int fr = qr, fl = fr - ans1 + 1;
if ((fr + 1) < n + 1) {
int tmpcp = getlcp(fr + 1, fl);
tmpcp = min(tmpcp, r - (fr + 1) + 1);
ans2 += tmpcp;
}
}
cout << ans1 << " " << ans2 << endl;
}
}