Advertisement
Guest User

State Machine Sample

a guest
Jun 29th, 2010
1,620
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 7.80 KB | None | 0 0
  1. #include <cassert>
  2. #include <iostream>
  3. #include <map>
  4. #include <sstream>
  5. #include <boost/shared_ptr.hpp>
  6.  
  7. using namespace std;
  8.  
  9. // Represents an abstract transition.
  10. // StateMachine maintains the mapping of transitions
  11. // to actual target states.
  12. class Transition
  13. {
  14. public:
  15.     typedef string TransitionID;
  16.     Transition(const TransitionID& id_) : id(id_) {}
  17.     TransitionID id;
  18.  
  19.     // Two common transitions for starting/stopping
  20.     // the state process.
  21.     static Transition Init;
  22.     static Transition Exit;
  23. };
  24.  
  25. Transition Transition::Init("_Init");
  26. Transition Transition::Exit("_Exit");
  27.  
  28. bool operator<(const Transition& rhs, const Transition& lhs)
  29. { return rhs.id < lhs.id; }
  30.  
  31. bool operator==(const Transition& rhs, const Transition& lhs)
  32. { return rhs.id == lhs.id; }
  33.  
  34. // A base class for data containers passed around between states.
  35. struct Context
  36. {
  37.     virtual ~Context() {}
  38. };
  39.  
  40. // Represents a single state.
  41. class IState
  42. {
  43. public:
  44.     typedef string StateID;
  45.     virtual ~IState() {}
  46.     virtual StateID getID() const = 0;
  47.     virtual Transition run(Context& context) = 0;
  48. };
  49.  
  50. class StateMachineException : public exception
  51. {
  52. public:
  53.     StateMachineException(const string& what) : mWhat(what) {}
  54.     ~StateMachineException() throw() {}
  55.     virtual const char* what() const throw() { return mWhat.c_str(); }
  56. private:
  57.     string mWhat;
  58. };
  59.  
  60. // Represents a finite state machine.
  61. class StateMachine
  62. {
  63. public:
  64.     typedef IState::StateID StateID;
  65.     ~StateMachine() { clear(); }
  66.     void clear();
  67.     // init() must be called before run().
  68.     // Override this to register states and transitions.
  69.     virtual void init();
  70.     virtual void run(Context& context, StateID initial);
  71.  
  72. protected:
  73.     typedef boost::shared_ptr<IState> StatePtr;
  74.     // Use these methods to set up the state graph.
  75.     // Every state in your machine must be registered.
  76.     // Every transition in your machine except for Transition::Exit must be registered.
  77.     void registerState(StatePtr state);
  78.     void registerTransition(StatePtr from, Transition t, StatePtr to);
  79.  
  80. private:
  81.     StatePtr getNextState(StatePtr fromState, Transition t) const;
  82.     StatePtr getState(StateID stateID) const;
  83.  
  84.     typedef pair<StateID, Transition> StateTransitionPair;
  85.     typedef map<StateTransitionPair, StateID> StateTransitionMap;
  86.     typedef map<StateID, StatePtr> StateMap;
  87.     StateMap stateRegistry;
  88.     StateTransitionMap stateTransitionRegistry;
  89. };
  90.  
  91. void StateMachine::init()
  92. {
  93. }
  94.  
  95. void StateMachine::clear()
  96. {
  97.     stateRegistry.clear();
  98.     stateTransitionRegistry.clear();
  99. }
  100.  
  101. void StateMachine::registerState(StatePtr state)
  102. {
  103.     if (! state)
  104.         throw StateMachineException("Null pointer passed to registerState");
  105.     stateRegistry[state->getID()] = state;
  106. }
  107.  
  108. void StateMachine::registerTransition(StatePtr fromState, Transition t, StatePtr toState)
  109. {
  110.     if (! fromState || ! toState)
  111.         throw StateMachineException("Null pointer passed to registerTransition");
  112.     stateTransitionRegistry[StateTransitionPair(fromState->getID(), t)] = toState->getID();
  113. }
  114.  
  115. StateMachine::StatePtr StateMachine::getNextState(StatePtr fromState, Transition t) const
  116. {
  117.     assert(fromState);
  118.     StateTransitionMap::const_iterator pvStateID  = stateTransitionRegistry.find(
  119.         StateTransitionPair(fromState->getID(), t));
  120.     if (pvStateID == stateTransitionRegistry.end()) {
  121.         ostringstream msg;
  122.         msg << "transition \"" << t.id << "\"";
  123.         msg << " is not in the transition registry for state \"" << fromState->getID() << "\"";
  124.         throw StateMachineException(msg.str());
  125.     }
  126.  
  127.     return getState(pvStateID->second);
  128. }
  129.  
  130. StateMachine::StatePtr StateMachine::getState(IState::StateID stateID) const
  131. {
  132.     StateMap::const_iterator pvState = stateRegistry.find(stateID);
  133.     if (pvState == stateRegistry.end()) {
  134.         ostringstream msg;
  135.         msg << "state \"" << stateID << "\" is not in the state registry";
  136.         throw StateMachineException(msg.str());
  137.     }
  138.  
  139.     return pvState->second;
  140. }
  141.  
  142. void StateMachine::run(Context& context, IState::StateID initial)
  143. {
  144.     Transition t = Transition::Init;
  145.     StatePtr nextState = getState(initial);
  146.     while (true) {
  147.         StatePtr currentState = nextState;
  148.         t = currentState->run(context);
  149.         if (t == Transition::Exit)
  150.             break;
  151.         nextState = getNextState(currentState, t);
  152.     }
  153. }
  154.  
  155. // A simple state machine implementation follows...
  156.  
  157. // The data for this state machine.
  158. // Pass the selection from state A through the rest of the machine.
  159. struct ProcessContext : public Context
  160. {
  161.     string selection;
  162. };
  163.  
  164. // A sample state.
  165. class StateA : public IState
  166. {
  167. public:
  168.     static StateID ID; // every state has one of these
  169.     static Transition Continue;     // the state's declared transitions
  170.     static Transition Fail;
  171.     virtual StateID getID() const { return StateA::ID; }
  172.     virtual Transition run(Context& context) {
  173.         // Doing this without the Context superclass and downcast
  174.         // (i.e. via templates) is possible but requires a
  175.         // very different design.
  176.         ProcessContext& pc = dynamic_cast<ProcessContext&>(context);
  177.         cout << "Which do you mean, an African or European swallow?" << endl;
  178.         getline(cin, pc.selection);
  179.         if (pc.selection == "African" || pc.selection == "European") {
  180.             // the state machine maps this transition
  181.             // to the appropriate next state
  182.             return Continue;
  183.         } else {
  184.             return Fail;
  185.         }
  186.     }
  187. };
  188.  
  189. // The string IDs specified here help greatly with debugging, at a cost of efficiency.
  190. // They can be generated for convenience, or converted to symbols for better performance.
  191. IState::StateID StateA::ID("StateA");
  192. Transition StateA::Continue("Continue");
  193. Transition StateA::Fail("Fail");
  194.  
  195. // Two other trivial states follow.
  196.  
  197. class StateB : public IState
  198. {
  199. public:
  200.     static StateID ID;
  201.     // No transitions declared; only Exit is used.
  202.     virtual StateID getID() const { return StateB::ID; }
  203.     virtual Transition run(Context& context) {
  204.         ProcessContext& pc = dynamic_cast<ProcessContext&>(context);
  205.         cout << "You said: \"" << pc.selection << "\"" << endl;
  206.         return Transition::Exit; // this stops the machine
  207.     }
  208. };
  209.  
  210. IState::StateID StateB::ID("StateB");
  211.  
  212. class StateError : public IState
  213. {
  214. public:
  215.     static StateID ID;
  216.     virtual StateID getID() const { return StateError::ID; }
  217.     virtual Transition run(Context& context) {
  218.         cout << "Sorry, you have died." << endl;
  219.         return Transition::Exit; // this stops the machine
  220.     }
  221. };
  222.  
  223. IState::StateID StateError::ID("StateError");
  224.  
  225. // Here's one state machine.
  226. // You could define a Process2 that wires up the states a different way.
  227. // The register methods overwrite previous work, so you could also have
  228. // Process2 derive from Process1 and have its init() override certain
  229. // transitions that Process1 set.
  230. class Process1 : public StateMachine
  231. {
  232. public:
  233.     virtual void init();
  234.     using StateMachine::run;
  235.     virtual void run() {
  236.         ProcessContext ctx;
  237.         run(ctx, StateA::ID);
  238.     }
  239.  
  240. private:
  241.     StatePtr initState;
  242. };
  243.  
  244. void Process1::init()
  245. {
  246.     StatePtr stateA(new StateA);
  247.     StatePtr stateB(new StateB);
  248.     StatePtr stateError(new StateError);
  249.  
  250.     registerState(stateA);
  251.     registerState(stateB);
  252.     registerState(stateError);
  253.     registerTransition(stateA, StateA::Continue, stateB);
  254.     registerTransition(stateA, StateA::Fail, stateError);
  255. }
  256.  
  257. int main()
  258. {
  259.     try {
  260.         Process1 p;
  261.         p.init();
  262.         p.run();
  263.         return 0;
  264.     } catch (exception& e) {
  265.         cerr << "Caught exception of type " << typeid(e).name() << ": " << e.what() << endl;
  266.         return 1;
  267.     }
  268. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement