LeetCode/Knight Dialer

Problem Summary

Place a knight on a phone pad, and let it hop N-1 times between the numbered keys, generating a N-digit number. Find how many distinct numbers we can dial in this manner.

The keypad:

The knight can move in 8 ways.

Solution

Let Xn,i be the number of n-digit numbers ending with digit i.

By observing the keypad, we can draw these conclusions:

  1. A knight can’t make any hop from 5 or to 5. So the only number with digit 5 that we can dial is “5”. That is, X1,5 = 1 and Xn,5 = 0 (n > 1).

  2. 1, 3, 7, 9 are symmetrical (and the 8 moves of the knight are symmetrical). So Xn,1 = Xn,3 = Xn,7 = Xn,9.

  3. Similarly, Xn,2 = Xn,8

  4. Xn,4 = Xn,6

  5. Xn+1,0 = Xn,4 + Xn,6 = 2 × Xn,4

    Xn+1,1 = Xn,2 + Xn,4

    Xn+1,2 = Xn,7 + Xn,9 = 2 × Xn,1

    Xn+1,4 = Xn,0 + Xn,3 + Xn,9 = Xn,0 + 2 × Xn,1

Let Xn = (Xn,0, Xn,1, Xn,2, Xn,4)’, then we have:
Xn+1 = A · Xn

A is a 4 by 4 matrix:

0 0 0 2
0 0 1 1
0 2 0 0
1 2 0 0

The answer is sigma Xn,i (0 ≤ i ≤ 9) = XN,0 + 4 × XN,1 + 2 × XN,2 + 2 × XN,4 (N > 1).

So we can calculate Xn,i recursively, which takes O(N) time. Or use fast power to calculate A^(N-1) and get the answer, which takes O(logN) time.

Code 1: O(N)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution {
public:
const int mo = 1e9 + 7;
int knightDialer(int n) {
long long ans = (n == 1) + helper(n-1, 1, 1, 1, 1);
return ans % mo;
}
int helper(int n, int x0, int x1, int x2, int x4) {
if (n == 0) {
long long res = (long long)x1*4 + x2*2 + x4*2 + x0;
return res % mo;
} else
return helper(n-1, x4*2%mo, (x2 + x4)%mo, x1*2%mo, ((long long)x1*2 + x0)%mo);
}
};

Code 2: O(logN)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Solution {
public:
const int mo = 1e9 + 7;
int knightDialer(int n) {
if (n == 1)
return 10;
vector<vector<int>> a(4, vector<int>(4)), b(4, vector<int>(4));
a[0][3] = a[2][1] = a[3][1] = 2;
a[1][2] = a[1][3] = a[3][0] = b[0][0] = b[1][1] = b[2][2] = b[3][3] = 1;
--n;
while (n > 0) {
if (n&1)
multi(b, a);
multi(a, a);
n >>= 1;
}
long long ans = 0;
for (int i = 0; i < 4; ++i)
for (int j = 0; j < 4; ++j)
if (i == 0)
ans += b[i][j];
else if (i == 1)
ans += (long long)b[i][j]*4%mo;
else
ans += (long long)b[i][j]*2%mo;
return ans % mo;
}
void multi(vector<vector<int>> &a, vector<vector<int>> &b) {
vector<vector<int>> c(4, vector<int>(4));
for (int i = 0; i < 4; ++i)
for (int j = 0; j < 4; ++j) {
long long tmp = 0;
for (int k = 0; k < 4; ++k)
tmp = ((long long)a[i][k] * b[k][j] % mo + tmp) % mo;
c[i][j] = tmp;
}
a = c;
}
};