Ridwanul_Haque

A* (Artificial Intelligence)

Nov 30th, 2021
1,049
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <chrono>
  2.  
  3. #define MAX 105
  4.  
  5. #include <set>
  6.  
  7. #ifndef NPUZZLE_ASTARSEARCH_H
  8. #define NPUZZLE_ASTARSEARCH_H
  9.  
  10. #define MANHATTAN_DISTANCE 1
  11. #define HAMMING_DISTANCE 2
  12. #define LINEAR_CONFLICT 3
  13.  
  14. #define LIMIT_DEPTH 60
  15. #define NODE_LIMIT 10000000
  16.  
  17. #define cost_ cost
  18. #define parent_ parent
  19.  
  20. #ifndef NPUZZLE_NODE_H
  21. #define NPUZZLE_NODE_H
  22. #define BOARD_SQ_SIZE 3
  23. #define PRINT_W 3
  24.  
  25. //#include<bits/stdc++.h>
  26. #include <iostream>
  27. #include <iomanip>
  28. #include <string>
  29. #include <vector>
  30. #include <map>
  31. #include <queue>
  32. #include <algorithm>
  33. #include <cmath>
  34. #include <cstdlib>
  35. #include <cstring>
  36.  
  37. #define RIGHT 0
  38. #define LEFT 1
  39. #define DOWN 2
  40. #define UP 3
  41.  
  42. typedef int direction_t;
  43. typedef int8_t puzzle_t;
  44.  
  45. using namespace std;
  46.  
  47. int dirX[4] = {0, 0, 1, -1}; // RIGHT-LEFT-DOWN-UP
  48. int dirY[4] = {1, -1, 0, 0}; // RIGHT-LEFT-DOWN-UP
  49.  
  50.  
  51.  
  52. class Node
  53. {
  54. public:
  55.     puzzle_t **A = nullptr;
  56.     bool emptyNode = true;
  57.     static int boardSqSize;
  58.  
  59.     friend ostream &operator<<(ostream &os, const Node &node);
  60.  
  61.     Node()
  62.     {
  63.         emptyNode = true;
  64.         A = new puzzle_t *[boardSqSize];
  65.         for (int i = 0; i < boardSqSize; ++i)
  66.         {
  67.             A[i] = new puzzle_t[boardSqSize];
  68.             memset(A[i], 0, boardSqSize * sizeof(A[0][0]));
  69.         }
  70.     }
  71.  
  72.  
  73.     Node(const Node &node)
  74.     {
  75.         this->~Node();
  76. //      emptyNode = false;
  77.         this->emptyNode = node.emptyNode;
  78.         A = new puzzle_t *[boardSqSize];
  79.         for (int i = 0; i < boardSqSize; ++i)
  80.         {
  81.             A[i] = new puzzle_t[boardSqSize];
  82.         }
  83.         for (int i = 0; i < boardSqSize; i++)
  84.         {
  85.             for (int j = 0; j < boardSqSize; j++)
  86.             {
  87.                 A[i][j] = node.A[i][j];
  88.             }
  89.         }
  90.     }
  91.  
  92.     Node &operator=(const Node &node)
  93.     {
  94.         this->~Node();
  95.         this->emptyNode = node.emptyNode;
  96.         A = new puzzle_t *[boardSqSize];
  97.         for (int i = 0; i < boardSqSize; ++i)
  98.         {
  99.             A[i] = new puzzle_t[boardSqSize];
  100.         }
  101.         for (int i = 0; i < boardSqSize; i++)
  102.         {
  103.             for (int j = 0; j < boardSqSize; j++)
  104.             {
  105.                 A[i][j] = node.A[i][j];
  106.             }
  107.         }
  108.         return *this;
  109.     }
  110.  
  111.     ~Node()
  112.     {
  113.         if (A == nullptr) return;
  114.         for (int i = 0; i < boardSqSize; ++i)
  115.         {
  116.             delete A[i];
  117.         }
  118.         delete[] A;
  119.         A = nullptr;
  120.     }
  121.  
  122.     bool operator==(const Node &right) const
  123.     {
  124.         for (int i = 0; i < boardSqSize; i++)
  125.             for (int j = 0; j < boardSqSize; j++)
  126.                 if (A[i][j] != right.A[i][j]) return false;
  127.         return true;
  128.     }
  129.  
  130.     bool operator!=(const Node &right) const
  131.     {
  132.         return !(*this == right);
  133.     }
  134.  
  135.     bool operator<(const Node &right) const
  136.     {
  137.         for (int i = 0; i < boardSqSize; i++)
  138.         {
  139.             for (int j = 0; j < boardSqSize; j++)
  140.             {
  141.                 if (A[i][j] < right.A[i][j]) return true;
  142.                 else if (A[i][j] == right.A[i][j]) continue;
  143.                 else return false;
  144.             }
  145.         }
  146.         return false;
  147.     }
  148.  
  149.     bool isSolveAble()
  150.     {
  151.  
  152.         int blank_row_no = -1;
  153.         vector<int> arr;
  154.         for (int i = 0; i < boardSqSize; i++)
  155.             for (int j = 0; j < boardSqSize; j++)
  156.             {
  157.                 if (A[i][j])
  158.                     arr.push_back(A[i][j]);
  159.                 else
  160.                     blank_row_no = i;
  161.             }
  162.         int invCount = getInvCount(arr);
  163.         bool boardSizeOdd = static_cast<bool>(boardSqSize & 1);
  164. //      cout << boardSizeOdd << " " << blank_row_no << " " << invCount << endl;
  165.         if (boardSizeOdd && !(invCount & 1)) // odd-board & even-inversions
  166.             return true;
  167.         else if (!boardSizeOdd && ((blank_row_no + getInvCount(arr)) & 1)) // even-board & odd-sum
  168.             return true;
  169.         return false;
  170.     }
  171.  
  172.  
  173.     static int getInvCount(const vector<int> &arr)
  174.     {
  175.         int inv_count = 0;
  176.         for (int i = 0; i < arr.size() - 1; i++)
  177.             for (int j = i + 1; j < arr.size(); j++)
  178.                 if (arr[i] > arr[j])
  179.                     inv_count++;
  180.  
  181.         return inv_count;
  182.     }
  183.  
  184.     // not works donno why not
  185.     Node getNode(int direction, int zX = -1, int zY = -1)
  186.     {
  187.         if (A == nullptr || direction > 3)
  188.             return *this;
  189.  
  190.         if (zX == -1 || zY == -1)
  191.         {
  192.             if (!getZeroPos(*this, zX, zY))
  193.                 return Node();
  194.         }
  195.  
  196.         int zXnew = zX + dirX[direction];
  197.         int zYnew = zY + dirY[direction];
  198.  
  199.         if (zXnew < 0 || zYnew < 0 || zXnew >= Node::boardSqSize || zYnew >= Node::boardSqSize)
  200.             return Node();
  201.  
  202.         Node v = *this;
  203. //      cout << v;
  204.         swap(v.A[zX][zY], v.A[zXnew][zYnew]);
  205.         return v;
  206.     }
  207.  
  208.     static bool getZeroPos(const Node &node, int &zX, int &zY)
  209.     {
  210.         zX = zY = -1;
  211.         for (int i = 0; i < Node::boardSqSize; i++)
  212.         {
  213.             for (int j = 0; j < Node::boardSqSize; j++)
  214.                 if (!node.A[i][j])
  215.                 {
  216.                     zX = i, zY = j;
  217.                     return true;
  218.                 }
  219.         }
  220.         return false;
  221.     }
  222.  
  223.     static int oppositeDirection(int direction)
  224.     {
  225.         switch (direction)
  226.         {
  227.         case LEFT:
  228.             return RIGHT;
  229.         case RIGHT:
  230.             return LEFT;
  231.         case UP:
  232.             return DOWN;
  233.         case DOWN:
  234.             return UP;
  235.         default:
  236.             return EOF;
  237.         }
  238.     }
  239.  
  240.     bool isEmptyNode() const
  241.     {
  242.         return emptyNode;
  243.     }
  244. };
  245.  
  246. int Node::boardSqSize = 0;
  247.  
  248. ostream &operator<<(ostream &os, const Node &node)
  249. {
  250.     if (!node.A) return os;
  251.     for (int i = 0; i < Node::boardSqSize; i++)
  252.     {
  253.         for (int j = 0; j < Node::boardSqSize; j++)
  254.             if (node.A[i][j])
  255.                 os << setw(PRINT_W) << (static_cast<int>(node.A[i][j])) << " ";
  256.             else
  257.                 os << setw(PRINT_W) << "  " << " ";
  258.         os << endl;
  259.     }
  260.     os << " ----------- " << std::endl;
  261.     return os;
  262. }
  263.  
  264. #endif //NPUZZLE_NODE_H
  265.  
  266.  
  267. typedef int cost_t;
  268. typedef int parent_t;
  269.  
  270.  
  271. struct NodeInfo
  272. {
  273.     bool isClosed;
  274.     cost_t cost;
  275.     parent_t parent;
  276.  
  277.     bool operator==(const NodeInfo &rhs) const
  278.     {
  279.         return parent == rhs.parent &&
  280.                cost == rhs.cost;
  281.     }
  282.  
  283.     bool operator!=(const NodeInfo &rhs) const
  284.     {
  285.         return !(rhs == *this);
  286.     }
  287. };
  288.  
  289.  
  290. class aStarSearch
  291. {
  292. public:
  293.     map<Node, NodeInfo> visited;//
  294.  
  295.     size_t openedCount;
  296.     int max_depth;
  297.     int nPushed;
  298.  
  299.     int heuristicType = 0;
  300.  
  301.     bool isValid(int x, int y)
  302.     {
  303.         return x >= 0 && y >= 0 && x < Node::boardSqSize && y < Node::boardSqSize;
  304.     }
  305.  
  306.     static double HammingDistance(const Node &a, const Node &b)
  307.     {
  308.         int conflicts = 0;
  309.         for (int i = 0; i < Node::boardSqSize; i++)
  310.             for (int j = 0; j < Node::boardSqSize; j++)
  311.                 if (a.A[i][j] && a.A[i][j] != b.A[i][j])conflicts++;
  312.         return conflicts;
  313.     }
  314.  
  315.     static double ManHattan(const Node &a, const Node &b)
  316.     {
  317.         int sum = 0;
  318.         puzzle_t pR[(Node::boardSqSize * Node::boardSqSize) + 1];
  319.         puzzle_t pC[(Node::boardSqSize * Node::boardSqSize) + 1];
  320.         for (int r = 0; r < Node::boardSqSize; r++)
  321.         {
  322.             for (int c = 0; c < Node::boardSqSize; c++)
  323.             {
  324.                 pR[a.A[r][c]] = static_cast<puzzle_t>(r);
  325.                 pC[a.A[r][c]] = static_cast<puzzle_t>(c);
  326.             }
  327.         }
  328.         for (int r = 0; r < Node::boardSqSize; r++)
  329.             for (int c = 0; c < Node::boardSqSize; c++)
  330.                 if (b.A[r][c])
  331.                     sum += abs(pR[b.A[r][c]] - r) + abs(pC[b.A[r][c]] - c);
  332.         return sum;
  333.     }
  334.  
  335.     static double nLinearConflicts(const Node &a, const Node &b)
  336.     {
  337.         int conflicts = 0;
  338.         puzzle_t pR[(Node::boardSqSize * Node::boardSqSize) + 1];
  339.         puzzle_t pC[(Node::boardSqSize * Node::boardSqSize) + 1];
  340.         for (int r = 0; r < Node::boardSqSize; r++)
  341.         {
  342.             for (int c = 0; c < Node::boardSqSize; c++)
  343.             {
  344.                 pR[a.A[r][c]] = static_cast<puzzle_t>(r);
  345.                 pC[a.A[r][c]] = static_cast<puzzle_t>(c);
  346.             }
  347.         }
  348.  
  349.         // row conflicts - @checked_okay
  350.         for (int r = 0; r < Node::boardSqSize; r++)
  351.         {
  352.             for (int cl = 0; cl < Node::boardSqSize; cl++)
  353.             {
  354.                 for (int cr = cl + 1; cr < Node::boardSqSize; cr++)
  355.                 {
  356.                     if (b.A[r][cl] && b.A[r][cr] && r == pR[b.A[r][cl]] && pR[b.A[r][cl]] == pR[b.A[r][cr]] &&
  357.                             pC[b.A[r][cl]] > pC[b.A[r][cr]])
  358.                     {
  359.                         conflicts++;
  360. //                      cout << b.A[r][cl] << " " << b.A[r][cr] << endl;
  361. //                      cout << pC[b.A[r][cl]] << " " << pC[b.A[r][cr]] << endl;
  362.                     }
  363.                 }
  364.             }
  365.         }
  366.  
  367.         // column conflicts -
  368.         for (int c = 0; c < Node::boardSqSize; c++)
  369.         {
  370.             for (int rU = 0; rU < Node::boardSqSize; rU++)
  371.             {
  372.                 for (int rD = rU + 1; rD < Node::boardSqSize; rD++)
  373.                 {
  374.                     if (b.A[rU][c] && b.A[rD][c] && c == pC[b.A[rU][c]] && pC[b.A[rU][c]] == pC[b.A[rD][c]] &&
  375.                             pR[b.A[rU][c]] > pR[b.A[rD][c]])
  376.                     {
  377.                         conflicts++;
  378. //                      cout << b.A[rU][c] << " " << b.A[rD][c] << endl;
  379. //                      cout << pR[b.A[rU][c]] << " " << pR[b.A[rD][c]] << endl;
  380.                     }
  381.                 }
  382.             }
  383.         }
  384.  
  385.         return conflicts;
  386.     }
  387.  
  388.     static double LinearConflicts(const Node &a, const Node &b)
  389.     {
  390.         return ManHattan(a, b) + 2 * nLinearConflicts(a, b);
  391.     }
  392.  
  393.     double Heuristic(const Node &a, const Node &b)
  394.     {
  395.         if (heuristicType == HAMMING_DISTANCE) return HammingDistance(a, b);
  396.         if (heuristicType == MANHATTAN_DISTANCE) return ManHattan(a, b);
  397.         if (heuristicType == LINEAR_CONFLICT) return LinearConflicts(a, b);
  398.         return 0;
  399.     }
  400.  
  401.     int AStarSearch(const Node &Start, const Node &Goal)
  402.     {
  403.         int nExpanded = 0;
  404.         max_depth = 0;
  405.         nPushed = 0;
  406.  
  407.         priority_queue<pair<double, Node> > openList;
  408.         openList.push({0, Start});
  409.         visited[Start] = {false, 0, EOF};
  410.  
  411.         while (!openList.empty())
  412.         {
  413.             Node u = openList.top().second;
  414.             openList.pop();
  415.             ++nExpanded;
  416.             NodeInfo &uInfo = visited[u];
  417.             uInfo.isClosed = true;
  418.  
  419.             max_depth = max(max_depth, visited[u].cost);
  420.  
  421.             if (u == Goal)
  422.             {
  423.                 break;
  424.             }
  425.  
  426.             if (uInfo.cost > LIMIT_DEPTH)
  427.             {
  428.                 cout << "Height limit Exceeded @" << endl << u;
  429.                 break;
  430.             }
  431.  
  432.  
  433.             if (visited.size() > NODE_LIMIT)
  434.             {
  435.                 cout << "Node limit Exceeded @" << endl << u;
  436.                 break;
  437.             }
  438.  
  439.             int zX = -1, zY = -1;
  440.             Node::getZeroPos(u, zX, zY);
  441.  
  442.             for (direction_t dir = 0; dir < 4; dir++)
  443.             {
  444.                 int zXnew = zX + dirX[dir];
  445.                 int zYnew = zY + dirY[dir];
  446.                 if (isValid(zXnew, zYnew))
  447.                 {
  448.                     Node v = u;
  449.                     swap(v.A[zX][zY], v.A[zXnew][zYnew]);
  450.  
  451.                     bool isVisited = visited.find(v) != visited.end();
  452.                     if (isVisited && visited[v].isClosed)continue;
  453.  
  454.                     double newCost = uInfo.cost + 1;
  455.                     if (!isVisited || newCost < visited[v].cost)   //2nd condition might not be needed
  456.                     {
  457.                         ++nPushed;
  458.                         visited[v] = {false, static_cast<cost_t>(newCost), Node::oppositeDirection(dir)};
  459.                         double Priority = newCost + Heuristic(v, Goal);
  460.                         openList.push({-Priority, v});
  461.                     }
  462.                 }
  463.             }
  464.  
  465.         }
  466.         openedCount = visited.size();
  467.         return nExpanded;
  468.     }
  469.  
  470.     void setHeuristic(int heuristic = MANHATTAN_DISTANCE)
  471.     {
  472.         heuristicType = heuristic;
  473.     }
  474.  
  475.     virtual ~aStarSearch()
  476.     {
  477.         heuristicType = 0;
  478.         visited.clear();
  479.     }
  480. };
  481.  
  482. #endif //NPUZZLE_ASTARSEARCH_H
  483.  
  484.  
  485. void printSolution(aStarSearch &starSearch, const Node &Start, const Node &Goal)
  486. {
  487.     auto now = Goal;
  488.  
  489.     //print soln
  490.     vector<Node> Path;
  491.     while (starSearch.visited[now].parent_ != EOF)
  492.     {
  493.         Path.push_back(now);
  494.         now = now.getNode(starSearch.visited[now].parent_);
  495.     }
  496.     Path.push_back(Start);
  497.     reverse(Path.begin(), Path.end());
  498.     for (auto &i : Path) cout << i;
  499. }
  500.  
  501. void executeSearch(const Node &Start, const Node &Goal, int heuristic, bool printSol = true)
  502. {
  503.  
  504.     auto *starSearch = new aStarSearch();
  505.     starSearch->setHeuristic(heuristic);
  506.     auto startTime = chrono::steady_clock::now();
  507.     int nExpanded = starSearch->AStarSearch(Start, Goal);
  508.     auto endTime = chrono::steady_clock::now();
  509.  
  510.     auto diff = endTime - startTime;
  511.     cout << "No of Steps: " << (int) starSearch->visited[Goal].cost_ << endl;
  512.     cout << "No of Nodes Expanded: " << nExpanded << endl;
  513.     cout << "No of Nodes Opened: " << starSearch->openedCount << endl;
  514.     cout << "No of Nodes Pushed: " << starSearch->nPushed << endl;
  515.     cout << "Max Depth Reached: " << starSearch->max_depth << endl;
  516.     cout << "Execution Time: " << chrono::duration<double, milli>(diff).count() << "ms" << endl;
  517.     cout << endl;
  518.     fflush(stdout);
  519.  
  520.     if (printSol) printSolution(*starSearch, Start, Goal);
  521.     delete starSearch;
  522. }
  523.  
  524. int main()
  525. {
  526.     freopen("in.txt", "r", stdin);
  527.    // freopen("out.txt", "w", stdout);
  528.     int boardSqSize = 3;
  529.     cin >> boardSqSize;
  530.     Node::boardSqSize = boardSqSize;
  531.  
  532.     Node Goal;
  533.     for (int i = 0; i < boardSqSize; i++)
  534.         for (int j = 0; j < boardSqSize; j++)
  535.             Goal.A[i][j] = static_cast<puzzle_t>(i * Node::boardSqSize + j + 1);
  536.     Goal.A[Node::boardSqSize - 1][Node::boardSqSize - 1] = 0;
  537.  
  538.     Node Start;
  539.     int x;
  540.     for (int i = 0; i < boardSqSize; i++)
  541.         for (int j = 0; j < boardSqSize; j++)
  542.         {
  543.             cin >> x;
  544.             Start.A[i][j] = static_cast<puzzle_t>(x);
  545.         }
  546.  
  547.  
  548.     cout << "Start: \n" << Start;
  549.     cout << "Goal: \n" << Goal;
  550.  
  551.     if (!Start.isSolveAble())
  552.     {
  553.         cout << "No Solution" << endl;
  554.     }
  555.     else
  556.     {
  557.         {
  558.             cout << "# Linear Conflicts Heuristics: " << endl;
  559.             executeSearch(Start, Goal, LINEAR_CONFLICT, false);
  560.  
  561.             cout << "# ManHattan Distance Heuristics: " << endl;
  562.             executeSearch(Start, Goal, MANHATTAN_DISTANCE, false);
  563.  
  564.             cout << "#Hamming Distance Heuristics: " << endl;
  565.             executeSearch(Start, Goal, HAMMING_DISTANCE, true);
  566.         }
  567.     }
  568. }
  569.  
RAW Paste Data