#include <cassert>
#include <iostream>
#include <map>
#include <sstream>
#include <boost/shared_ptr.hpp>
using namespace std;
// Represents an abstract transition.
// StateMachine maintains the mapping of transitions
// to actual target states.
class Transition
{
public:
typedef string TransitionID;
Transition(const TransitionID& id_) : id(id_) {}
TransitionID id;
// Two common transitions for starting/stopping
// the state process.
static Transition Init;
static Transition Exit;
};
Transition Transition::Init("_Init");
Transition Transition::Exit("_Exit");
bool operator<(const Transition& rhs, const Transition& lhs)
{ return rhs.id < lhs.id; }
bool operator==(const Transition& rhs, const Transition& lhs)
{ return rhs.id == lhs.id; }
// A base class for data containers passed around between states.
struct Context
{
virtual ~Context() {}
};
// Represents a single state.
class IState
{
public:
typedef string StateID;
virtual ~IState() {}
virtual StateID getID() const = 0;
virtual Transition run(Context& context) = 0;
};
class StateMachineException : public exception
{
public:
StateMachineException(const string& what) : mWhat(what) {}
~StateMachineException() throw() {}
virtual const char* what() const throw() { return mWhat.c_str(); }
private:
string mWhat;
};
// Represents a finite state machine.
class StateMachine
{
public:
typedef IState::StateID StateID;
~StateMachine() { clear(); }
void clear();
// init() must be called before run().
// Override this to register states and transitions.
virtual void init();
virtual void run(Context& context, StateID initial);
protected:
typedef boost::shared_ptr<IState> StatePtr;
// Use these methods to set up the state graph.
// Every state in your machine must be registered.
// Every transition in your machine except for Transition::Exit must be registered.
void registerState(StatePtr state);
void registerTransition(StatePtr from, Transition t, StatePtr to);
private:
StatePtr getNextState(StatePtr fromState, Transition t) const;
StatePtr getState(StateID stateID) const;
typedef pair<StateID, Transition> StateTransitionPair;
typedef map<StateTransitionPair, StateID> StateTransitionMap;
typedef map<StateID, StatePtr> StateMap;
StateMap stateRegistry;
StateTransitionMap stateTransitionRegistry;
};
void StateMachine::init()
{
}
void StateMachine::clear()
{
stateRegistry.clear();
stateTransitionRegistry.clear();
}
void StateMachine::registerState(StatePtr state)
{
if (! state)
throw StateMachineException("Null pointer passed to registerState");
stateRegistry[state->getID()] = state;
}
void StateMachine::registerTransition(StatePtr fromState, Transition t, StatePtr toState)
{
if (! fromState || ! toState)
throw StateMachineException("Null pointer passed to registerTransition");
stateTransitionRegistry[StateTransitionPair(fromState->getID(), t)] = toState->getID();
}
StateMachine::StatePtr StateMachine::getNextState(StatePtr fromState, Transition t) const
{
assert(fromState);
StateTransitionMap::const_iterator pvStateID = stateTransitionRegistry.find(
StateTransitionPair(fromState->getID(), t));
if (pvStateID == stateTransitionRegistry.end()) {
ostringstream msg;
msg << "transition \"" << t.id << "\"";
msg << " is not in the transition registry for state \"" << fromState->getID() << "\"";
throw StateMachineException(msg.str());
}
return getState(pvStateID->second);
}
StateMachine::StatePtr StateMachine::getState(IState::StateID stateID) const
{
StateMap::const_iterator pvState = stateRegistry.find(stateID);
if (pvState == stateRegistry.end()) {
ostringstream msg;
msg << "state \"" << stateID << "\" is not in the state registry";
throw StateMachineException(msg.str());
}
return pvState->second;
}
void StateMachine::run(Context& context, IState::StateID initial)
{
Transition t = Transition::Init;
StatePtr nextState = getState(initial);
while (true) {
StatePtr currentState = nextState;
t = currentState->run(context);
if (t == Transition::Exit)
break;
nextState = getNextState(currentState, t);
}
}
// A simple state machine implementation follows...
// The data for this state machine.
// Pass the selection from state A through the rest of the machine.
struct ProcessContext : public Context
{
string selection;
};
// A sample state.
class StateA : public IState
{
public:
static StateID ID; // every state has one of these
static Transition Continue; // the state's declared transitions
static Transition Fail;
virtual StateID getID() const { return StateA::ID; }
virtual Transition run(Context& context) {
// Doing this without the Context superclass and downcast
// (i.e. via templates) is possible but requires a
// very different design.
ProcessContext& pc = dynamic_cast<ProcessContext&>(context);
cout << "Which do you mean, an African or European swallow?" << endl;
getline(cin, pc.selection);
if (pc.selection == "African" || pc.selection == "European") {
// the state machine maps this transition
// to the appropriate next state
return Continue;
} else {
return Fail;
}
}
};
// The string IDs specified here help greatly with debugging, at a cost of efficiency.
// They can be generated for convenience, or converted to symbols for better performance.
IState::StateID StateA::ID("StateA");
Transition StateA::Continue("Continue");
Transition StateA::Fail("Fail");
// Two other trivial states follow.
class StateB : public IState
{
public:
static StateID ID;
// No transitions declared; only Exit is used.
virtual StateID getID() const { return StateB::ID; }
virtual Transition run(Context& context) {
ProcessContext& pc = dynamic_cast<ProcessContext&>(context);
cout << "You said: \"" << pc.selection << "\"" << endl;
return Transition::Exit; // this stops the machine
}
};
IState::StateID StateB::ID("StateB");
class StateError : public IState
{
public:
static StateID ID;
virtual StateID getID() const { return StateError::ID; }
virtual Transition run(Context& context) {
cout << "Sorry, you have died." << endl;
return Transition::Exit; // this stops the machine
}
};
IState::StateID StateError::ID("StateError");
// Here's one state machine.
// You could define a Process2 that wires up the states a different way.
// The register methods overwrite previous work, so you could also have
// Process2 derive from Process1 and have its init() override certain
// transitions that Process1 set.
class Process1 : public StateMachine
{
public:
virtual void init();
using StateMachine::run;
virtual void run() {
ProcessContext ctx;
run(ctx, StateA::ID);
}
private:
StatePtr initState;
};
void Process1::init()
{
StatePtr stateA(new StateA);
StatePtr stateB(new StateB);
StatePtr stateError(new StateError);
registerState(stateA);
registerState(stateB);
registerState(stateError);
registerTransition(stateA, StateA::Continue, stateB);
registerTransition(stateA, StateA::Fail, stateError);
}
int main()
{
try {
Process1 p;
p.init();
p.run();
return 0;
} catch (exception& e) {
cerr << "Caught exception of type " << typeid(e).name() << ": " << e.what() << endl;
return 1;
}
}