跳转至

字典树 做题记录

P4551

P4551

钦定根为 \(root\),我们首先考虑从 \(root\) 开始 \(dfs\) 处理出每个节点 \(i\) 到其的距离 \(dis[i]\);由于两节点的 \(LCA\) 部分对最终的异或答案不产生影响,有一个很显然的想法是枚举每对节点取最大值,但 \(O(n^2)\) 无法接受,需要效率更优的算法

已知节点 \(i\),不用枚举,如何算出其他节点到 \(i\) 的异或路径长度最大值?可以按位考虑 贪心;从高到低枚举每一位,我们希望新节点的当前位与 \(dis[i]\) 的当前位尽量不同;但是可能无法保证新节点的当前位总不同,即需要 验证不同的一位是否存在,这可以使用 \(01trie\) 维护

具体的,我们将每一个 \(dis[i]\) 插入字典树,枚举每个距离 \(dis[i]\),维护当前字典树上的节点 \(node\),使得 \(node\) 能往与当前位不同的节点走就往不同的节点走即可

代码

C++
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int MAXN = 1e5+5, MAXL = 4e6+5;

int n, cnt;
ll ans;
ll dis[MAXN];

struct Edge{
    int to;
    ll weig;
};
vector <Edge> G[MAXN];

struct Node{
    int nex[2];
    bool exist;
}trie[MAXL];

void insert(ll x){
    int node = 0;
    for (int i=31; i>=0; i--){
        int c = ((x>>i)&1);
        if (!trie[node].nex[c]) cnt++, trie[node].nex[c]=cnt;
        node = trie[node].nex[c];
    }
    trie[node].exist = true;
}

void getans(ll x){
    int node = 0;
    ll res = 0;
    for (int i=31; i>=0; i--){
        int c = ((x>>i)&1);
        if (trie[node].nex[(c^1)]) node=trie[node].nex[(c^1)], res = (res|(1<<i));
        else node = trie[node].nex[c];
    }
    ans = max(ans, res);
}

void dfs(int node, int fa, ll rdis){
    dis[node] = rdis;
    for (int i=0; i<G[node].size(); i++){
        if (G[node][i].to == fa) continue;
        dfs(G[node][i].to, node, (rdis^G[node][i].weig));
    }   
}

void show(int node, string s){
    if (trie[node].exist) cout << s << endl;
    if (trie[node].nex[0]) show(trie[node].nex[0], s+"0");
    if (trie[node].nex[1]) show(trie[node].nex[1], s+"1");
}

int main(){
    scanf("%d", &n);
    for (int i=1; i<=n-1; i++){
        int node1, node2;
        ll weig;
        scanf("%d%d%lld", &node1, &node2, &weig);
        G[node1].push_back({node2, weig});
        G[node2].push_back({node1, weig});
    }

    dfs(1, 1, 0);

    for (int i=1; i<=n; i++){
        insert(dis[i]);
    }
    // show(0, "");
    for (int i=1; i<=n; i++){
        getans(dis[i]);
    }

    printf("%lld", ans);
    return 0;
}

P3294

P3294

先总结题意,设当前位置为 \(x\),总共有 \(n\) 个单词 1. 如果当前单词存在一个后缀,其位置在 \(x\) 后,花费加 \(n^2\) 2. 如果当前单词没有后缀,花费加 \(x\) 3. 如果当前单词所有后缀的位置都在 \(x\) 前,设这些位置中的最大值为 \(y\),花费加 \(x-y\)

我们发现,第 \(1\) 个情况的花费开销实在太大,远远高过后两种情况,因此我们要避免,即 使得单词的所有后缀都在填在其前面;对于另外两种情况,实质上情况 \(2\) 是情况 \(3\)\(y=0\) 时的特殊情况,因此只需要考虑情况 \(3\) 即可

显然我们需要处理单词间的后缀关系,很自然的想法就是 将每个单词反转插入字典树;字典树中非实际存在的节点没有用,因此我们进一步建出其 重构树,实质上即每个节点的最长后缀向该节点连边形成的树;我们考虑在其上进行操作

其实感性理解上,我们发现将树上的每条链连在一起的这种排法看起来比较优秀,因为再把任何一个其他链上的节点插进去会增加链后所有节点的花费;这实际上是 \(dfs\)

尝试证明这一结论;发现所有的花费中都含 \(x\),因此只需考虑 \(-y\),即我们想让 \(\sum-y\) 最小,即 \(\sum y\) 最大;对于位置 \(x\),其相应的 \(y\) 显然满足 \(y \le x\),我们想让 \(y\) 尽量大,那么最好就与 \(x\) *相邻* (或者就是 \(x\));如何让其相邻?实际上就是把链上的节点连在一起,中间没有其他链的节点;即得证

现在我们可以把每条链看做一个整体;还没完,链的长度显然也会对答案产生影响;对于节点 \(i\) 每个子树的根节点 \(j\),其价值开销就是在 \(j\)\(i\) 所有儿子的子树大小和,那么肯定 从小到大排序 最优;其实很容易理解,这就相当于排队打水计算每个人等待时间的和,肯定让打水时间少的人站在前面更优

实现上,对于重构树,我们使用并查集记录节点 \(x\)\(fa[x]\) 表示在 \(x\) 上方存在且最近的节点 (如果上方没有也可以是 \(x\) 自身),维护并查集,在搜到存在的节点 \(i\) 时使 \(last[i]\)\(i\) 连边即可

别忘了排序,统计答案时直接按照 \(dfs\) 序即可

代码

C++
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int MAXL = 5e5+5, MAXN = 1e5+5;

int n, cnt, recnt;
int fa[MAXL];
ll ans;
string s;

struct Node{
    int child[26];
    bool exist;
    Node(){memset(child, 0, sizeof(child)); exist=false;}
}trie[MAXL];

struct ReNode{
    vector <int> child;
    int siz;
    ReNode(){child.clear(); siz=0;}
}retrie[MAXL];

void insert(string s){
    int node = 1;
    for (int i=s.size()-1; i>=0; i--){
        int u = s[i]-'a';
        if (!trie[node].child[u]) cnt++, trie[node].child[u] = cnt;
        node = trie[node].child[u];
    }
    if (!trie[node].exist) trie[node].exist = true;
    return;
}

void dfs(int node){
    if (trie[node].exist == true && node>1){
        retrie[fa[node]].child.push_back(node);
        fa[node] = node;
    }
    for (int i=0; i<26; i++){
        if (trie[node].child[i]){
            fa[trie[node].child[i]] = fa[node]; 
            dfs(trie[node].child[i]);   
        }
    }
}

void dfssiz(int node){
    retrie[node].siz = 1;
    for (int i=0; i<retrie[node].child.size(); i++){
        dfssiz(retrie[node].child[i]);
        retrie[node].siz += retrie[retrie[node].child[i]].siz;
    }
}

bool cmp(int a, int b){
    return retrie[a].siz < retrie[b].siz;
}

void dfssort(int node){
    sort(retrie[node].child.begin(), retrie[node].child.end(), cmp);
    for (int i=0; i<retrie[node].child.size(); i++){
        dfssort(retrie[node].child[i]); 
    }    
}

void get(int node){
    int tmp = recnt;
    recnt++;
    for (int i=0; i<retrie[node].child.size(); i++){
        ans += recnt-tmp;
        get(retrie[node].child[i]);
    }   
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cnt = 1;
    cin >> n;
    for (int i=1; i<=n; i++){
        cin >> s;
        insert(s);
    }
    trie[1].exist = true;
    fa[1] = 1;
    dfs(1);
    dfssiz(1);
    dfssort(1);
    get(1);

    cout << ans;
    return 0;
}

P2292

P2292

字典树什么东西,直接上自动机

有一个显然的 \(dp\),记 \(dp[i]\) 表示文本串前 \(i\) 位是否能被理解 (可行性 \(dp\)),转移即对于每个 \(i\),枚举 \(j(1 \le j \le i)\),若 \(s[j]s[j+1]...s[i]\) 是字典中的一个单词且 \(dp[j-1]\) 为真,\(dp[i]\) 即为真

初值即为 \(dp[0] = true\)

对于文本串的前缀,我们发现需要处理出其所有在字典中的后缀进行转移,这很显然就是 \(AC\) 自动机,一直跳 \(fail\) 即可;直接做会超时,考虑进行玄学优化

记录字典中单词的最大长度 \(maxl\),对于文本串当前位 \(i\),维护 \(i\) 之前可以被理解的最大位置 \(maxpl\),显然当 \(maxl+maxpl < i\) 时可以直接返回假;另外,当 \(dp[i]\) 已经为真时就没必要继续跳了,直接返回即可;这样就水过了此题

代码

C++
#include<bits/stdc++.h>
using namespace std;
const int MAXL = 2e6+5, MAXT = 2e4+5;

int n, m, cnt, maxl, maxpl;
bool dp[MAXL];
string s;

struct Node{
    int child[26];
    int fail;
    int len;
    int exist;

    Node(){memset(child, 0, sizeof(child)); fail = len = exist = 0;}
}trie[MAXT];

void insert(string s){
    int node = 1;
    for (int i=0; i<s.size(); i++){
        int u = s[i]-'a';
        if (!trie[node].child[u]) cnt++, trie[node].child[u] = cnt;
        node = trie[node].child[u];
        trie[node].len = i+1;
    }
    if (!trie[node].exist) trie[node].exist=1;
    maxl = max(maxl, trie[node].len);
    return;
}

queue <int> q;

void build(){
    for (int i=0; i<26; i++) trie[0].child[i] = 1;
    trie[1].fail = 0;
    q.push(1);

    while (!q.empty()){
        int node = q.front();
        q.pop();
        int fail = trie[node].fail;

        for (int i=0; i<26; i++){
            int v = trie[node].child[i];
            if (v){
                trie[v].fail = trie[fail].child[i];
                q.push(v);
            }
            else{
                trie[node].child[i] = trie[fail].child[i];
            }
        }   
    }
}

void query(string s){
    int node = 1;
    for (int i=0; i<s.size(); i++){
        int u = s[i]-'a';
        node = trie[node].child[u];
        for (int j=node; j>1; j=trie[j].fail){
            if (trie[j].exist == 1){
                dp[i+1] = (dp[i+1] | dp[i+1-trie[j].len]);
                if (dp[i+1] == true){maxpl = max(maxpl, i+1); break;}
                if (maxpl+maxl < i+1) break;                
            }
        }
    }       
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    cnt = 1;
    cin >> n >> m;
    for (int i=1; i<=n; i++){
        cin >> s;
        insert(s);
    }
    build();
    for (int i=1; i<=m; i++){
        memset(dp, false, sizeof(dp));
        dp[0] = true;
        maxpl = 0;
        cin >> s;
        query(s);
        cout << maxpl << endl;
    }
    return 0;
}

P2922

P2922

我们在字典树中插入每个 \(b_i\);设当前处理的文本为 \(c\),维护字典树上每个节点 \(node\) 被包含的次数 \(cnt1\) (注意当单词在该节点结束时也统计) 与在 \(node\)结束的单词数 \(cnt2\)

这样,设 \(c\) 在字典树上的终止节点为 \(node\),对于 \(c\) 包含 \(b_i\) 的情况,我们只需要统计从根到 \(node\) 上的 \(\sum cnt2\);对于 \(b_i\) 包含 \(c\) 的情况,我们只需要统计 \(node\) 上的 \(cnt1\);需要注意,由于我们在单词的终止节点同时增加了 \(cnt1\)\(cnt2\) 的贡献,因此这样我们会 多算一遍 \(node\) 上的 \(cnt2\)

最终答案即为 \(\sum cnt2+cnt1_{node}-cnt2_{node}\)

代码

C++
#include<bits/stdc++.h>
using namespace std;
const int MAXL = 5e5+5;

int N, M, num, nc, cnt, res;
string tmp;

struct Node{
    int nex[2];
    bool exist;
    int cnt1, cnt2;
    Node(){nex[0]=nex[1]=0;exist=false;cnt1=cnt2=0;}
}trie[MAXL];

void insert(string s){
    int node = 0;
    for (int i=0; i<s.size(); i++){
        int c = s[i]-'0';
        if (!trie[node].nex[c]) cnt++,trie[node].nex[c]=cnt;
        node = trie[node].nex[c];
        trie[node].cnt1++;
    }
    trie[node].cnt2++;
}

int query(string s){
    int node = 0, res = 0;
    for (int i=0; i<s.size(); i++){
        int c = s[i]-'0';
        if (!trie[node].nex[c]) return res;
        node = trie[node].nex[c];
        res += trie[node].cnt2;
    }
    res += (trie[node].cnt1-trie[node].cnt2);
    return res;
}

int main(){
    scanf("%d%d", &N, &M);
    for (int i=1; i<=N; i++){
        scanf("%d", &num);
        tmp = "";
        for (int j=1; j<=num; j++){
            scanf("%d", &nc);
            tmp += (nc+'0');
        }
        insert(tmp);
    }
    for (int i=1; i<=M; i++){
        scanf("%d", &num);
        tmp = "";
        for (int j=1; j<=num; j++){
            scanf("%d", &nc);
            tmp += (nc+'0');
        }
        printf("%d\n", query(tmp));
    }
    return 0;
}

评论