--
-- Simple Finite Domain, Constraint Propagation Solver
--
-- Inspired by 'Fast Procedural Level Population with Playability Constraints'
-- by Ian Horswill and Leif Foged.
--
-- paper: http://www.aaai.org/ocs/index.php/AIIDE/AIIDE12/paper/view/5466/5691
-- code: http://code.google.com/p/constraint-thingy/
--
-- I haven't implemented any of the more advanced parts of the paper like
-- interval domains of the DAG based pathing stuff.
--
-- It's not very efficient as it uses tables (Lua's built in hash map) for
-- representing domains instead of bit arrays but it is simple.
--
-- variable = {
-- name = <string>,
-- values = { [<value>] = true }*
-- }
-- constraint = function ( solver, variable )
-- solver = {
-- variables = { [variable] = true }+
-- varmap = { [<string>] = variable }+
-- constraints = { [variable] = { constraint }+ }+
-- stack = { { variables } }*
-- }
function printf( ... ) print(string.format(...)) end
function set( value )
return { [value] = true }
end
function set_empty( set )
return next(set) == nil
end
function set_singular( set )
local k = next(set)
return k ~= nil and next(set, k) == nil
end
function set_count( set )
local result = 0
for k, _ in pairs(set) do
result = result + 1
end
return result
end
function set_copy( set )
local result = {}
for k, _ in pairs(set) do
result[k] = true
end
return result
end
function set_intersection( set1, set2 )
local result = {}
for k, _ in pairs(set1) do
if set2[k] then
result[k] = true
end
end
return result
end
function set_subtract( set1, set2 )
local result = {}
for k, _ in pairs(set1) do
if not set2[k] then
result[k] = true
end
end
return result
end
function set_equal( set1, set2 )
for k, _ in pairs(set1) do
if not set2[k] then
return false
end
end
for k,_ in pairs(set2) do
if not set1[k] then
return false
end
end
return true
end
function set_toarray( set )
local result = {}
for k, _ in pairs(set) do
result[#result+1] = k
end
return result
end
function set_fromarray( array )
local result = {}
for _, v in ipairs(array) do
result[v] = true
end
return result
end
function set_tostring( set )
local parts = {}
for k, _ in pairs(set) do
parts[#parts+1] = k
end
return string.format("{%s}", table.concat(parts, ' '))
end
function print_stack( solver )
for index, frame in ipairs(solver.stack) do
printf('%d ->', index)
for variable, values in pairs(frame) do
printf(' %s = %s', variable.name, set_tostring(values))
end
end
end
function newframe( solver )
local stack = solver.stack
local frame = {}
stack[#stack+1] = frame
end
function save( solver, var )
local stack = solver.stack
local frame = stack[#stack]
frame[var] = var.values
end
function unwind( solver )
local stack = solver.stack
local frame = stack[#stack]
for variable, values in pairs(frame) do
variable.values = values
end
stack[#stack] = nil
end
function solution( solver )
local result = {}
for variable, _ in pairs(solver.variables) do
assert(set_singular(variable.values))
result[variable.name] = next(variable.values)
end
return result
end
function narrow( solver, var, set )
-- printf(' narrow %s to %s', vartos(var), set_tostring(set))
local newvalues = set_intersection(var.values, set)
if set_empty(newvalues) then
return false
end
if not set_equal(var.values, newvalues) then
save(solver, var)
var.values = newvalues
local constraints = solver.constraints[var]
-- printf(' propagate %s over %d constraints', vartos(var), set_count(constraints))
for constraint, _ in pairs(constraints) do
if not constraint(solver, var) then
return false
end
end
end
return true
end
function shuffle( array )
for i = 1, #array-1 do
local j = math.random(i, #array)
array[i], array[j] = array[j], array[i]
end
return array
end
function randpairs( tbl )
local kvs = {}
for k, v in pairs(tbl) do
kvs[#kvs+1] = { k, v }
end
kvs = shuffle(kvs)
local index = 0
return
function ()
index = index + 1
local kv = kvs[index]
if kv then
return kv[1], kv[2]
else
return nil
end
end
end
function solve( solver )
local varsarray = set_toarray(solver.variables)
-- We want to try and solve the 'smallest' variables first. This also
-- allows a quick check for any empty variables.
table.sort(varsarray,
function ( lhs, rhs )
return set_count(lhs.values) < set_count(rhs.values)
end)
-- If there's any empty variables we're buggered.
if set_empty(varsarray[1].values) then
error('empty')
end
if solver.random then
shuffle(varsarray)
end
local iterator = (solver.random) and randpairs or pairs
local function aux( solver, index )
if index > #varsarray then
coroutine.yield(solution(solver))
return
end
local variable = varsarray[index]
-- If the variable is singular (i.e. has a value) we just carry on.
if set_singular(variable.values) then
aux(solver, index+1)
else
-- local failed = true
for value, _ in iterator(variable.values) do
newframe(solver)
-- printf(' try %s = %s', variable.name, value)
local success = narrow(solver, variable, set(value))
if success then
-- failed = false
-- print(debug.traceback())
aux(solver, index+1)
end
unwind(solver)
end
-- if not failed then
return
-- else
-- error('unsatisfiable')
-- end
end
end
aux(solver, 1)
end
function printTable( tbl )
for k, v in pairs(tbl) do
print(k, v)
end
end
function environment( varnames, varmap )
local result = {}
local repeated = {}
for _, varname in ipairs(varnames) do
local variable = varmap[varname]
assert(variable)
assert(not repeated[variable], variable.name)
result[#result+1] = variable
repeated[variable] = true
end
return result
end
function vartos( var, values )
values = values or var.values
return string.format("%s=%s", var.name, set_tostring(values))
end
local numcfails = 0
local numfwfails = 0
local numcardsame = 0
local numcarddiff = 0
local _constraints = {
NotEqual =
function ( vars )
assert(#vars == 2)
local var1, var2 = vars[1], vars[2]
return
function ( solver, var )
assert(var == var1 or var == var2)
local other = (var == var1) and var2 or var1
if set_singular(var.values) then
local newvalues = set_subtract(other.values, var.values)
local oldvalues = other.values
-- printf(' NotEqual %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues))
local result = narrow(solver, other, newvalues)
-- printf(' %s', result and 'succeeded' or 'failed')
return result
end
return true
end
end,
Equal =
function ( vars )
assert(#vars == 2)
local var1, var2 = vars[1], vars[2]
return
function ( solver, var )
assert(var == var1 or var == var2)
local other = (var == var1) and var2 or var1
if set_singular(var.values) then
local newvalues = set_copy(var.values)
local oldvalues = other.values
-- printf(' Equal %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues))
local result = narrow(solver, other, newvalues)
-- printf(' %s', result and 'succeeded' or 'failed')
return result
end
return true
end
end,
Cardinality =
function ( vars, params )
local min = params.min
local max = params.max
local value = params.value
assert(0 <= min)
assert(min <= max)
assert(value)
assert(min <= #vars)
local variables = set_fromarray(varss)
return
function ( solver, var )
assert(variables[var])
local possible = {}
local definite = {}
local numpossible = 0
local numdefinite = 0
for variable, _ in pairs(variables) do
if variable.values[value] then
possible[variable] = true
numpossible = numpossible + 1
if set_singular(variable.values) then
definite[variable] = true
numdefinite = numdefinite + 1
end
end
end
if numpossible < min or numdefinite > max then
numcfails = numcfails + 1
return false
end
-- We have enough possibles, but only just, so try and make
-- them definite to be sure.
if numpossible == min then
for variable, _ in pairs(possible) do
local result = narrow(solver, variable, set(value))
if not result then
numcfails = numcfails + 1
return false
end
end
return true
end
-- We've got the maximum number of definites so make the
-- possibles that are not also definite into impossibles.
if numdefinite == max then
local valueset = set(value)
for variable, _ in pairs(possible) do
if not definite[variable] then
local newvalues = set_subtract(variable.values, valueset)
-- printf(' Cardinality %s => %s', vartos(variable), set_tostring(newvalues))
local result = narrow(solver, variable, newvalues)
if not result then
numcfails = numcfails + 1
return false
end
end
end
return true
end
return true
end
end,
Sum =
function ( vars, params )
assert(#vars > 0)
local total = params.total
assert(total)
return
function ( solver, var )
local count = 0
-- print(#vars)
for _, variable in ipairs(vars) do
if not set_singular(variable.values) then
return true
else
-- TODO: Need a 'get unique value' function.
count = count + next(variable.values)
end
end
-- printf(' Sum count:%d total:%d', count, total)
return count == total
end
end,
FourWayConnected =
function ( vars, params )
local height = params.height
local width = params.width
local passable = params.passable
assert(#vars == width * height)
-- Build up the 'neighbour' relation.
local neighbours = {}
local at =
function ( x, y )
local result = x + (y-1)*width
assert(1 <= result and result <= width * height)
return result
end
for y = 1, height do
for x = 1, width do
local peers = {}
-- North
if y > 1 then
peers[#peers+1] = vars[at(x, y-1)]
end
-- East
if x < width then
peers[#peers+1] = vars[at(x+1, y)]
end
-- South
if y < height then
peers[#peers+1] = vars[at(x, y+1)]
end
-- West
if x > 1 then
peers[#peers+1] = vars[at(x-1, y)]
end
neighbours[vars[at(x, y)]] = peers
end
end
return
function ( solver, var )
for _, variable in ipairs(vars) do
if not set_singular(variable.values) then
return true
end
end
-- First create a set of passable variables.
local passables = {}
local numpassables = 0
for _, variable in ipairs(vars) do
if passable[next(variable.values)] then
passables[variable] = true
numpassables = numpassables + 1
end
end
if set_count(passables) == 0 then
return true
else
local bfs = function( variable )
local result = { [variable] = true }
local frontier = { [variable] = true }
local count = 1
while next(frontier) do
local newFrontier = {}
for variable, _ in pairs(frontier) do
for _, other in ipairs(neighbours[variable]) do
if not result[other] and not frontier[other] and passables[other] then
count = count + 1
result[other] = true
newFrontier[other] = true
end
end
end
frontier = newFrontier
end
return count
end
local count = bfs(next(passables))
if count ~= numpassables then
numfwfails = numfwfails + 1
if numcfails % 1000 == 0 or numfwfails % 1000 == 0 then
-- printf('#card:%d #fourw:%d', numcfails, numfwfails)
end
end
return count == numpassables
end
end
end,
}
-- tbl = {
-- vars = { [<string>] = { <string> }+ },
-- constraints = { { <string>, vars = { <string> }+ } }*
-- }
function newSolver( tbl )
local dump = tbl.dump == true
local vars = tbl.vars
local variables = {}
local varmap = {}
for name, values in pairs(vars) do
assert(#values >= 1)
local variable = { name = name, values = set_fromarray(values) }
variables[variable] = true
varmap[name] = variable
if dump then
print(vartos(variable))
end
end
local constraints = {}
for variable, _ in pairs(variables) do
constraints[variable] = {}
end
for index, v in ipairs(tbl.constraints) do
local name = v[1]
local vars = environment(v.vars, varmap)
local params = v.params
local constraint = _constraints[name](vars, params)
for _, variable, _ in ipairs(vars) do
if dump then
printf('constraint #%d %s', index, variable.name)
end
constraints[variable][constraint] = true
end
end
local random = tbl.random and true or false
local result = {
variables = variables,
varmap = varmap,
constraints = constraints,
stack = {},
random = random,
}
local coro = coroutine.create(function () solve(result) end)
return
function ()
local status, result = coroutine.resume(coro)
if status then
return result
else
error(result)
return nil
end
end
end
function solution_tostring( solution )
local sorted = {}
for varname, value in pairs(solution) do
sorted[#sorted+1] = { varname, value }
end
table.sort(sorted,
function ( lhs, rhs )
return lhs[1] < rhs[1]
end)
local parts = {}
for _, data in pairs(sorted) do
parts[#parts+1] = string.format("%s=%s", data[1], data[2])
end
return string.format("{ %s }", table.concat(parts, ' '))
end
function enumerate( solver )
local count = 0
for solution in solver do
count = count + 1
printf('#%d %s', count, solution_tostring(solution))
end
printf('%d solutions', count)
print()
return count
end
function enumerate2( solver )
local count = 0
for solution in solver do
count = count + 1
if count % 1000 == 0 then
printf('#%d %s', count, solution_tostring(solution))
end
end
printf('%d solutions', count)
print()
end
local solver = newSolver {
vars = {
a = { 'x', 'y', 'z' },
b = { 'x', 'y', 'z' },
c = { 'x', 'y', 'z' },
},
constraints = {},
}
enumerate(solver)
local solver = newSolver {
dump = true,
vars = {
a = { 'x', 'y', 'z' },
b = { 'x', 'y', 'z' },
c = { 'x', 'y', 'z' },
},
constraints = {
{ 'NotEqual', vars = { 'a', 'b' } },
{ 'NotEqual', vars = { 'a', 'c' } },
{ 'NotEqual', vars = { 'b', 'c' } },
},
}
assert(enumerate(solver) == 6)
local solver = newSolver {
vars = {
a = { 'x', 'y', 'z' },
b = { 'x', 'y', 'z' },
c = { 'x', 'y', 'z' },
},
constraints = {
{ 'Equal', vars = { 'a', 'b' } },
{ 'Equal', vars = { 'a', 'c' } },
{ 'Equal', vars = { 'b', 'c' } },
},
}
assert(enumerate(solver) == 3)
function genNotEquals( varnames )
local result = {}
for i = 1, #varnames do
for j = i+1, #varnames do
result[#result+1] = { 'NotEqual', vars = { varnames[i], varnames[j] } }
end
end
return result
end
local solver = newSolver {
vars = {
a = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
b = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
c = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
d = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
e = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
f = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
g = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
h = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
i = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
j = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
k = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
l = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', },
},
constraints = genNotEquals { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l' } ,
}
-- This should have 12! = 479,001,600 solutions.
-- enumerate2(solver)
local solver = newSolver {
vars = {
a = { '0', '1', },
b = { '0', '1', },
c = { '0', '1', },
d = { '0', '1', },
},
constraints = {
{ 'Cardinality', vars = { 'a', 'b', 'c', 'd' }, params = { min = 2, max = 3, value = '1' } },
},
}
assert(enumerate(solver) == 10)
local solver = newSolver {
vars = {
a = { '0', '1', },
b = { '0', '1', },
c = { '0', '1', },
d = { '0', '1', },
e = { '0', '1', },
f = { '0', '1', },
g = { '0', '1', },
h = { '0', '1', },
i = { '0', '1', },
j = { '0', '1', },
k = { '0', '1', },
l = { '0', '1', },
m = { '0', '1', },
n = { '0', '1', },
o = { '0', '1', },
r = { '0', '1', },
s = { '0', '1', },
t = { '0', '1', },
u = { '0', '1', },
v = { '0', '1', },
w = { '0', '1', },
x = { '0', '1', },
y = { '0', '1', },
z = { '0', '1', },
},
constraints = {
{ 'Cardinality', vars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' }, params = { min = 13, max = 13, value = '1' } },
},
}
-- enumerate2(solver)
local solver = newSolver {
vars = {
a = { 0, 1, 2, 3, 4, 5, },
b = { 0, 1, 2, 3, 4, 5, },
c = { 0, 1, 2, 3, 4, 5, },
d = { 0, 1, 2, 3, 4, 5, },
},
constraints = {
{ 'Sum', vars = { 'a', 'b', 'c', 'd' }, params = { total = 2 } },
},
}
enumerate(solver)
local solver = newSolver {
vars = {
a11 = { '#', '.' },
a21 = { '#', '.' },
a31 = { '#', '.' },
a12 = { '#', '.' },
a22 = { '#', '.' },
a32 = { '#', '.' },
a13 = { '#', '.' },
a23 = { '#', '.' },
a33 = { '#', '.' },
},
constraints = {
{ 'FourWayConnected', vars = { 'a11', 'a21', 'a31', 'a12', 'a22', 'a32', 'a13', 'a23', 'a33' }, params = { width = 3, height = 3, passable = { ['.'] = true } } },
},
}
-- enumerate3(solver)
function enumeratemap( width, height, minf, maxf )
assert(width >= 3)
assert(height >= 3)
assert(0 <= minf and minf <= 1)
assert(0 <= maxf and maxf <= 1)
assert(minf <= maxf)
local min = math.floor((width*height)*minf)
local max = math.floor((width*height)*maxf)
local vars = {}
local varnames = {}
local domain = { '#', '.' }
local passable = { ['.'] = true }
for y = 1, height do
for x = 1, width do
local var = string.format("%d_%d", x, y)
vars[var] = domain
varnames[#varnames+1] = var
end
end
local solver = newSolver {
random = true,
vars = vars,
constraints = {
{
'FourWayConnected',
vars = varnames,
params = {
width = width,
height = height,
passable = passable
},
},
{
'Cardinality',
vars = varnames,
params = {
min = min,
max = max,
value = '.'
}
}
},
}
local count = 0
for solution in solver do
count = count + 1
printf('#%d', count)
for y = 1, height do
local line = {}
for x = 1, width do
local var = string.format("%d_%d", x, y)
line[#line+1] = solution[var]
end
print(table.concat(line))
end
print()
end
printf('%d solutions', count)
print()
return count
end
enumeratemap(10, 9, 0.2, 0.5)