题目连接:Range Sum of Multiset

Range Sum of Multiset

本题要求实现一个支持在线插入(1)、删除(2)、区间和(3)的数据结构。根据数据规模 $1\leq Q\leq 5\cdot 10^5$,我们需要 $O(\lg n)$ 的时间支持这三种操作。

如果只需要支持操作 1,2,我们可以采用平衡树(红黑树,Treap)或者跳表。为了支持操作 3,我们需要在原数据结构上额外维护一下区间相关的内容。为了求 $[l,r]$ 区间内的和,我们可以维护每个结点的前缀和,用前缀和来计算区间和。

可以参考网上 Treap跳表 的实现。

对于 Treap,要注意的是前缀和并不是子树之和,而是左子树之和。

对于跳表,可以在每层的每个结点维护:在该层,该节点到前一结点之间的元素之和。实现的时候要注意指针的使用。

示例代码 (Treap & 跳表)

#include <bits/stdc++.h>

#include <random>
using namespace std;

#define MAXN 500000
// #define DEBUG

/* Treap */

typedef struct {
    int64_t key, sum_left, sum_right;
    int pri, w, l, r;
} treap_node;

struct treap_rotate {
    treap_node node[MAXN + 1];
    int rt, sz;
    void l_rotate(int &k) {
        int t = node[k].r;
        node[k].r = node[t].l;
        node[t].l = k;
        // update sum
        node[k].sum_right = node[t].sum_left - node[t].w * node[t].key +
                            node[k].w * node[k].key;
        node[t].sum_left += node[k].sum_left;

        // update pointer
        k = t;
    }
    void r_rotate(int &k) {
        int t = node[k].l;
        node[k].l = node[t].r;
        node[t].r = k;
        // update sum
        node[k].sum_left = node[t].sum_right - node[t].w * node[t].key +
                           node[k].w * node[k].key;
        node[t].sum_right += node[k].sum_right;
        // update pointer
        k = t;
    }
    // Insert
    void _insert(int &k, int64_t x) {
        if (!k) {
            k = ++sz;
            node[sz].key = node[sz].sum_left = node[sz].sum_right = x;
            node[sz].pri = rand();
            node[sz].w = 1;
            node[sz].l = node[sz].r = 0;
        } else {
            if (node[k].key == x) {
                node[k].sum_left += x;
                node[k].sum_right += x;
                node[k].w++;
            } else if (node[k].key < x) {
                node[k].sum_right += x;
                _insert(node[k].r, x);
                if (node[node[k].r].pri < node[k].pri) l_rotate(k);
            } else {
                node[k].sum_left += x;
                _insert(node[k].l, x);
                if (node[node[k].l].pri < node[k].pri) r_rotate(k);
            }
        }
    }
    void insert(int64_t x) { _insert(rt, x); }
    // Remove
    int _find(int &k, int64_t x) {
        if (!k) return 0;
        if (node[k].key == x) return 1;
        if (x < node[k].key)
            return _find(node[k].l, x);
        else
            return _find(node[k].r, x);
    }
    int _remove(int &k, int64_t x) {
        if (!k) return 0;
        if (node[k].key == x) {
            if (node[k].w > 1) {
                node[k].w--;
                node[k].sum_left -= x;
                node[k].sum_right -= x;
                return node[k].w + 1;
            } else {
                if (node[k].l == 0 || node[k].r == 0) {
                    k = node[k].l + node[k].r;
                    return 1;
                } else if (node[node[k].l].pri < node[node[k].r].pri) {
                    r_rotate(k);
                    return _remove(k, x);
                } else {
                    l_rotate(k);
                    return _remove(k, x);
                }
            }
        } else if (node[k].key < x) {
            node[k].sum_right -= x;
            return _remove(node[k].r, x);
        } else {
            node[k].sum_left -= x;
            return _remove(node[k].l, x);
        }
        return 0;
    }
    int remove(int64_t x) {
        if (_find(rt, x))
            return _remove(rt, x);
        else
            return 0;
    }
    // Query
    int64_t _querysum(int &k, int64_t x) {
        if (!k) return 0;
        if (node[k].key <= x) {
            return node[k].sum_left + _querysum(node[k].r, x);
        } else {
            return _querysum(node[k].l, x);
        }
    }
    int64_t query(int64_t l, int64_t r) {
        return _querysum(rt, r) - _querysum(rt, l);
    }
};

/* Skip List */

const int max_level = 20;
typedef struct SkipNode {
    int64_t key;
    int w, level;
    SkipNode **nxt;
    int64_t *sum_pre;
    SkipNode(int64_t k, int level) : key(k), level(level) {
        nxt = new SkipNode *[level];
        sum_pre = new int64_t[level];
        for (int i = 0; i < level; ++i) {
            nxt[i] = nullptr;
            sum_pre[i] = 0;
        }
        w = 1;
    }
} skip_node;

class skip_list {
   private:
    skip_node *head, *tail;
    int random_level() {
        int random_level = 1;

        static int seed = time(NULL);
        static default_random_engine e(seed);
        static uniform_int_distribution<int> u(0, 1);

        // while (rand() % 2 && random_level < max_level) random_level++;
        while (u(e) % 2 && random_level < max_level) random_level++;
        return random_level;
    };
    skip_node *_find_pre(int64_t x) {
        skip_node *tmp = head;
        for (int i = max_level - 1; i >= 0; --i) {
            while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
                tmp = tmp->nxt[i];
        }
        return tmp;
    }
    int64_t _sum(int64_t x) {
        skip_node *tmp = head;
        int64_t sum = 0;
        for (int i = max_level - 1; i >= 0; --i) {
            while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key <= x) {
                tmp = tmp->nxt[i];
                sum += tmp->sum_pre[i];
            }
        }
        return sum;
    }

   public:
    skip_list() {
        head = new skip_node(LLONG_MIN, max_level);
        tail = new skip_node(LLONG_MAX, max_level);
        for (int i = 0; i < max_level; ++i) {
            head->nxt[i] = tail;
        }
    };
    void insert(int64_t x) {
        skip_node *node = _find_pre(x)->nxt[0];
        if (node->key == x) {
            node->w++;
            for (int i = 0; i < node->level; ++i) node->sum_pre[i] += x;
            skip_node *tmp = head;
            for (int i = max_level - 1; i >= node->level; --i) {
                while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
                    tmp = tmp->nxt[i];
                tmp->nxt[i]->sum_pre[i] += x;
            }
        } else {
            int r_level = random_level();
            skip_node *new_node = new skip_node(x, r_level);
            skip_node *tmp = head;
            for (int i = max_level - 1; i >= 0; --i) {
                int64_t acc_sum = 0;
                while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x) {
                    tmp = tmp->nxt[i];
                    acc_sum += tmp->sum_pre[i];
                }
                if (i < r_level) {
                    if (i + 1 < r_level) new_node->sum_pre[i + 1] = acc_sum;
                    new_node->nxt[i] = tmp->nxt[i];
                    tmp->nxt[i] = new_node;
                } else {
                    tmp->nxt[i]->sum_pre[i] += x;
                }
            }
            new_node->sum_pre[0] = x;
            for (int i = 1; i < new_node->level; ++i) {
                new_node->sum_pre[i] += new_node->sum_pre[i - 1];
                new_node->nxt[i]->sum_pre[i] -= new_node->sum_pre[i] - x;
            }
        }
    }
    int remove(int64_t x) {
        skip_node *node = _find_pre(x)->nxt[0];
        if (node->key != x) {
            return 0;
        } else {
            if (node->w > 1) {
                node->w--;
                for (int i = 0; i < node->level; ++i) node->sum_pre[i] -= x;
                skip_node *tmp = head;
                for (int i = max_level - 1; i >= node->level; --i) {
                    while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
                        tmp = tmp->nxt[i];
                    tmp->nxt[i]->sum_pre[i] -= x;
                }
                return node->w + 1;
            } else {
                skip_node *tmp = head;
                for (int i = max_level - 1; i >= 0; --i) {
                    while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
                        tmp = tmp->nxt[i];
                    if (i < node->level) {
                        tmp->nxt[i]->nxt[i]->sum_pre[i] +=
                            tmp->nxt[i]->sum_pre[i] - x;
                        tmp->nxt[i] = tmp->nxt[i]->nxt[i];
                    } else {
                        tmp->nxt[i]->sum_pre[i] -= x;
                    }
                }
                delete node;
                return 1;
            }
        }
    }
    int64_t query(int64_t l, int64_t r) { return _sum(r) - _sum(l); }
};

/*
Use Treap
    Time: 9s, Memory: 22.1 MB
*/
typedef struct treap_rotate Multiset;

/*
Use Skip List
    max_level = 10, TLE
    max_level = 15, Time: 13.4s, Memory: 59.0 MB
    max_level = 20, Time: 13.8s, Memory: 59.0 MB
    max_level = 40, Time: 14.8s, Memory: 59.0 MB
*/
// typedef struct skip_list Multiset;

Multiset S;
int64_t Q, mod, lans, y, x, l, r, u, v, cnt;
int q;

int main() {
    scanf("%lld%lld", &Q, &mod);
    for (int t = 0; t < Q; ++t) {
        scanf("%d", &q);
        switch (q) {
            case 0:
                scanf("%lld", &y);
#ifndef DEBUG
                x = (y + lans) % mod;
#else
                x = y;
#endif
                S.insert(x);
                // printf("0 %lld\n", x);
                break;
            case 1:
                scanf("%lld", &y);
#ifndef DEBUG
                x = (y + lans) % mod;
#else
                x = y;
#endif
                cnt = S.remove(x);
                printf("%lld\n", cnt);
                // printf("1 %lld\n", x);
                break;
            case 2:
                scanf("%lld%lld", &u, &v);
#ifndef DEBUG
                l = (u + lans) % mod;
                r = (v + lans) % mod;
#else
                l = u;
                r = v;
#endif
                if (l > r) swap(l, r);
                lans = S.query(l, r);
                printf("%lld\n", lans);
                // printf("2 %lld %lld\n", l, r);
                break;
        }
    }
    return 0;
}