-- -- Simple Finite Domain, Constraint Propagation Solver -- -- Inspired by 'Fast Procedural Level Population with Playability Constraints' -- by Ian Horswill and Leif Foged. -- -- paper: http://www.aaai.org/ocs/index.php/AIIDE/AIIDE12/paper/view/5466/5691 -- code: http://code.google.com/p/constraint-thingy/ -- -- I haven't implemented any of the more advanced parts of the paper like -- interval domains of the DAG based pathing stuff. -- -- It's not very efficient as it uses tables (Lua's built in hash map) for -- representing domains instead of bit arrays but it is simple. -- -- variable = { -- name = , -- values = { [] = true }* -- } -- constraint = function ( solver, variable ) -- solver = { -- variables = { [variable] = true }+ -- varmap = { [] = variable }+ -- constraints = { [variable] = { constraint }+ }+ -- stack = { { variables } }* -- } function printf( ... ) print(string.format(...)) end function set( value ) return { [value] = true } end function set_empty( set ) return next(set) == nil end function set_singular( set ) local k = next(set) return k ~= nil and next(set, k) == nil end function set_count( set ) local result = 0 for k, _ in pairs(set) do result = result + 1 end return result end function set_copy( set ) local result = {} for k, _ in pairs(set) do result[k] = true end return result end function set_intersection( set1, set2 ) local result = {} for k, _ in pairs(set1) do if set2[k] then result[k] = true end end return result end function set_subtract( set1, set2 ) local result = {} for k, _ in pairs(set1) do if not set2[k] then result[k] = true end end return result end function set_equal( set1, set2 ) for k, _ in pairs(set1) do if not set2[k] then return false end end for k,_ in pairs(set2) do if not set1[k] then return false end end return true end function set_toarray( set ) local result = {} for k, _ in pairs(set) do result[#result+1] = k end return result end function set_fromarray( array ) local result = {} for _, v in ipairs(array) do result[v] = true end return result end function set_tostring( set ) local parts = {} for k, _ in pairs(set) do parts[#parts+1] = k end return string.format("{%s}", table.concat(parts, ' ')) end function print_stack( solver ) for index, frame in ipairs(solver.stack) do printf('%d ->', index) for variable, values in pairs(frame) do printf(' %s = %s', variable.name, set_tostring(values)) end end end function newframe( solver ) local stack = solver.stack local frame = {} stack[#stack+1] = frame end function save( solver, var ) local stack = solver.stack local frame = stack[#stack] frame[var] = var.values end function unwind( solver ) local stack = solver.stack local frame = stack[#stack] for variable, values in pairs(frame) do variable.values = values end stack[#stack] = nil end function solution( solver ) local result = {} for variable, _ in pairs(solver.variables) do assert(set_singular(variable.values)) result[variable.name] = next(variable.values) end return result end function narrow( solver, var, set ) -- printf(' narrow %s to %s', vartos(var), set_tostring(set)) local newvalues = set_intersection(var.values, set) if set_empty(newvalues) then return false end if not set_equal(var.values, newvalues) then save(solver, var) var.values = newvalues local constraints = solver.constraints[var] -- printf(' propagate %s over %d constraints', vartos(var), set_count(constraints)) for constraint, _ in pairs(constraints) do if not constraint(solver, var) then return false end end end return true end function shuffle( array ) for i = 1, #array-1 do local j = math.random(i, #array) array[i], array[j] = array[j], array[i] end return array end function randpairs( tbl ) local kvs = {} for k, v in pairs(tbl) do kvs[#kvs+1] = { k, v } end kvs = shuffle(kvs) local index = 0 return function () index = index + 1 local kv = kvs[index] if kv then return kv[1], kv[2] else return nil end end end function solve( solver ) local varsarray = set_toarray(solver.variables) -- We want to try and solve the 'smallest' variables first. This also -- allows a quick check for any empty variables. table.sort(varsarray, function ( lhs, rhs ) return set_count(lhs.values) < set_count(rhs.values) end) -- If there's any empty variables we're buggered. if set_empty(varsarray[1].values) then error('empty') end if solver.random then shuffle(varsarray) end local iterator = (solver.random) and randpairs or pairs local function aux( solver, index ) if index > #varsarray then coroutine.yield(solution(solver)) return end local variable = varsarray[index] -- If the variable is singular (i.e. has a value) we just carry on. if set_singular(variable.values) then aux(solver, index+1) else -- local failed = true for value, _ in iterator(variable.values) do newframe(solver) -- printf(' try %s = %s', variable.name, value) local success = narrow(solver, variable, set(value)) if success then -- failed = false -- print(debug.traceback()) aux(solver, index+1) end unwind(solver) end -- if not failed then return -- else -- error('unsatisfiable') -- end end end aux(solver, 1) end function printTable( tbl ) for k, v in pairs(tbl) do print(k, v) end end function environment( varnames, varmap ) local result = {} local repeated = {} for _, varname in ipairs(varnames) do local variable = varmap[varname] assert(variable) assert(not repeated[variable], variable.name) result[#result+1] = variable repeated[variable] = true end return result end function vartos( var, values ) values = values or var.values return string.format("%s=%s", var.name, set_tostring(values)) end local numcfails = 0 local numfwfails = 0 local numcardsame = 0 local numcarddiff = 0 local _constraints = { NotEqual = function ( vars ) assert(#vars == 2) local var1, var2 = vars[1], vars[2] return function ( solver, var ) assert(var == var1 or var == var2) local other = (var == var1) and var2 or var1 if set_singular(var.values) then local newvalues = set_subtract(other.values, var.values) local oldvalues = other.values -- printf(' NotEqual %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues)) local result = narrow(solver, other, newvalues) -- printf(' %s', result and 'succeeded' or 'failed') return result end return true end end, Equal = function ( vars ) assert(#vars == 2) local var1, var2 = vars[1], vars[2] return function ( solver, var ) assert(var == var1 or var == var2) local other = (var == var1) and var2 or var1 if set_singular(var.values) then local newvalues = set_copy(var.values) local oldvalues = other.values -- printf(' Equal %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues)) local result = narrow(solver, other, newvalues) -- printf(' %s', result and 'succeeded' or 'failed') return result end return true end end, Cardinality = function ( vars, params ) local min = params.min local max = params.max local value = params.value assert(0 <= min) assert(min <= max) assert(value) assert(min <= #vars) local variables = set_fromarray(varss) return function ( solver, var ) assert(variables[var]) local possible = {} local definite = {} local numpossible = 0 local numdefinite = 0 for variable, _ in pairs(variables) do if variable.values[value] then possible[variable] = true numpossible = numpossible + 1 if set_singular(variable.values) then definite[variable] = true numdefinite = numdefinite + 1 end end end if numpossible < min or numdefinite > max then numcfails = numcfails + 1 return false end -- We have enough possibles, but only just, so try and make -- them definite to be sure. if numpossible == min then for variable, _ in pairs(possible) do local result = narrow(solver, variable, set(value)) if not result then numcfails = numcfails + 1 return false end end return true end -- We've got the maximum number of definites so make the -- possibles that are not also definite into impossibles. if numdefinite == max then local valueset = set(value) for variable, _ in pairs(possible) do if not definite[variable] then local newvalues = set_subtract(variable.values, valueset) -- printf(' Cardinality %s => %s', vartos(variable), set_tostring(newvalues)) local result = narrow(solver, variable, newvalues) if not result then numcfails = numcfails + 1 return false end end end return true end return true end end, Sum = function ( vars, params ) assert(#vars > 0) local total = params.total assert(total) return function ( solver, var ) local count = 0 -- print(#vars) for _, variable in ipairs(vars) do if not set_singular(variable.values) then return true else -- TODO: Need a 'get unique value' function. count = count + next(variable.values) end end -- printf(' Sum count:%d total:%d', count, total) return count == total end end, FourWayConnected = function ( vars, params ) local height = params.height local width = params.width local passable = params.passable assert(#vars == width * height) -- Build up the 'neighbour' relation. local neighbours = {} local at = function ( x, y ) local result = x + (y-1)*width assert(1 <= result and result <= width * height) return result end for y = 1, height do for x = 1, width do local peers = {} -- North if y > 1 then peers[#peers+1] = vars[at(x, y-1)] end -- East if x < width then peers[#peers+1] = vars[at(x+1, y)] end -- South if y < height then peers[#peers+1] = vars[at(x, y+1)] end -- West if x > 1 then peers[#peers+1] = vars[at(x-1, y)] end neighbours[vars[at(x, y)]] = peers end end return function ( solver, var ) for _, variable in ipairs(vars) do if not set_singular(variable.values) then return true end end -- First create a set of passable variables. local passables = {} local numpassables = 0 for _, variable in ipairs(vars) do if passable[next(variable.values)] then passables[variable] = true numpassables = numpassables + 1 end end if set_count(passables) == 0 then return true else local bfs = function( variable ) local result = { [variable] = true } local frontier = { [variable] = true } local count = 1 while next(frontier) do local newFrontier = {} for variable, _ in pairs(frontier) do for _, other in ipairs(neighbours[variable]) do if not result[other] and not frontier[other] and passables[other] then count = count + 1 result[other] = true newFrontier[other] = true end end end frontier = newFrontier end return count end local count = bfs(next(passables)) if count ~= numpassables then numfwfails = numfwfails + 1 if numcfails % 1000 == 0 or numfwfails % 1000 == 0 then -- printf('#card:%d #fourw:%d', numcfails, numfwfails) end end return count == numpassables end end end, } -- tbl = { -- vars = { [] = { }+ }, -- constraints = { { , vars = { }+ } }* -- } function newSolver( tbl ) local dump = tbl.dump == true local vars = tbl.vars local variables = {} local varmap = {} for name, values in pairs(vars) do assert(#values >= 1) local variable = { name = name, values = set_fromarray(values) } variables[variable] = true varmap[name] = variable if dump then print(vartos(variable)) end end local constraints = {} for variable, _ in pairs(variables) do constraints[variable] = {} end for index, v in ipairs(tbl.constraints) do local name = v[1] local vars = environment(v.vars, varmap) local params = v.params local constraint = _constraints[name](vars, params) for _, variable, _ in ipairs(vars) do if dump then printf('constraint #%d %s', index, variable.name) end constraints[variable][constraint] = true end end local random = tbl.random and true or false local result = { variables = variables, varmap = varmap, constraints = constraints, stack = {}, random = random, } local coro = coroutine.create(function () solve(result) end) return function () local status, result = coroutine.resume(coro) if status then return result else error(result) return nil end end end function solution_tostring( solution ) local sorted = {} for varname, value in pairs(solution) do sorted[#sorted+1] = { varname, value } end table.sort(sorted, function ( lhs, rhs ) return lhs[1] < rhs[1] end) local parts = {} for _, data in pairs(sorted) do parts[#parts+1] = string.format("%s=%s", data[1], data[2]) end return string.format("{ %s }", table.concat(parts, ' ')) end function enumerate( solver ) local count = 0 for solution in solver do count = count + 1 printf('#%d %s', count, solution_tostring(solution)) end printf('%d solutions', count) print() return count end function enumerate2( solver ) local count = 0 for solution in solver do count = count + 1 if count % 1000 == 0 then printf('#%d %s', count, solution_tostring(solution)) end end printf('%d solutions', count) print() end local solver = newSolver { vars = { a = { 'x', 'y', 'z' }, b = { 'x', 'y', 'z' }, c = { 'x', 'y', 'z' }, }, constraints = {}, } enumerate(solver) local solver = newSolver { dump = true, vars = { a = { 'x', 'y', 'z' }, b = { 'x', 'y', 'z' }, c = { 'x', 'y', 'z' }, }, constraints = { { 'NotEqual', vars = { 'a', 'b' } }, { 'NotEqual', vars = { 'a', 'c' } }, { 'NotEqual', vars = { 'b', 'c' } }, }, } assert(enumerate(solver) == 6) local solver = newSolver { vars = { a = { 'x', 'y', 'z' }, b = { 'x', 'y', 'z' }, c = { 'x', 'y', 'z' }, }, constraints = { { 'Equal', vars = { 'a', 'b' } }, { 'Equal', vars = { 'a', 'c' } }, { 'Equal', vars = { 'b', 'c' } }, }, } assert(enumerate(solver) == 3) function genNotEquals( varnames ) local result = {} for i = 1, #varnames do for j = i+1, #varnames do result[#result+1] = { 'NotEqual', vars = { varnames[i], varnames[j] } } end end return result end local solver = newSolver { vars = { a = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, b = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, c = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, d = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, e = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, f = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, g = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, h = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, i = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, j = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, k = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, l = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', }, }, constraints = genNotEquals { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l' } , } -- This should have 12! = 479,001,600 solutions. -- enumerate2(solver) local solver = newSolver { vars = { a = { '0', '1', }, b = { '0', '1', }, c = { '0', '1', }, d = { '0', '1', }, }, constraints = { { 'Cardinality', vars = { 'a', 'b', 'c', 'd' }, params = { min = 2, max = 3, value = '1' } }, }, } assert(enumerate(solver) == 10) local solver = newSolver { vars = { a = { '0', '1', }, b = { '0', '1', }, c = { '0', '1', }, d = { '0', '1', }, e = { '0', '1', }, f = { '0', '1', }, g = { '0', '1', }, h = { '0', '1', }, i = { '0', '1', }, j = { '0', '1', }, k = { '0', '1', }, l = { '0', '1', }, m = { '0', '1', }, n = { '0', '1', }, o = { '0', '1', }, r = { '0', '1', }, s = { '0', '1', }, t = { '0', '1', }, u = { '0', '1', }, v = { '0', '1', }, w = { '0', '1', }, x = { '0', '1', }, y = { '0', '1', }, z = { '0', '1', }, }, constraints = { { 'Cardinality', vars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' }, params = { min = 13, max = 13, value = '1' } }, }, } -- enumerate2(solver) local solver = newSolver { vars = { a = { 0, 1, 2, 3, 4, 5, }, b = { 0, 1, 2, 3, 4, 5, }, c = { 0, 1, 2, 3, 4, 5, }, d = { 0, 1, 2, 3, 4, 5, }, }, constraints = { { 'Sum', vars = { 'a', 'b', 'c', 'd' }, params = { total = 2 } }, }, } enumerate(solver) local solver = newSolver { vars = { a11 = { '#', '.' }, a21 = { '#', '.' }, a31 = { '#', '.' }, a12 = { '#', '.' }, a22 = { '#', '.' }, a32 = { '#', '.' }, a13 = { '#', '.' }, a23 = { '#', '.' }, a33 = { '#', '.' }, }, constraints = { { 'FourWayConnected', vars = { 'a11', 'a21', 'a31', 'a12', 'a22', 'a32', 'a13', 'a23', 'a33' }, params = { width = 3, height = 3, passable = { ['.'] = true } } }, }, } -- enumerate3(solver) function enumeratemap( width, height, minf, maxf ) assert(width >= 3) assert(height >= 3) assert(0 <= minf and minf <= 1) assert(0 <= maxf and maxf <= 1) assert(minf <= maxf) local min = math.floor((width*height)*minf) local max = math.floor((width*height)*maxf) local vars = {} local varnames = {} local domain = { '#', '.' } local passable = { ['.'] = true } for y = 1, height do for x = 1, width do local var = string.format("%d_%d", x, y) vars[var] = domain varnames[#varnames+1] = var end end local solver = newSolver { random = true, vars = vars, constraints = { { 'FourWayConnected', vars = varnames, params = { width = width, height = height, passable = passable }, }, { 'Cardinality', vars = varnames, params = { min = min, max = max, value = '.' } } }, } local count = 0 for solution in solver do count = count + 1 printf('#%d', count) for y = 1, height do local line = {} for x = 1, width do local var = string.format("%d_%d", x, y) line[#line+1] = solution[var] end print(table.concat(line)) end print() end printf('%d solutions', count) print() return count end enumeratemap(10, 9, 0.2, 0.5)