fork download
  1. #include <bits/stdc++.h>
  2. #define ll long long
  3. #define itachi ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
  4. #define maxn 210000
  5. using namespace std;
  6.  
  7. const int MOD = 1e9 + 7;
  8. ll n, val[maxn], rank_val[maxn], child[maxn], B_array[maxn];
  9. bool banned[maxn];
  10. vector<int> adj[maxn];
  11. ll sum_tree[maxn + 7], cnt_tree[maxn + 7];
  12. int ma_rank = 0;
  13.  
  14. // --- BIT LOGIC ---
  15. void update_bit(int i, ll v, ll c) {
  16. v = (v % MOD + MOD) % MOD;
  17. c = (c % MOD + MOD) % MOD;
  18. for (; i <= n; i += i & -i) {
  19. sum_tree[i] = (sum_tree[i] + v) % MOD;
  20. cnt_tree[i] = (cnt_tree[i] + c) % MOD;
  21. }
  22. }
  23.  
  24. ll query_sum(int i) {
  25. ll res = 0;
  26. for (; i > 0; i -= i & -i) res = (res + sum_tree[i]) % MOD;
  27. return res;
  28. }
  29.  
  30. ll query_cnt(int i) {
  31. ll res = 0;
  32. for (; i > 0; i -= i & -i) res = (res + cnt_tree[i]) % MOD;
  33. return res;
  34. }
  35.  
  36. // --- CENTROID TEMPLATE ---
  37. void countChild(int u, int p) {
  38. child[u] = 1;
  39. for (int v : adj[u]) {
  40. if (v == p || banned[v]) continue;
  41. countChild(v, u);
  42. child[u] += child[v];
  43. }
  44. }
  45.  
  46. int find_centroid(int u, int p, int sz_total) {
  47. for (int v : adj[u]) {
  48. if (v != p && !banned[v] && child[v] > sz_total / 2)
  49. return find_centroid(v, u, sz_total);
  50. }
  51. return u;
  52. }
  53.  
  54. // --- LOGIC BÀI 1 ---
  55. void dfs_B(int u, int p, ll current_B, vector<int>& list_v) {
  56. B_array[u] = current_B;
  57. list_v.push_back(u);
  58. update_bit(rank_val[u], val[u], 1);
  59.  
  60. for (int v : adj[u]) {
  61. if (v != p && !banned[v]) {
  62. ll count = query_cnt(rank_val[v]);
  63. ll sum_gt = (query_sum(n) - query_sum(rank_val[v]) + MOD) % MOD;
  64. ll cross = (val[v] % MOD * count % MOD + sum_gt) % MOD;
  65. ll next_B = (current_B + cross + val[v] % MOD) % MOD;
  66. dfs_B(v, u, next_B, list_v);
  67. }
  68. }
  69. update_bit(rank_val[u], -val[u], -1); // Rollback BIT
  70. }
  71.  
  72. ll calc_F(vector<int>& S) {
  73. sort(S.begin(), S.end(), [](int a, int b) { return val[a] < val[b]; });
  74. ll F = 0, SumSz = 0;
  75. for (int x : S) {
  76. ll s_x = child[x] % MOD;
  77. ll v_x = val[x] % MOD;
  78. F = (F + 2LL * v_x % MOD * s_x % MOD * SumSz % MOD + v_x * s_x % MOD * s_x % MOD) % MOD;
  79. SumSz = (SumSz + s_x) % MOD;
  80. }
  81. return F;
  82. }
  83.  
  84. ll Total_Ans = 0;
  85.  
  86. void process_centroid(int C) {
  87. countChild(C, 0); // Cập nhật lại child size trong cây con hiện tại
  88. int N_C = child[C];
  89.  
  90. ll sum_pairs = 0, part1 = 0, F_sub = 0;
  91. vector<int> all_v;
  92.  
  93. update_bit(rank_val[C], val[C], 1);
  94.  
  95. for (int v : adj[C]) {
  96. if (!banned[v]) {
  97. vector<int> cur_list;
  98. ll count = query_cnt(rank_val[v]);
  99. ll sum_gt = (query_sum(n) - query_sum(rank_val[v]) + MOD) % MOD;
  100. ll cross = (val[v] % MOD * count % MOD + sum_gt) % MOD;
  101. ll next_B = (val[C] % MOD + cross + val[v] % MOD) % MOD;
  102.  
  103. dfs_B(v, C, next_B, cur_list);
  104.  
  105. ll s_v = child[v];
  106. sum_pairs = (sum_pairs + s_v * s_v % MOD) % MOD;
  107. for (int x : cur_list) part1 = (part1 + B_array[x] % MOD * (N_C - s_v) % MOD) % MOD;
  108.  
  109. F_sub = (F_sub + calc_F(cur_list)) % MOD;
  110. all_v.insert(all_v.end(), cur_list.begin(), cur_list.end());
  111. }
  112. }
  113.  
  114. update_bit(rank_val[C], -val[C], -1); // Reset BIT hoàn toàn
  115.  
  116. sum_pairs = (1LL * (N_C - 1) * (N_C - 1) % MOD - sum_pairs + MOD) % MOD * 500000004 % MOD;
  117. part1 = (part1 + val[C] % MOD * (1 - sum_pairs % MOD + MOD) % MOD) % MOD;
  118.  
  119. ll F_total = calc_F(all_v);
  120. ll part2 = (F_total - F_sub + MOD) % MOD * 500000004 % MOD;
  121.  
  122. Total_Ans = (Total_Ans + part1 + part2) % MOD;
  123. }
  124.  
  125. void solve(int u) {
  126. countChild(u, 0);
  127. int root = find_centroid(u, 0, child[u]);
  128.  
  129. process_centroid(root);
  130.  
  131. banned[root] = 1;
  132. for (int v : adj[root]) {
  133. if (!banned[v]) solve(v);
  134. }
  135. }
  136.  
  137. int main() {
  138. itachi
  139. if (!(cin >> n)) return 0;
  140. vector<long long> sorted_vals;
  141. for (int i = 1; i <= n; i++) {
  142. cin >> val[i];
  143. sorted_vals.push_back(val[i]);
  144. }
  145. sort(sorted_vals.begin(), sorted_vals.end());
  146. for (int i = 1; i <= n; i++) {
  147. rank_val[i] = lower_bound(sorted_vals.begin(), sorted_vals.end(), val[i]) - sorted_vals.begin() + 1;
  148. }
  149. for (int i = 1; i < n; i++) {
  150. int u, v; cin >> u >> v;
  151. adj[u].push_back(v); adj[v].push_back(u);
  152. }
  153.  
  154. solve(1);
  155. cout << Total_Ans << "\n";
  156. return 0;
  157. }
Success #stdin #stdout 0.01s 11712KB
stdin
Standard input is empty
stdout
Standard output is empty