Range Sum of Multiset
本题要求实现一个支持在线插入(1)、删除(2)、区间和(3)的数据结构。根据数据规模 $1\leq Q\leq 5\cdot 10^5$,我们需要 $O(\lg n)$ 的时间支持这三种操作。
如果只需要支持操作 1,2,我们可以采用平衡树(红黑树,Treap)或者跳表。为了支持操作 3,我们需要在原数据结构上额外维护一下区间相关的内容。为了求 $[l,r]$ 区间内的和,我们可以维护每个结点的前缀和,用前缀和来计算区间和。
对于 Treap,要注意的是前缀和并不是子树之和,而是左子树之和。
对于跳表,可以在每层的每个结点维护:在该层,该节点到前一结点之间的元素之和。实现的时候要注意指针的使用。
示例代码 (Treap & 跳表)
#include <bits/stdc++.h>
#include <random>
using namespace std;
#define MAXN 500000
// #define DEBUG
/* Treap */
typedef struct {
int64_t key, sum_left, sum_right;
int pri, w, l, r;
} treap_node;
struct treap_rotate {
treap_node node[MAXN + 1];
int rt, sz;
void l_rotate(int &k) {
int t = node[k].r;
node[k].r = node[t].l;
node[t].l = k;
// update sum
node[k].sum_right = node[t].sum_left - node[t].w * node[t].key +
node[k].w * node[k].key;
node[t].sum_left += node[k].sum_left;
// update pointer
k = t;
}
void r_rotate(int &k) {
int t = node[k].l;
node[k].l = node[t].r;
node[t].r = k;
// update sum
node[k].sum_left = node[t].sum_right - node[t].w * node[t].key +
node[k].w * node[k].key;
node[t].sum_right += node[k].sum_right;
// update pointer
k = t;
}
// Insert
void _insert(int &k, int64_t x) {
if (!k) {
k = ++sz;
node[sz].key = node[sz].sum_left = node[sz].sum_right = x;
node[sz].pri = rand();
node[sz].w = 1;
node[sz].l = node[sz].r = 0;
} else {
if (node[k].key == x) {
node[k].sum_left += x;
node[k].sum_right += x;
node[k].w++;
} else if (node[k].key < x) {
node[k].sum_right += x;
_insert(node[k].r, x);
if (node[node[k].r].pri < node[k].pri) l_rotate(k);
} else {
node[k].sum_left += x;
_insert(node[k].l, x);
if (node[node[k].l].pri < node[k].pri) r_rotate(k);
}
}
}
void insert(int64_t x) { _insert(rt, x); }
// Remove
int _find(int &k, int64_t x) {
if (!k) return 0;
if (node[k].key == x) return 1;
if (x < node[k].key)
return _find(node[k].l, x);
else
return _find(node[k].r, x);
}
int _remove(int &k, int64_t x) {
if (!k) return 0;
if (node[k].key == x) {
if (node[k].w > 1) {
node[k].w--;
node[k].sum_left -= x;
node[k].sum_right -= x;
return node[k].w + 1;
} else {
if (node[k].l == 0 || node[k].r == 0) {
k = node[k].l + node[k].r;
return 1;
} else if (node[node[k].l].pri < node[node[k].r].pri) {
r_rotate(k);
return _remove(k, x);
} else {
l_rotate(k);
return _remove(k, x);
}
}
} else if (node[k].key < x) {
node[k].sum_right -= x;
return _remove(node[k].r, x);
} else {
node[k].sum_left -= x;
return _remove(node[k].l, x);
}
return 0;
}
int remove(int64_t x) {
if (_find(rt, x))
return _remove(rt, x);
else
return 0;
}
// Query
int64_t _querysum(int &k, int64_t x) {
if (!k) return 0;
if (node[k].key <= x) {
return node[k].sum_left + _querysum(node[k].r, x);
} else {
return _querysum(node[k].l, x);
}
}
int64_t query(int64_t l, int64_t r) {
return _querysum(rt, r) - _querysum(rt, l);
}
};
/* Skip List */
const int max_level = 20;
typedef struct SkipNode {
int64_t key;
int w, level;
SkipNode **nxt;
int64_t *sum_pre;
SkipNode(int64_t k, int level) : key(k), level(level) {
nxt = new SkipNode *[level];
sum_pre = new int64_t[level];
for (int i = 0; i < level; ++i) {
nxt[i] = nullptr;
sum_pre[i] = 0;
}
w = 1;
}
} skip_node;
class skip_list {
private:
skip_node *head, *tail;
int random_level() {
int random_level = 1;
static int seed = time(NULL);
static default_random_engine e(seed);
static uniform_int_distribution<int> u(0, 1);
// while (rand() % 2 && random_level < max_level) random_level++;
while (u(e) % 2 && random_level < max_level) random_level++;
return random_level;
};
skip_node *_find_pre(int64_t x) {
skip_node *tmp = head;
for (int i = max_level - 1; i >= 0; --i) {
while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
tmp = tmp->nxt[i];
}
return tmp;
}
int64_t _sum(int64_t x) {
skip_node *tmp = head;
int64_t sum = 0;
for (int i = max_level - 1; i >= 0; --i) {
while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key <= x) {
tmp = tmp->nxt[i];
sum += tmp->sum_pre[i];
}
}
return sum;
}
public:
skip_list() {
head = new skip_node(LLONG_MIN, max_level);
tail = new skip_node(LLONG_MAX, max_level);
for (int i = 0; i < max_level; ++i) {
head->nxt[i] = tail;
}
};
void insert(int64_t x) {
skip_node *node = _find_pre(x)->nxt[0];
if (node->key == x) {
node->w++;
for (int i = 0; i < node->level; ++i) node->sum_pre[i] += x;
skip_node *tmp = head;
for (int i = max_level - 1; i >= node->level; --i) {
while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
tmp = tmp->nxt[i];
tmp->nxt[i]->sum_pre[i] += x;
}
} else {
int r_level = random_level();
skip_node *new_node = new skip_node(x, r_level);
skip_node *tmp = head;
for (int i = max_level - 1; i >= 0; --i) {
int64_t acc_sum = 0;
while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x) {
tmp = tmp->nxt[i];
acc_sum += tmp->sum_pre[i];
}
if (i < r_level) {
if (i + 1 < r_level) new_node->sum_pre[i + 1] = acc_sum;
new_node->nxt[i] = tmp->nxt[i];
tmp->nxt[i] = new_node;
} else {
tmp->nxt[i]->sum_pre[i] += x;
}
}
new_node->sum_pre[0] = x;
for (int i = 1; i < new_node->level; ++i) {
new_node->sum_pre[i] += new_node->sum_pre[i - 1];
new_node->nxt[i]->sum_pre[i] -= new_node->sum_pre[i] - x;
}
}
}
int remove(int64_t x) {
skip_node *node = _find_pre(x)->nxt[0];
if (node->key != x) {
return 0;
} else {
if (node->w > 1) {
node->w--;
for (int i = 0; i < node->level; ++i) node->sum_pre[i] -= x;
skip_node *tmp = head;
for (int i = max_level - 1; i >= node->level; --i) {
while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
tmp = tmp->nxt[i];
tmp->nxt[i]->sum_pre[i] -= x;
}
return node->w + 1;
} else {
skip_node *tmp = head;
for (int i = max_level - 1; i >= 0; --i) {
while (tmp->nxt[i] != nullptr && tmp->nxt[i]->key < x)
tmp = tmp->nxt[i];
if (i < node->level) {
tmp->nxt[i]->nxt[i]->sum_pre[i] +=
tmp->nxt[i]->sum_pre[i] - x;
tmp->nxt[i] = tmp->nxt[i]->nxt[i];
} else {
tmp->nxt[i]->sum_pre[i] -= x;
}
}
delete node;
return 1;
}
}
}
int64_t query(int64_t l, int64_t r) { return _sum(r) - _sum(l); }
};
/*
Use Treap
Time: 9s, Memory: 22.1 MB
*/
typedef struct treap_rotate Multiset;
/*
Use Skip List
max_level = 10, TLE
max_level = 15, Time: 13.4s, Memory: 59.0 MB
max_level = 20, Time: 13.8s, Memory: 59.0 MB
max_level = 40, Time: 14.8s, Memory: 59.0 MB
*/
// typedef struct skip_list Multiset;
Multiset S;
int64_t Q, mod, lans, y, x, l, r, u, v, cnt;
int q;
int main() {
scanf("%lld%lld", &Q, &mod);
for (int t = 0; t < Q; ++t) {
scanf("%d", &q);
switch (q) {
case 0:
scanf("%lld", &y);
#ifndef DEBUG
x = (y + lans) % mod;
#else
x = y;
#endif
S.insert(x);
// printf("0 %lld\n", x);
break;
case 1:
scanf("%lld", &y);
#ifndef DEBUG
x = (y + lans) % mod;
#else
x = y;
#endif
cnt = S.remove(x);
printf("%lld\n", cnt);
// printf("1 %lld\n", x);
break;
case 2:
scanf("%lld%lld", &u, &v);
#ifndef DEBUG
l = (u + lans) % mod;
r = (v + lans) % mod;
#else
l = u;
r = v;
#endif
if (l > r) swap(l, r);
lans = S.query(l, r);
printf("%lld\n", lans);
// printf("2 %lld %lld\n", l, r);
break;
}
}
return 0;
}