数位dp一般采用记忆化搜索写起来更加清晰
对于求大数区间计数,并且要求满足一行数位性质时采用数位dp。
一般对于一个位nums[u],分成0~nums[u] - 1 和 nums[u]两种情况考虑,其中的lim(limit)表示该位是否受限,若不受限则能取0~9,否则只能取0~nums[u]。
例题
定义$dp[u][m1][m2][m3][l1][l2][p1][p2][p3]$为取到第$u$位,取模$a1,a2,a3$的余数分别为$m1,m2,m3$,三个数是否受限($l1, l2, l3$), 三个数是否取过1,因为这题答案为0满足但不在题意范围内。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
|
#include <bits/stdc++.h> #define endl '\n' #define int long long using namespace std; using ll = long long;
const int N = 66, mod = 998244353; int dp[N][15][15][15][2][2][2][2][2][2]; int n, a1, a2, a3; vector<int> nums;
bool multi = 0;
int dfs(int u, int m1, int m2, int m3, bool l1, bool l2, bool l3, bool p1, bool p2, bool p3) { if(u == -1 && !m1 && !m2 && !m3 && p1 && p2 && p3) return 1; else if(u == -1) return 0; auto &w = dp[u][m1][m2][m3][l1][l2][l3][p1][p2][p3]; if(~w) return w; int res = 0; int t1 = l1 ? nums[u] : 1, t2 = l2 ? nums[u] : 1, t3 = l3 ? nums[u] : 1; for(int i1 = 0; i1 <= t1; i1++) { for(int i2 = 0; i2 <= t2; i2++) { int i3 = i1 ^ i2; if(i3 > t3) continue; res = (res + dfs(u - 1, (m1 * 2 + i1) % a1, (m2 * 2 + i2) % a2, (m3 * 2 + i3) % a3, l1 & (i1 == nums[u]), l2 & (i2 == nums[u]), l3 & (i3 == nums[u]), i1 | p1, i2 | p2, i3 | p3)) % mod; } } return w = res; }
void solve() { cin >> n >> a1 >> a2 >> a3; while(n) nums.push_back(n % 2), n /= 2; memset(dp, -1, sizeof dp); cout << dfs((int)nums.size() - 1, 0, 0, 0, 1, 1, 1, 0, 0, 0) << '\n'; }
signed main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); int T = 1; if (multi) cin >> T; while (T--) { solve(); } return 0; }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
#include <bits/stdc++.h> #define endl '\n' #define int long long using namespace std; using ll = long long;
bool multi = 1;
const int N = 67, mod = 1e9 + 7; vector<int> nums;
int p10[20];
void init() { p10[0] = 1; for(int i = 1; i <= 20; i++) { p10[i] = (p10[i - 1] * 10) % mod; } }
struct Node { int sum, sum2, cnt; };
Node dp[N][20][75][2];
Node dfs(int u, int m1, int m2, bool lim) { if(u == -1) { if(m1 && m2) return (Node){0, 0, 1}; else return (Node){0, 0, 0}; } auto &w = dp[u][m1][m2][lim]; if(~w.sum) return w; Node res = {0, 0, 0}; int tot = lim ? nums[u] : 9; for(int i = 0; i <= tot; i++) { if(i == 7) continue; auto t = dfs(u - 1, (m1 + i) % 7, (m2 * 10 + i) % 7, lim & (i == nums[u])); res.sum = (res.sum + i * t.cnt % mod * p10[u] % mod + t.sum) % mod, res.cnt = (res.cnt + t.cnt) % mod; res.sum2 = (res.sum2 + t.cnt * (i * p10[u] % mod) % mod * (i * p10[u] % mod) % mod + 2 * i * p10[u] % mod * t.sum % mod + t.sum2) % mod; } return w = res; }
Node get(int x) { nums.clear(); memset(dp, -1, sizeof dp); while(x) nums.push_back(x % 10), x /= 10; auto t = dfs((int)nums.size() - 1, 0, 0, 1); return t; }
void solve() { int l, r; cin >> l >> r; cout << ((get(r).sum2 - get(l - 1).sum2) % mod + mod) % mod<< '\n'; }
signed main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
init();
int T = 1; if (multi) cin >> T; while (T--) { solve(); }
return 0; }
|