Advertisement
peltorator

2-SAT

May 29th, 2023
847
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 1.92 KB | None | 0 0
  1. struct SAT_2 {
  2.     int n;
  3.     vector<vector<int>> g, grev;
  4.  
  5.     SAT_2(int n = 0) : n(n) {
  6.         g.assign(2 * n, {});
  7.         grev.assign(2 * n, {});
  8.     };
  9.  
  10.     void add_clause(bool positive_x, int x, bool positive_y, int y) { // (x v !y) -> (true, x, false, y)
  11.         assert(0 <= x && x < n);
  12.         assert(0 <= y && y < n);
  13.         int vx = 2 * x + (positive_x ? 0 : 1);
  14.         int vy = 2 * y + (positive_y ? 0 : 1);
  15.         g[vx ^ 1].push_back(vy);
  16.         grev[vy].push_back(vx ^ 1);
  17.         g[vy ^ 1].push_back(vx);
  18.         grev[vx].push_back(vy ^ 1);
  19.     }
  20.  
  21.     vector<int> solve() { // return.empty() -> no solution, o/w -> x_i = bool(return[i])
  22.         vector<int> used(2 * n, 0), topsort;
  23.         function<void(int)> dfs_topsort = [&](int v) {
  24.             used[v] = 1;
  25.             for (int u : g[v]) {
  26.                 if (!used[u]) {
  27.                     dfs_topsort(u);
  28.                 }
  29.             }
  30.             topsort.push_back(v);
  31.         };
  32.         for (int v = 0; v < 2 * n; v++) {
  33.             if (!used[v]) {
  34.                 dfs_topsort(v);
  35.             }
  36.         }
  37.         reverse(topsort.begin(), topsort.end());
  38.         vector<int> cols(2 * n, -1);
  39.         function<void(int, int)> dfs_color = [&](int v, int color) {
  40.             cols[v] = color;
  41.             for (int u : grev[v]) {
  42.                 if (cols[u] == -1) {
  43.                     dfs_color(u, color);
  44.                 }
  45.             }
  46.         };
  47.         int curcol = 0;
  48.         for (int v : topsort) {
  49.             if (cols[v] == -1) {
  50.                 dfs_color(v, curcol++);
  51.             }
  52.         }
  53.         for (int v = 0; v < 2 * n; v += 2) {
  54.             if (cols[v] == cols[v + 1]) {
  55.                 return {};
  56.             }
  57.         }
  58.         vector<int> solution(n, 0);
  59.         for (int v = 0; v < 2 * n; v += 2) {
  60.             solution[v >> 1] = (cols[v] > cols[v + 1] ? 1 : 0);
  61.         }
  62.         return solution;
  63.     }
  64. };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement