On this page
article
MergeSort Tree
Sobre
Se for construida sobre um array:
count(i, j, a, b) retorna quantos
elementos de v[i..j] pertencem a [a, b]
report(i, j, a, b) retorna os indices dos
elementos de v[i..j] que pertencem a [a, b]
retorna o vetor ordenado
Se for construida sobre pontos (x, y):
count(x1, x2, y1, y2) retorna quantos pontos
pertencem ao retangulo (x1, y1), (x2, y2)
report(x1, x2, y1, y2) retorna os indices dos pontos que
pertencem ao retangulo (x1, y1), (x2, y2)
retorna os pontos ordenados lexicograficamente
(assume x1 <= x2, y1 <= y2)
kth(y1, y2, k) retorna o indice do ponto com k-esimo menor
x dentre os pontos que possuem y em [y1, y2] (0 based)
Se quiser usar para achar k-esimo valor em range, construir
com ms_tree t(v, true), e chamar kth(l, r, k)
Usa O(n log(n)) de memoria
Complexidades:
construir - O(n log(n))
count - O(log(n))
report - O(log(n) + k) para k indices retornados
kth - O(log(n))
Link original: mergeSortTree.cpp
Código
template <typename T = int> struct ms_tree {
vector<tuple<T, T, int>> v;
int n;
vector<vector<tuple<T, T, int>>> t; // {y, idx, left}
vector<T> vy;
ms_tree(vector<pair<T, T>>& vv) : n(vv.size()), t(4*n), vy(n) {
for (int i = 0; i < n; i++) v.push_back({vv[i].first, vv[i].second, i});
sort(v.begin(), v.end());
build(1, 0, n-1);
for (int i = 0; i < n; i++) vy[i] = get<0>(t[1][i+1]);
}
ms_tree(vector<T>& vv, bool inv = false) { // inv: inverte indice e valor
vector<pair<T, T>> v2;
for (int i = 0; i < vv.size(); i++)
inv ? v2.push_back({vv[i], i}) : v2.push_back({i, vv[i]});
*this = ms_tree(v2);
}
void build(int p, int l, int r) {
t[p].push_back({get<0>(v[l]), get<0>(v[r]), 0}); // {min_x, max_x, 0}
if (l == r) return t[p].push_back({get<1>(v[l]), get<2>(v[l]), 0});
int m = (l+r)/2;
build(2*p, l, m), build(2*p+1, m+1, r);
int L = 0, R = 0;
while (t[p].size() <= r-l+1) {
int left = get<2>(t[p].back());
if (L > m-l or (R+m+1 <= r and t[2*p+1][1+R] < t[2*p][1+L])) {
t[p].push_back(t[2*p+1][1 + R++]);
get<2>(t[p].back()) = left;
continue;
}
t[p].push_back(t[2*p][1 + L++]);
get<2>(t[p].back()) = left+1;
}
}
int get_l(T y) { return lower_bound(vy.begin(), vy.end(), y) - vy.begin(); }
int get_r(T y) { return upper_bound(vy.begin(), vy.end(), y) - vy.begin(); }
int count(T x1, T x2, T y1, T y2) {
function<int(int, int, int)> dfs = [&](int p, int l, int r) {
if (l == r or x2 < get<0>(t[p][0]) or get<1>(t[p][0]) < x1) return 0;
if (x1 <= get<0>(t[p][0]) and get<1>(t[p][0]) <= x2) return r-l;
int nl = get<2>(t[p][l]), nr = get<2>(t[p][r]);
return dfs(2*p, nl, nr) + dfs(2*p+1, l-nl, r-nr);
};
return dfs(1, get_l(y1), get_r(y2));
}
vector<int> report(T x1, T x2, T y1, T y2) {
vector<int> ret;
function<void(int, int, int)> dfs = [&](int p, int l, int r) {
if (l == r or x2 < get<0>(t[p][0]) or get<1>(t[p][0]) < x1) return;
if (x1 <= get<0>(t[p][0]) and get<1>(t[p][0]) <= x2) {
for (int i = l; i < r; i++) ret.push_back(get<1>(t[p][i+1]));
return;
}
int nl = get<2>(t[p][l]), nr = get<2>(t[p][r]);
dfs(2*p, nl, nr), dfs(2*p+1, l-nl, r-nr);
};
dfs(1, get_l(y1), get_r(y2));
return ret;
}
int kth(T y1, T y2, int k) {
function<int(int, int, int)> dfs = [&](int p, int l, int r) {
if (k >= r-l) {
k -= r-l;
return -1;
}
if (r-l == 1) return get<1>(t[p][l+1]);
int nl = get<2>(t[p][l]), nr = get<2>(t[p][r]);
int left = dfs(2*p, nl, nr);
if (left != -1) return left;
return dfs(2*p+1, l-nl, r-nr);
};
return dfs(1, get_l(y1), get_r(y2));
}
};