Advertisement
sleepy_coder

15-Puzzle BnB A* search

Mar 3rd, 2020
268
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 8.35 KB | None | 0 0
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. using ll = long long;
  5. using pii = pair<int, int>;
  6.  
  7. #define     endl               "\n"
  8. #define     fast_io            ios::sync_with_stdio(false); cin.tie(0);
  9. #define     file_io            freopen("input.txt", "r", stdin);   \
  10.                                freopen("output.txt", "w", stdout);
  11. #define     all(x)             begin(x), end(x)
  12. #define     debug(x)           cerr <<"Line "<< __LINE__ <<" : "<< #x " = "<< x <<endl;
  13.  
  14. template<typename T, typename TT>
  15. ostream& operator<<(ostream &os, const pair<T, TT> &t) { return os<<"("<<t.first<<", "<<t.second<<")"; }
  16. template<typename T>
  17. ostream& operator<<(ostream& os, const vector<T> &t) { for(auto& i: t) os<<i<<" "; return os; }
  18.  
  19.  
  20.  
  21. int N;
  22. using ull = unsigned long long int;
  23. using pii = pair<int, int>;
  24.  
  25.  
  26.  
  27.  
  28.  
  29.  
  30. #define LIMIT_DEPTH 60
  31. #define NODE_LIMIT 10000000
  32.  
  33.  
  34.  
  35.  
  36. // done .. hash for mapping vector in c++
  37. struct VectorHash {
  38.     size_t operator()(const vector<int>& v) const {
  39.         hash<int> hasher;
  40.         size_t seed = 0;
  41.         for (int i : v) {
  42.             seed ^= hasher(i) + 0x9e3779b9 + (seed<<6) + (seed>>2);
  43.         }
  44.         return seed;
  45.     }
  46. };
  47.  
  48.  
  49.  
  50.  
  51.  
  52. // done
  53. vector<int> processInput(string line) {
  54.     /// "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0"
  55.     vector <int> tokens;
  56.     stringstream ss(line);
  57.     string intermediate;
  58.     // Tokenizing w.r.t. space ' '
  59.     while(getline(ss, intermediate, ',')) {
  60.         tokens.push_back(stoi(intermediate));
  61.     }
  62.     return tokens;
  63. }
  64.  
  65.  
  66.  
  67.  
  68.  
  69.  
  70.  
  71.  
  72. //done
  73. int countInversion(const vector<int>& a, const vector<int>& b) {
  74.     assert(a.size() == b.size());
  75.     unordered_map<int, int> pos;
  76.     for(int i = 0; i < (int)b.size(); ++i) pos.insert({b[i], i});
  77.     int c(0);
  78.     for(int i = 0; i < (int)a.size()-1; ++i) {
  79.         for(int j = i + 1; j < (int)a.size(); ++j) {
  80.             if(a[i] && a[j] && pos[a[i]] > pos[a[j]]) {
  81.                 ++c;
  82.             }
  83.         }
  84.     }
  85.     return c;
  86. }
  87.  
  88.  
  89.  
  90.  
  91.  
  92.  
  93.  
  94.  
  95.  
  96.  
  97.  
  98.  
  99.  
  100. class NPuzzle {
  101.  
  102.     //done
  103.     struct StateNode {
  104.         StateNode* parent; 
  105.         vector<int> state;
  106.         pii blank;
  107.         pii cx; // cx = g(x) + h(x) //  g_score + heuristic_cost
  108.         stack<StateNode*> children;
  109.        
  110.         //done
  111.         explicit StateNode(const vector<int> s, const pii& prevBlank, const pii& newBlank, const int& fx, StateNode* p) {
  112.             parent = p;
  113.             state = s;
  114.             swap(state[prevBlank.first*N + prevBlank.second], state[newBlank.first*N + newBlank.second]);
  115.             blank = newBlank;
  116.             cx = {fx, INT_MAX};
  117.         }
  118.            
  119.         //done
  120.         friend ostream& operator <<(ostream& out, StateNode* obj) {
  121.             for(int i = 0; i < N; ++i) {
  122.                 for(int j = 0; j < N; ++j) {
  123.                     out << obj->state[i*N + j] << '\t';
  124.                 }
  125.                 cout << endl;
  126.             }
  127.             return out;
  128.         }
  129.        
  130.         //done
  131.         void printPath() {
  132.             if(this->parent) {
  133.                 parent->printPath();
  134.                 if(this->blank.first > parent->blank.first) cout << "D" << endl;
  135.                 else if(this->blank.first < parent->blank.first) cout << "U" << endl;
  136.                 else if(this->blank.second > parent->blank.second) cout << "R" << endl;
  137.                 else cout << "L" << endl;
  138.             }
  139.             cout << (this) << endl;
  140.         }
  141.      
  142.         //done
  143.         ~StateNode() {
  144.             parent = nullptr;
  145.             while(!children.empty()) {
  146.                 StateNode* cur = children.top();
  147.                 children.pop();
  148.                
  149.                 if(cur) delete cur;
  150.                 cur = nullptr;
  151.             }
  152.             state.clear();
  153.         }
  154.     };
  155.    
  156.     //done losed f_score in heap
  157.     struct comp {
  158.         bool operator()(const StateNode* lhs, const StateNode* rhs) const {
  159.             int lCost = lhs->cx.first + lhs->cx.second;
  160.             int rCost = rhs->cx.first + rhs->cx.second;
  161.             return (lCost == rCost)? lhs->cx.first > rhs->cx.first : lCost > rCost;
  162.         }
  163.     };  
  164.  
  165.     //done
  166.     bool isSafe(const pii& cell) {
  167.         return cell.first >= 0 && cell.second >= 0 && cell.first < N && cell.second < N;
  168.     }
  169.    
  170.     //done
  171.     // this is heuristic 1 for h(x)
  172.     int hammingDistance(const vector<int>& curr, const vector<int>& goal) {
  173.         int conflicts(0);
  174.         for(int i = 0; i < (int)curr.size(); ++i) {
  175.             if(curr[i] && curr[i] != goal[i]) ++conflicts;
  176.         }  
  177.         return conflicts;
  178.     }
  179.  
  180.     //done
  181.     // this is heuristic 2 for h(x)
  182.     int manhattanDistance(const vector<int>& curr, const vector<int>& goal) {
  183.         vector<pii> indx(N*N);
  184.         int c(0);
  185.         for(int i = 0; i < N*N; ++i) {
  186.             indx[goal[i]] = {i/N, i%N};
  187.         }
  188.         for(int i = 0; i < N*N; ++i) {
  189.             if(curr[i]) c += abs(i/N - indx[curr[i]].first) + abs(i%N - indx[curr[i]].second);
  190.         }
  191.         return c;
  192.     }
  193.  
  194.     //done
  195.     bool isSolvable(const vector<int>& a, const vector<int>& b, int blankX) {
  196.         int invCount = countInversion(a, b);
  197.         int pos = N - blankX;
  198.         if (N & 1) {
  199.             return !(invCount & 1);
  200.         } else {
  201.             if (pos & 1) return !(invCount & 1);
  202.             else return invCount & 1;
  203.         }
  204.     }
  205.  
  206.     //done
  207.     int findBlankX(const vector<int>& placeMents) {
  208.         for(int i = 0; i < (int)placeMents.size(); ++i) {
  209.             if(!placeMents[i]) {
  210.                 return i/N;
  211.             }
  212.         }
  213.         return -1;
  214.     }
  215.  
  216. public:
  217.  
  218.     void solve15puzzle(const vector<int>& initial, const vector<int>& goal,  const pii& blankCell, bool choice = false) {
  219.  
  220.         vector<int> middle(goal);
  221.         sort(all(middle));
  222.         if(!isSolvable(initial, middle, blankCell.first) || !isSolvable(goal, middle, findBlankX(goal))) {
  223.             cout << "Not Solvable " << endl;
  224.             return;
  225.         } else {
  226.             cout << "Solvable" << endl;
  227.         }
  228.        
  229.         int expanded(0), maxDepth(0);
  230.         vector<pii> dir({{1, 0}, {0, 1}, {-1, 0}, {0, -1}});
  231.        
  232.        
  233.        
  234.        
  235.         ///closedSet all are aleary evaluated
  236.         unordered_set< vector<int>, VectorHash > closedList;
  237.         ///stores g_score of each node in openList with minimum
  238.         unordered_map< vector<int>, int, VectorHash > g_score;
  239.        
  240.         ///list provides node with minimum f_score value
  241.         priority_queue<StateNode*, vector<StateNode*>, comp> openList;
  242.        
  243.        
  244.        
  245.        
  246.        
  247.         StateNode* rootNode = new StateNode(initial, blankCell, blankCell, 0, nullptr);
  248.        
  249.         rootNode->cx.second = choice ? hammingDistance(rootNode->state, goal) : manhattanDistance(rootNode->state, goal);
  250.        
  251.         g_score[rootNode->state] = 0;
  252.        
  253.         openList.push(rootNode);
  254.        
  255.         while(!openList.empty()) {
  256.            
  257.             StateNode* curr = openList.top();
  258.             //cout << curr->cx << endl;
  259.             //success
  260.             if(!curr->cx.second) {
  261.                 cout << curr->cx.first << " steps needed" << endl;
  262.                 cout << expanded << " number of nodes expanded" << endl;
  263.                 curr->printPath();
  264.                 break;
  265.             }
  266.            
  267.             openList.pop();
  268.            
  269.             if(!closedList.count(curr->state))
  270.                 closedList.insert(curr->state);
  271.            
  272.                        
  273.             for(const auto& d : dir) {
  274.                 int x = d.first + curr->blank.first;
  275.                 int y = d.second + curr->blank.second;
  276.                 if(!isSafe({x, y})) continue;
  277.                    
  278.                 int tentative_g_score = curr->cx.first + 1;
  279.                 StateNode* neighbor = new StateNode(curr->state, curr->blank, {x, y}, tentative_g_score, curr);
  280.                 curr->children.push(neighbor);
  281.                 neighbor->cx.second = choice ? hammingDistance(neighbor->state, goal) : manhattanDistance(neighbor->state, goal);
  282.                
  283.                 if(closedList.find(neighbor->state) != closedList.end()) continue;
  284.                
  285.                 if(!g_score.count(neighbor->state) || tentative_g_score < g_score[neighbor->state]) {
  286.                     neighbor->parent = curr;
  287.                     //if(!g_score.count(neighbor->state)) {
  288.                     //  openList.push(neighbor);
  289.                     //}
  290.                     openList.push(neighbor);
  291.                     g_score[neighbor->state] = tentative_g_score;
  292.                    
  293.                 }
  294.             }
  295.            
  296.             curr = nullptr;
  297.         }
  298.        
  299.         delete rootNode;
  300.         rootNode = nullptr;
  301.     }
  302.    
  303. };
  304.  
  305.  
  306. int main(int argc, char** argv) {
  307.     fast_io
  308.    
  309.     /**
  310.     4
  311.     <1, 3, 0, 4, 5, 2, 6, 8, 9, 10, 7, 12, 13, 14, 11, 15>
  312.     <1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 13, 14, 15, 12>
  313.     1
  314.     */
  315.    
  316.     string str;
  317.     getline(cin, str);
  318.     N = stoi(str);
  319.    
  320.     vector<int> initial, goal;
  321.     pii blank;
  322.    
  323.     getline(cin, str);
  324.     initial = processInput(str.substr(1, str.size()-2));
  325.    
  326.     for(int i = 0; i < N*N; ++i) {
  327.         if(!initial[i]) {
  328.             blank = {i/N, i%N};
  329.             break;
  330.         }
  331.     }
  332.    
  333.     getline(cin, str);
  334.     goal = processInput(str.substr(1, str.size()-2));
  335.    
  336.     getline(cin, str);
  337.     bool choice = (str == "1");
  338.    
  339.    
  340.     cout << "input: " << initial << endl;
  341.     cout << "target: " << goal << endl;
  342.     cout << "blank at cell: " << blank << endl << endl;
  343.    
  344.     NPuzzle puzzle;
  345.  
  346.     auto st = chrono::steady_clock::now();
  347.     //solve puzzle
  348.     puzzle.solve15puzzle(initial, goal, blank, choice);
  349.    
  350.     auto et = chrono::steady_clock::now();
  351.     cout << "Execution Time: " << chrono::duration<double, milli>(et-st).count() << "ms" << endl << endl;
  352.  
  353.    
  354.     return 0;
  355. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement