Let's start with finding in which "round" the cell is. That is, how often did the spiral go fully around before hitting this cell:
int n = min(x, y, M - x - 1, N - y - 1);
The first full round consists of 2*M + N) - 4
cells, the next one of 2*(M + N) - 12
cells, and so on (I hope you believe me in this). More general, round i
consists of 2*(M + N - 2) - 8*i
cells.
So how many cells are in the first n
rounds? Just sum the value just found:
sum(0 <= i < n : 2*(M + N - 2) - 8*i) = 2*n*(M + N - 2) - 8 * sum(0 <= i < n : i)
= 2*n*(M + N - 2) - 8 * n * (n - 1) / 2
= 2*n*(M + N - 2*n)
We can already add this value to the index:
int index = 2 * n * (M + N - 2 * n);
Now we just need to check where in the current round the cell is:
if (n == y) {
// top of this round
index += x - n;
} else {
// add full top of this round
index += M - 2 * n;
if (n == M - x - 1) {
// right side of this round
index += y - (n + 1);
} else {
// add full right side of this round
index += N - 2 * n - 1;
if (n == N - y - 1) {
// bottom of this round
index += N - x - 1 - (n + 1);
} else {
// add full bottom of this round
index += M - 2 * n - 1;
// left side of this round
index += M - y - 1 - (n+1);
}
}
}
I called the method spiral(M, N, x, y)
and ran it as follows:
System.out.println(spiral(3, 3, 0, 0));
System.out.println(spiral(3, 3, 1, 0));
System.out.println(spiral(3, 3, 2, 0));
System.out.println(spiral(3, 3, 2, 1));
System.out.println(spiral(3, 3, 2, 2));
System.out.println(spiral(3, 3, 1, 2));
System.out.println(spiral(3, 3, 0, 2));
System.out.println(spiral(3, 3, 0, 1));
System.out.println(spiral(3, 3, 1, 1));
Which results in
0
1
2
3
4
5
6
7
8