9. 回溯
总体思路是深度优先遍历(DFS)。
9.1. 子集树
子集树大小为 \(\mathcal{O}(m^n)\) ,\(m\) 是树的分支个数( \(m\) 叉树),\(n\) 是树的深度。
算法描述:
1void backtrack(int t)
2{
3 if(t >= n) output(x);
4 else
5 {
6 for(int i = 0; i < m; ++i)
7 {
8 x[t] = i;
9 if(constrain(t) and bound(t)) backtrack(t+1);
10 }
11 }
12}
9.2. 排列树
排列树大小为 \(\mathcal{O}(n!)\) 。
算法描述:
1void backtrack(int t)
2{
3 if(t >= n) output(x);
4 else
5 {
6 for(int k = t; k < n; ++k)
7 {
8 swap(x[t], x[k]);
9 if(constrain(t) and bound(t)) backtrack(t+1);
10 swap(x[t], x[k]);
11 }
12 }
13}
9.3. 0-1背包问题
算法描述:
1void backtrack(int t)
2{
3 if(t >= n)
4 {
5 best_value = curr_value;
6 bext_x = x;
7 return;
8 }
9 else
10 {
11 if(curr_weight + w[t] <= W)
12 {
13 x[t] = 1;
14 curr_weight += w[t]; // 进入左子树
15 curr_value += v[t];
16 backtrack(t+1);
17
18 curr_weight -= w[t]; // 状态恢复
19 curr_value -= v[t];
20 }
21 x[t] = 0;
22 backtrack(t+1); // 进入右子树
23 }
24}
9.4. 八皇后问题
八皇后问题共有 92 组解。
1bool place(int t, int* x)
2{
3 for(int j = 0; j < t; ++j)
4 {
5 if(x[j] == x[t] || abs(j - t) == abs(x[j] - x[t])) return false; // 在同一列或同一斜线上
6 }
7 return true;
8}
9
10void backtrack(int t, int n, int* x, int& sum)
11{
12 if(t == n) ++sum;
13 else
14 {
15 for(int i = 0; i < n; ++i)
16 {
17 x[t] = i;
18 if(place(t, x)) backtrack(t+1, n, x, sum);
19 }
20 }
21}
9.5. 实例
全排列(含重复元素)。Hint:在交换第 \(i\) 个元素与第 \(j\) 个元素之前,要求数组的 \([i, j)\) 区间中的元素没有与第 \(j\) 个元素重复。
https://blog.csdn.net/so_geili/article/details/71078945
\(\color{darkgreen}{Code}\)
1int cnt = 0; // 不同排列的个数 2 3//检查[from,to)之间的元素和第to号元素是否相同 4bool isRepeat(int* A, int from, int to) 5{ 6 for(int i = from; i < to; i++) 7 { 8 if(A[to] == A[i]) return true; 9 } 10 return false; 11} 12 13void permutation(int* A, int t, int n) 14{ 15 if(t == n) 16 { 17 cnt++; 18 Output(A); 19 } 20 else 21 { 22 for(int j = t; j < n; j++) 23 { 24 if(!isRepeat(A, t, j)) 25 { 26 swap(A[t], A[j]); 27 permutation(A, t+1, n); 28 swap(A[t], A[j]); 29 } 30 } 31 } 32}
Next Permutation 下一个排列。Hint:从后往前先找到第一个开始下降的数字 \(x\) (下标 \(i\) ),再从后往前找到第一个比 \(x\) 大的数 \(y\) (下标 \(j\) );交换 \(x\) 和 \(y\) ;翻转区间 \([i+1, end]\) 。
https://www.cnblogs.com/grandyang/p/4428207.html
\(\color{darkgreen}{Code}\)
1class Solution 2{ 3public: 4 void nextPermutation(vector<int> &num) 5 { 6 int i, j, n = num.size(); 7 for (i = n - 2; i >= 0; --i) 8 { 9 if (num[i + 1] > num[i]) 10 { 11 for (j = n - 1; j > i; --j) 12 { 13 if (num[j] > num[i]) break; 14 } 15 swap(num[i], num[j]); 16 reverse(num.begin() + i + 1, num.end()); 17 return; 18 } 19 } 20 reverse(num.begin(), num.end()); // 当前排列是最大的排列,则翻转为最小的排列 21 } 22};
按字典序输出序列 \(1,2,...,n\) 的全排列。Hint:深度优先遍历。
\(\color{darkgreen}{Code}\)
1void DFS(int* arr, bool* used, int n, int t) 2{ 3 if(t == n) 4 { 5 for(int i = 0; i < n; ++i) cout << arr[i] << ends; 6 cout << endl; 7 return; 8 } 9 for(int digit = 1; digit <= n; ++digit) 10 { 11 if(!used[digit - 1]) 12 { 13 used[digit - 1] = true; 14 arr[t] = digit; 15 DFS(arr, used, n, t+1); 16 used[digit - 1] = false; 17 } 18 } 19}
[LeetCode] Permutation Sequence 输出序列 \(1,2,...,n\) 的第 \(k\) 个排列(字典序)。Hint:方法一,按字典序深度优先遍历;方法二,逐步缩小搜索范围,如: \(perm [ 1,2,3 ] = \{1 + perm [ 2,3 ] \} + \{2 + perm [ 1,3 ] \} + \{3 + perm [ 1,2 ] \}\) 。
https://leetcode.com/problems/permutation-sequence/
\(\color{darkgreen}{Code}\)
1// https://leetcode.com/problems/permutation-sequence/discuss/22507/%22Explain-like-I'm-five%22-Java-Solution-in-O(n) 2 3class Solution 4{ 5public: 6 string getPermutation(int n, int k) 7 { 8 string nums = ""; 9 vector<int> factorial(n+1, 1); 10 for(int i = 1; i <= n; ++i) 11 { 12 nums += to_string(i); 13 factorial[i] = i; 14 } 15 partial_sum(factorial.begin(), factorial.end(), factorial.begin(), multiplies<int>()); // f(n) = n!, f(0) = 1 16 17 string res = ""; 18 while(n) 19 { 20 int id = (k - 1) / factorial[n-1]; // k - 1,下标从 0 开始 21 res += nums[id]; 22 nums.erase(nums.begin() + id); // 得到 n - 1 个数的序列 23 k -= id * factorial[n-1]; // 在 n - 1 个数的序列中继续查找第 k - id * factorial[n-1] 个排列 24 --n; 25 } 26 return res; 27 } 28};
输出序列 \(1,2,...,n\) 的所有子集(组合),共 \(2^n\) 组。Hint:方法一,回溯,二叉子集树;方法二,递归,序列每增加一个数,组合数增加一倍,增加的这些组合是在之前的组合的基础上插入该数得到的; 方法三,当 \(n < 32\) ,可以使用一个 int 型的变量 \(k\) ( \(1 \leqslant k \leqslant 2^n\) )来表示组合的状态,当该变量的二进制表示的第 \(i\) 位为 1,则表示当前组合中包含数字 \(i\) 。
\(\color{darkgreen}{Code}\)
1// 方法一,回溯 2 3void backtrack(int n, vector<int>& tmp, vector<vector<int>>& res) 4{ 5 if (n == 0) 6 { 7 res.push_back(tmp); 8 return; 9 } 10 backtrack(n - 1, tmp, res); // 不包含 n 11 tmp.push_back(n); 12 backtrack(n - 1, tmp, res); // 包含 n 13 tmp.pop_back(); 14} 15 16vector<vector<int>> combination(int n) 17{ 18 assert(n > 0); 19 vector<vector<int>> res; 20 vector<int> tmp; 21 backtrack(n, tmp, res); 22 return res; 23}
1// 方法二,递归 2 3void combinationRecursive(int n, vector<vector<int>>& res) 4{ 5 if (n == 1) 6 { 7 res[1].push_back(1); 8 return; 9 } 10 11 combinationRecursive(n - 1, res); 12 13 int pre_num = pow(2, n - 1); // 在 1 ~ n-1 的组合上插入数字 n 14 for (int i = 0; i < pre_num; ++i) 15 { 16 res[i + pre_num].push_back(n); 17 for (int j = 0; j < res[i].size(); ++j) 18 { 19 res[i + pre_num].push_back(res[i][j]); 20 } 21 } 22} 23 24vector<vector<int>> combination(int n) 25{ 26 assert(n > 0); 27 int num = pow(2, n); 28 vector<vector<int>> res(num, vector<int>{}); 29 combinationRecursive(n, res); 30 return res; 31}
1// 方法三,统计二进制中 1 的个数 2 3vector<vector<int>> combination(int n) 4{ 5 assert(n > 0); 6 int num = pow(2, n); 7 vector<vector<int>> res(num, vector<int>{}); 8 int k = num - 1; 9 while (k >= 0) 10 { 11 int pos = n - 1; 12 while (pos >= 0) 13 { 14 if (k & (1 << pos)) res[k].push_back(pos + 1); 15 --pos; 16 } 17 --k; 18 } 19 return res; 20}
输出整数集合的所有组合(包含重复元素)。Hint:统计每个元素的频率 \(f\) ,在组合过程中,该元素可取的个数最少为零个,最多为 \(f\) 个;回溯。
https://leetcode.com/problems/subsets-ii/
\(\color{darkgreen}{Code}\)
1from collections import Counter 2class Solution: 3 def backtrack(self, ints: List[int], freqs: List[int], tmp: List[int], res:List[List[int]], t:int): 4 if t == len(ints): 5 res.append(tmp[:]) ## 注意:这里必须是添加tmp的副本到res中,否则随着tmp改变,res中的元素也会改变 6 return 7 for k in range(freqs[t] + 1): 8 tmp.extend([ints[t]] * k) 9 self.backtrack(ints, freqs, tmp, res, t+1) 10 if tmp: 11 for _ in range(k): tmp.pop() 12 def subsetsWithDup(self, nums: List[int]) -> List[List[int]]: 13 cnt = Counter(nums) 14 ints, freqs = list(cnt.keys()), list(cnt.values()) ## python3 中需要把 dict_keys、dict_values 类型转换为 list 15 res = [] 16 tmp = [] 17 self.backtrack(ints, freqs, tmp, res, 0) 18 return res
[LeetCode] Distinct Subsequences II 子序列个数(含重复元素的组合数)。Hint:方法一,动态规划,设 \(dp[k]\) 是以 \(S[k]\) 结尾的子序列个数, 如果不考虑重复,则 \(dp[k] = dp[0] + dp[1] + \cdots + dp[k-1] + 1\) ,即在前面的子序列末尾追加 \(S[k]\) ,或 \(S[k]\) 单独构成的子序列( \(+1\) ); 然而要减掉以 \(S[k]\) 结尾的重复子序列: \(dp[k]\ -= dp[r],\ 0 \leqslant r < k \ \&\& \ S[k]=S[r]\) ; 方法二,回溯:设当前子序列集合最后一个元素的下标为 \(i\) ,在把当前字符(设下标为 \(t\) )加入子序列集合时, 需要考虑区间 \((i, t)\) (如果当前子序列集合为空,区间为 \([0, t)\) )内是否有 \(S[t]\) 的重复元素,如果有,则不能把 \(S[t]\) 插入当前子序列中,否则就造成重复;回溯方法严重超时。
https://leetcode.com/problems/distinct-subsequences-ii
\(\color{darkgreen}{Code}\)
1// 方法一 2 3class Solution 4{ 5public: 6 int distinctSubseqII(string S) 7 { 8 if(S.empty()) return 0; 9 vector<long long> dp(S.size(), 0); 10 dp[0] = 1; 11 for(size_t i = 1; i < S.size(); ++i) 12 { 13 dp[i] = accumulate(dp.begin(), dp.begin() + i, 1LL); // + 1,这里的 1LL 表示 long long int,默认的 int 型导致溢出,结果错误 14 for(size_t k = 0; k < i; ++k) 15 { 16 if(S[k] == S[i]) // 减去重复 17 { 18 dp[i] -= dp[k]; 19 while(dp[i] < 0) dp[i] += 1000000007; // 减法操作可能会使得 dp[i] < 0 20 } 21 } 22 dp[i] = dp[i] % 1000000007; 23 } 24 return accumulate(dp.begin(), dp.end(), 0LL) % 1000000007; // 0LL 25 } 26};
1// 方法一改进型 2 3// 设 dp[l] 是以 S[l] 结尾的不重复子序列个数(定义与上面的方法一相同), 4// 设 end[i] 是以字符 'a' + i 结尾的子序列个数,0 <= i < 26,S[l] = 'a' + i, 5// 如果该字符出现在多个位置,如 {j,k,l},则 end[i] = dp[j] + dp[k] + dp[l], 6// 由方法一可知:dp[l] = \sum_{m=0}^{l-1} dp[m] + 1 - dp[j] - dp[k], 7// 因此 end[i] = \sum_{m=0}^{l-1} dp[m] + 1 = \sum_{n=0}^25 end[n] + 1 8 9class Solution 10{ 11public: 12 int distinctSubseqII(string S) 13 { 14 if(S.empty()) return 0; 15 long long end[26] = {0}; 16 for(size_t i = 0; i < S.size(); ++i) 17 { 18 end[S[i] - 'a'] = accumulate(end, end + 26, 1LL) % 1000000007; 19 } 20 return accumulate(end, end + 26, 0LL) % 1000000007; 21 } 22};
1// 方法二 2 3class Solution 4{ 5public: 6 int distinctSubseqII(string S) 7 { 8 if(S.empty()) return 0; 9 int ans = 0; 10 vector<int> subS; // 当前子序列集合 11 DFS(S, 0, subS, ans); 12 return ans; 13 } 14private: 15 bool hasRepeat(string& S, vector<int>& subS, int t) 16 { 17 bool repeat = false; 18 size_t i; 19 if(subS.empty()) i = 0; 20 else i = subS.back() + 1; 21 for(; i < t; ++i) 22 { 23 if(S[i] == S[t]) 24 { 25 repeat = true; 26 break; 27 } 28 } 29 return repeat; 30 } 31 void DFS(string& S, int t, vector<int>& subS, int &ans) 32 { 33 if(t == S.size()) 34 { 35 if(!subS.empty()) ans = (ans + 1) % 1000000007; 36 return; 37 } 38 DFS(S, t+1, subS, ans); // 当前子序列集合不包括 S[t] 39 if(!hasRepeat(S, subS, t)) // 区间 (i, t) (或 [0, t))内不包括 S[t] 的重复字符,才可以把 S[t] 加入当前子序列集合 40 { 41 subS.push_back(t); 42 DFS(S, t+1, subS, ans); 43 subS.pop_back(); 44 } 45 } 46};
Word search 查找字符串路径。
https://leetcode.com/problems/word-search/
\(\color{darkgreen}{Code}\)
1class Solution { 2public: 3 bool findPath(vector<vector<char>>& board, string word, bool** flag, int x, int y, int k) 4 { 5 if(k == word.size()) return true; 6 for(int t = 0; t < 4; ++t) 7 { 8 int tx = x + mv[t][0]; 9 int ty = y + mv[t][1]; 10 11 if(flag[tx+1][ty+1] && board[tx][ty] == word[k]) 12 { 13 flag[tx+1][ty+1] = false; // 设置 flag 14 if(findPath(board, word, flag, tx, ty, k+1)) return true; 15 flag[tx+1][ty+1] = true; // flag 还原 16 } 17 18 } 19 return false; 20 } 21 bool exist(vector<vector<char>>& board, string word) { 22 if(word == "") return true; 23 if(board.size()==0) return false; 24 int M = board.size(); 25 int N = board[0].size(); 26 bool** flag = new bool*[M+2]; // 设置一圈边界,标记为 false,后面访问 board 中的 4 个领域不用再判断是否越界;flag 的大小为 (M+2)x(N+2) 27 for(int m = 0; m < M+2; ++m) 28 { 29 flag[m] = new bool[N+2]; 30 for(int n = 0; n < N+2; ++n) 31 { 32 if(m == 0 || m == M+1 || n == 0 || n == N+1) flag[m][n] = false; 33 else flag[m][n] = true; 34 } 35 } 36 bool EXIST = false; 37 for(int i = 0; i < M; ++i) 38 { 39 for(int j = 0; j < N; ++j) 40 { 41 if(board[i][j] == word[0]) 42 { 43 flag[i+1][j+1] = false; // 注意: flag 的下标与 board 相差 1 44 if(findPath(board, word, flag, i, j, 1)) 45 { 46 EXIST = true; 47 break; // 跳出第二重循环 48 } 49 flag[i+1][j+1] = true; // flag 还原 50 } 51 } 52 if(EXIST) break; // 跳出第一重循环 53 } 54 55 for(int m = 0; m < M+2; ++m) delete[] flag[m]; 56 delete[] flag; 57 58 return EXIST; 59 } 60private: 61 static const int mv[4][2]; 62}; 63 64const int Solution::mv[4][2] = {{-1,0},{0,-1},{0,1},{1,0}};
Knuth-Shuffle,公平的洗牌算法:生成每一种排列的概率都是 \(\frac{1}{n!}\)。
\(\color{darkgreen}{Code}\)
1void shuffle(int* arr, int n) 2{ 3 for(int i = n - 1; i >= 0; --i) 4 { 5 swap(arr[i], arr[rand()%(i+1)]); 6 } 7}