#include #include #include #include #include using namespace std; // Represents an abstract transition. // StateMachine maintains the mapping of transitions // to actual target states. class Transition { public: typedef string TransitionID; Transition(const TransitionID& id_) : id(id_) {} TransitionID id; // Two common transitions for starting/stopping // the state process. static Transition Init; static Transition Exit; }; Transition Transition::Init("_Init"); Transition Transition::Exit("_Exit"); bool operator<(const Transition& rhs, const Transition& lhs) { return rhs.id < lhs.id; } bool operator==(const Transition& rhs, const Transition& lhs) { return rhs.id == lhs.id; } // A base class for data containers passed around between states. struct Context { virtual ~Context() {} }; // Represents a single state. class IState { public: typedef string StateID; virtual ~IState() {} virtual StateID getID() const = 0; virtual Transition run(Context& context) = 0; }; class StateMachineException : public exception { public: StateMachineException(const string& what) : mWhat(what) {} ~StateMachineException() throw() {} virtual const char* what() const throw() { return mWhat.c_str(); } private: string mWhat; }; // Represents a finite state machine. class StateMachine { public: typedef IState::StateID StateID; ~StateMachine() { clear(); } void clear(); // init() must be called before run(). // Override this to register states and transitions. virtual void init(); virtual void run(Context& context, StateID initial); protected: typedef boost::shared_ptr StatePtr; // Use these methods to set up the state graph. // Every state in your machine must be registered. // Every transition in your machine except for Transition::Exit must be registered. void registerState(StatePtr state); void registerTransition(StatePtr from, Transition t, StatePtr to); private: StatePtr getNextState(StatePtr fromState, Transition t) const; StatePtr getState(StateID stateID) const; typedef pair StateTransitionPair; typedef map StateTransitionMap; typedef map StateMap; StateMap stateRegistry; StateTransitionMap stateTransitionRegistry; }; void StateMachine::init() { } void StateMachine::clear() { stateRegistry.clear(); stateTransitionRegistry.clear(); } void StateMachine::registerState(StatePtr state) { if (! state) throw StateMachineException("Null pointer passed to registerState"); stateRegistry[state->getID()] = state; } void StateMachine::registerTransition(StatePtr fromState, Transition t, StatePtr toState) { if (! fromState || ! toState) throw StateMachineException("Null pointer passed to registerTransition"); stateTransitionRegistry[StateTransitionPair(fromState->getID(), t)] = toState->getID(); } StateMachine::StatePtr StateMachine::getNextState(StatePtr fromState, Transition t) const { assert(fromState); StateTransitionMap::const_iterator pvStateID = stateTransitionRegistry.find( StateTransitionPair(fromState->getID(), t)); if (pvStateID == stateTransitionRegistry.end()) { ostringstream msg; msg << "transition \"" << t.id << "\""; msg << " is not in the transition registry for state \"" << fromState->getID() << "\""; throw StateMachineException(msg.str()); } return getState(pvStateID->second); } StateMachine::StatePtr StateMachine::getState(IState::StateID stateID) const { StateMap::const_iterator pvState = stateRegistry.find(stateID); if (pvState == stateRegistry.end()) { ostringstream msg; msg << "state \"" << stateID << "\" is not in the state registry"; throw StateMachineException(msg.str()); } return pvState->second; } void StateMachine::run(Context& context, IState::StateID initial) { Transition t = Transition::Init; StatePtr nextState = getState(initial); while (true) { StatePtr currentState = nextState; t = currentState->run(context); if (t == Transition::Exit) break; nextState = getNextState(currentState, t); } } // A simple state machine implementation follows... // The data for this state machine. // Pass the selection from state A through the rest of the machine. struct ProcessContext : public Context { string selection; }; // A sample state. class StateA : public IState { public: static StateID ID; // every state has one of these static Transition Continue; // the state's declared transitions static Transition Fail; virtual StateID getID() const { return StateA::ID; } virtual Transition run(Context& context) { // Doing this without the Context superclass and downcast // (i.e. via templates) is possible but requires a // very different design. ProcessContext& pc = dynamic_cast(context); cout << "Which do you mean, an African or European swallow?" << endl; getline(cin, pc.selection); if (pc.selection == "African" || pc.selection == "European") { // the state machine maps this transition // to the appropriate next state return Continue; } else { return Fail; } } }; // The string IDs specified here help greatly with debugging, at a cost of efficiency. // They can be generated for convenience, or converted to symbols for better performance. IState::StateID StateA::ID("StateA"); Transition StateA::Continue("Continue"); Transition StateA::Fail("Fail"); // Two other trivial states follow. class StateB : public IState { public: static StateID ID; // No transitions declared; only Exit is used. virtual StateID getID() const { return StateB::ID; } virtual Transition run(Context& context) { ProcessContext& pc = dynamic_cast(context); cout << "You said: \"" << pc.selection << "\"" << endl; return Transition::Exit; // this stops the machine } }; IState::StateID StateB::ID("StateB"); class StateError : public IState { public: static StateID ID; virtual StateID getID() const { return StateError::ID; } virtual Transition run(Context& context) { cout << "Sorry, you have died." << endl; return Transition::Exit; // this stops the machine } }; IState::StateID StateError::ID("StateError"); // Here's one state machine. // You could define a Process2 that wires up the states a different way. // The register methods overwrite previous work, so you could also have // Process2 derive from Process1 and have its init() override certain // transitions that Process1 set. class Process1 : public StateMachine { public: virtual void init(); using StateMachine::run; virtual void run() { ProcessContext ctx; run(ctx, StateA::ID); } private: StatePtr initState; }; void Process1::init() { StatePtr stateA(new StateA); StatePtr stateB(new StateB); StatePtr stateError(new StateError); registerState(stateA); registerState(stateB); registerState(stateError); registerTransition(stateA, StateA::Continue, stateB); registerTransition(stateA, StateA::Fail, stateError); } int main() { try { Process1 p; p.init(); p.run(); return 0; } catch (exception& e) { cerr << "Caught exception of type " << typeid(e).name() << ": " << e.what() << endl; return 1; } }