Mo’s Algorithm

Mo’s algorithm is a SQRT based algorithm that solves problems of the following form:

Given a static array A of length N, answer Q queries of the form [l, r]. Asking for some query on the subarray from l to r. It could be anything from finding the sum, to finding the 3rd smallest integer to finding the median, etc. An important note is that the array must be static since Mo’s is offline. 

Naively solving this problem results in a time complexity of O(NQ), which is often too slow. 

To use Mo’s algorithm, we need to maintain a data structure that supports the following queries:

  1. insert a value
  2. remove a value
  3. query the current answer

The crux of Mo’s algorithm is that we will reorder the queries in such a way that allows us to traverse them in O((N+Q)sqrt(N)). 

Divide the array in SQRT(N) blocks of size SQRT(N). Each query’s left endpoint will be in one of these blocks. For every query with the left endpoint in the same block, we sort their right endpoints and we shall process these queries in a bunch. 

As we traverse queries we maintain the current left and right endpoints of the subarray we have in our data structure and modify it to match the current query by moving left or right. 

void del(int index);
void add(int index);
int query();

int BLOCKSIZE = sqrt(N);

struct Query{
    int l, r, ind;
};

bool compare(Query &l, Query &r){
    return l.l == r.l ? l.r < r.r : l.l < r.l;
}

void doMos(vector<Query> queries){
    sort(queries.begin(), queries.end(), compare); //rearrange the queries into correct order
    int leftPtr = 0, rightPtr = 0;
    vector<int> answer(queries.size());
    for(Query q : queries){
        while(leftPtr > q.l){ //if my current range is to the right, we need to keep removing until we reach the left endpoint
            leftPtr–; add(leftPtr);
        }
        while(leftPtr < q.l){ //if we’re to the left then we need to delete some indices
            del(leftPtr); leftPtr++; 
        }
        while(rightPtr < q.r){
            rightPtr++;
            add(rightPtr);
        }
        while(rightPtr > q.r){
            del(rightPtr);
            rightPtr–;
        }
        answer[q.ind] = query();
    }
    return answer;
}

The time complexity of this algorithm ends up being O((N+Q)sqrt(N))

Optimizations:

  • We can sort odd blocks in increasing order and even blocks in decreasing order in order to minimize the traversing of the right pointer 
  • We can use hilbert ordering as mentioned here: https://codeforces.com/blog/entry/61203

Practice Problem

https://dmoj.ca/problem/nccc9s4

Code:

#include<bits/stdc++.h>
using namespace std;

struct Query{
    int l, r, id;
};
const int B = 500;
const int MM = 1e5 + 2;
struct DS{
    vector<int> freq;
    int ans = 0;
    int fakeans = 0;
    DS(){
        freq.resize(MM+1);
    }
    void clr(){
        fill(freq.begin(), freq.end(), 0);
        ans = 0;
    }
    void add(int ii){
        ++freq[ii];
        if(freq[ii] == freq[ans])ans = max(ans, ii);
        else if(freq[ii] > freq[ans]){
            ans = ii;
        }
    }
    void resetAns(){
        fakeans = ans;
    }
    void fakeAdd(int v){
        ++freq[v];
        if(freq[v] > freq[fakeans]){
            fakeans = v;
        }else if(freq[v] == freq[fakeans]){
            fakeans = max(fakeans, v);
        }
    }
    void rem(int v){
        –freq[v];
    }
};
int main(){
    ios_base::sync_with_stdio(false); cin.tie(0);
    int N, Q; cin >> N >> Q;
    vector<int> a(N);
    for(auto &u : a)cin >> u;
    vector<vector<Query>> qq(N/B + 10);
    vector<int> answer(Q);
    vector<int> f(N+1);
    for(int i = 0; i < Q; i++){
        int l, r; cin >> l >> r;
        –l, –r;
        if((l / B) == (r / B)){
            int ans = 0, fans = 0;
            for(int j = l; j <= r; j++){
                ++f[a[j]];
                if(f[a[j]] > fans){
                    fans = f[a[j]];
                    ans = a[j];
                }else if(f[a[j]] == fans){
                    ans = max(ans, a[j]);
                }
            }
            answer[i] = ans;
            for(int j = l; j <= r; j++){
                –f[a[j]];
            }
            continue;
        }
        qq[l / B].push_back({l, r, i});
    }
    DS ds;
    for(int i = 0; i <= (N+B-1) / B; i++){
        sort(qq[i].begin(), qq[i].end(), [&](auto &l, auto &r){return l.r < r.r;});
        ds.clr();
        int s = (i+1) * B;
        for(auto [l, r, id] : qq[i]){
            while(s <= r){
                ds.add(a[s++]);
            }
            ds.resetAns();
            for(int m = l; m < (i+1)*B; m++){
                ds.fakeAdd(a[m]);
            }
            answer[id] = ds.fakeans;
            for(int m = l; m < (i+1)*B; m++){
                ds.rem(a[m]);
            }
        }
    }
    for(int i = 0; i < Q; i++)cout << answer[i] << ‘\n’;
}