Pastebin launched a little side project called VERYVIRAL.com, check it out ;-) Want more features on Pastebin? Sign Up, it's FREE!
Guest

Simple Constraint Solver in Lua

By: a guest on May 11th, 2013  |  syntax: Lua  |  size: 19.94 KB  |  views: 71  |  expires: Never
download  |  raw  |  embed  |  report abuse  |  print
Text below is selected. Please press Ctrl+C to copy to your clipboard. (⌘+C on Mac)
  1. --
  2. -- Simple Finite Domain, Constraint Propagation Solver
  3. --
  4. -- Inspired by 'Fast Procedural Level Population with Playability Constraints'
  5. -- by Ian Horswill and Leif Foged.
  6. --
  7. -- paper: http://www.aaai.org/ocs/index.php/AIIDE/AIIDE12/paper/view/5466/5691
  8. -- code:  http://code.google.com/p/constraint-thingy/
  9. --
  10. -- I haven't implemented any of the more advanced parts of the paper like
  11. -- interval domains of the DAG based pathing stuff.
  12. --
  13. -- It's not very efficient as it uses tables (Lua's built in hash map) for
  14. -- representing domains instead of bit arrays but it is simple.
  15. --
  16.  
  17.  
  18. -- variable = {
  19. --     name = <string>,
  20. --     values = { [<value>] = true }*
  21. -- }
  22.  
  23. -- constraint = function ( solver, variable )
  24.  
  25. -- solver = {
  26. --     variables = { [variable] = true }+
  27. --     varmap = { [<string>] = variable }+
  28. --     constraints = { [variable] = { constraint }+ }+
  29. --     stack = { { variables } }*
  30. -- }
  31.  
  32. function printf( ... ) print(string.format(...)) end
  33.  
  34. function set( value )
  35.         return { [value] = true }
  36. end
  37.  
  38. function set_empty( set )
  39.         return next(set) == nil
  40. end
  41.  
  42. function set_singular( set )
  43.         local k = next(set)
  44.  
  45.         return k ~= nil and next(set, k) == nil
  46. end
  47.  
  48. function set_count( set )
  49.         local result = 0
  50.  
  51.         for k, _ in pairs(set) do
  52.                 result = result + 1
  53.         end
  54.  
  55.         return result
  56. end
  57.  
  58. function set_copy( set )
  59.         local result = {}
  60.  
  61.         for k, _ in pairs(set) do
  62.                 result[k] = true
  63.         end
  64.  
  65.         return result
  66. end
  67.  
  68. function set_intersection( set1, set2 )
  69.         local result = {}
  70.  
  71.         for k, _ in pairs(set1) do
  72.                 if set2[k] then
  73.                         result[k] = true
  74.                 end
  75.         end
  76.  
  77.         return result
  78. end
  79.  
  80. function set_subtract( set1, set2 )
  81.         local result = {}
  82.  
  83.         for k, _ in pairs(set1) do
  84.                 if not set2[k] then
  85.                         result[k] = true
  86.                 end
  87.         end
  88.  
  89.         return result
  90. end
  91.  
  92. function set_equal( set1, set2 )
  93.         for k, _ in pairs(set1) do
  94.                 if not set2[k] then
  95.                         return false
  96.                 end
  97.         end
  98.  
  99.         for k,_ in pairs(set2) do
  100.                 if not set1[k] then
  101.                         return false
  102.                 end
  103.         end
  104.  
  105.         return true
  106. end
  107.  
  108. function set_toarray( set )
  109.         local result = {}
  110.  
  111.         for k, _ in pairs(set) do
  112.                 result[#result+1] = k
  113.         end
  114.  
  115.         return result
  116. end
  117.  
  118. function set_fromarray( array )
  119.         local result = {}
  120.  
  121.         for _, v in ipairs(array) do
  122.                 result[v] = true
  123.         end
  124.  
  125.         return result
  126. end
  127.  
  128. function set_tostring( set )
  129.         local parts = {}
  130.  
  131.         for k, _ in pairs(set) do
  132.                 parts[#parts+1] = k
  133.         end
  134.  
  135.         return string.format("{%s}", table.concat(parts, ' '))
  136. end
  137.  
  138. function print_stack( solver )
  139.         for index, frame in ipairs(solver.stack) do
  140.                 printf('%d ->', index)
  141.  
  142.                 for variable, values in pairs(frame) do
  143.                         printf('  %s = %s', variable.name, set_tostring(values))
  144.                 end
  145.         end
  146. end
  147.  
  148. function newframe( solver )
  149.         local stack = solver.stack
  150.         local frame = {}
  151.         stack[#stack+1] = frame
  152. end
  153.  
  154. function save( solver, var )
  155.         local stack = solver.stack
  156.         local frame = stack[#stack]
  157.        
  158.         frame[var] = var.values
  159. end
  160.  
  161. function unwind( solver )
  162.         local stack = solver.stack
  163.         local frame = stack[#stack]
  164.  
  165.         for variable, values in pairs(frame) do
  166.                 variable.values = values
  167.         end
  168.  
  169.         stack[#stack] = nil
  170. end
  171.  
  172. function solution( solver )
  173.         local result = {}
  174.  
  175.         for variable, _ in pairs(solver.variables) do
  176.                 assert(set_singular(variable.values))
  177.  
  178.                 result[variable.name] = next(variable.values)
  179.         end
  180.  
  181.         return result
  182. end
  183.  
  184. function narrow( solver, var, set )
  185.         -- printf('  narrow %s to %s', vartos(var), set_tostring(set))
  186.         local newvalues = set_intersection(var.values, set)
  187.  
  188.         if set_empty(newvalues) then
  189.                 return false
  190.         end
  191.  
  192.         if not set_equal(var.values, newvalues) then
  193.                 save(solver, var)
  194.  
  195.                 var.values = newvalues
  196.  
  197.                 local constraints = solver.constraints[var]
  198.  
  199.                 -- printf('  propagate %s over %d constraints', vartos(var), set_count(constraints))
  200.  
  201.                 for constraint, _ in pairs(constraints) do
  202.                         if not constraint(solver, var) then
  203.                                 return false
  204.                         end
  205.                 end
  206.         end
  207.  
  208.         return true
  209. end
  210.  
  211. function shuffle( array )
  212.         for i = 1, #array-1 do
  213.                 local j = math.random(i, #array)
  214.                 array[i], array[j] = array[j], array[i]
  215.         end
  216.  
  217.         return array
  218. end
  219.        
  220. function randpairs( tbl )
  221.         local kvs = {}
  222.  
  223.         for k, v in pairs(tbl) do
  224.                 kvs[#kvs+1] = { k, v }
  225.         end
  226.  
  227.         kvs = shuffle(kvs)
  228.  
  229.         local index = 0
  230.  
  231.         return
  232.                 function ()
  233.                         index = index + 1
  234.                         local kv = kvs[index]
  235.  
  236.                         if kv then
  237.                                 return kv[1], kv[2]
  238.                         else
  239.                                 return nil
  240.                         end
  241.                 end
  242. end
  243.  
  244. function solve( solver )
  245.         local varsarray = set_toarray(solver.variables)
  246.  
  247.         -- We want to try and solve the 'smallest' variables first. This also
  248.         -- allows a quick check for any empty variables.
  249.         table.sort(varsarray,
  250.                 function ( lhs, rhs )
  251.                         return set_count(lhs.values) < set_count(rhs.values)
  252.                 end)
  253.  
  254.         -- If there's any empty variables we're buggered.
  255.         if set_empty(varsarray[1].values) then
  256.                 error('empty')
  257.         end
  258.  
  259.         if solver.random then
  260.                 shuffle(varsarray)
  261.         end
  262.  
  263.         local iterator = (solver.random) and randpairs or pairs
  264.  
  265.         local function aux( solver, index )
  266.                 if index > #varsarray then
  267.                         coroutine.yield(solution(solver))
  268.  
  269.                         return
  270.                 end
  271.  
  272.                 local variable = varsarray[index]
  273.  
  274.                 -- If the variable is singular (i.e. has a value) we just carry on.            
  275.                 if set_singular(variable.values) then
  276.                         aux(solver, index+1)
  277.                 else
  278.                         -- local failed = true
  279.  
  280.                         for value, _ in iterator(variable.values) do
  281.                                 newframe(solver)
  282.  
  283.                                 -- printf(' try %s = %s', variable.name, value)
  284.  
  285.                                 local success = narrow(solver, variable, set(value))
  286.  
  287.                                 if success then
  288.                                         -- failed = false
  289.                                         -- print(debug.traceback())
  290.                                         aux(solver, index+1)
  291.                                 end
  292.                                
  293.                                 unwind(solver)
  294.                         end
  295.  
  296.                         -- if not failed then
  297.                                 return
  298.                         -- else
  299.                                 -- error('unsatisfiable')
  300.                         -- end
  301.                 end
  302.         end
  303.  
  304.         aux(solver, 1)
  305. end
  306.  
  307. function printTable( tbl )
  308.         for k, v in pairs(tbl) do
  309.                 print(k, v)
  310.         end
  311. end
  312.  
  313. function environment( varnames, varmap )
  314.         local result = {}
  315.  
  316.         local repeated = {}
  317.  
  318.         for _, varname in ipairs(varnames) do
  319.                 local variable = varmap[varname]
  320.                 assert(variable)
  321.                 assert(not repeated[variable], variable.name)
  322.                 result[#result+1] = variable
  323.                 repeated[variable] = true
  324.         end
  325.  
  326.         return result
  327. end
  328.  
  329. function vartos( var, values )
  330.         values = values or var.values
  331.         return string.format("%s=%s", var.name, set_tostring(values))
  332. end
  333.  
  334. local numcfails = 0
  335. local numfwfails = 0
  336.  
  337. local numcardsame = 0
  338. local numcarddiff = 0
  339.  
  340. local _constraints = {
  341.         NotEqual =
  342.                 function ( vars )
  343.                         assert(#vars == 2)
  344.  
  345.                         local var1, var2 = vars[1], vars[2]
  346.  
  347.                         return
  348.                                 function ( solver, var )
  349.                                         assert(var == var1 or var == var2)
  350.                                         local other = (var == var1) and var2 or var1
  351.  
  352.                                         if set_singular(var.values) then
  353.                                                 local newvalues = set_subtract(other.values, var.values)
  354.                                                 local oldvalues = other.values
  355.  
  356.                                                 -- printf('  NotEqual %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues))
  357.  
  358.                                                 local result = narrow(solver, other, newvalues)
  359.  
  360.                                                 -- printf('   %s', result and 'succeeded' or 'failed')
  361.  
  362.                                                 return result
  363.                                         end
  364.  
  365.                                         return true
  366.                                 end
  367.                 end,
  368.         Equal =
  369.                 function ( vars )
  370.                         assert(#vars == 2)
  371.  
  372.                         local var1, var2 = vars[1], vars[2]
  373.  
  374.                         return
  375.                                 function ( solver, var )
  376.                                         assert(var == var1 or var == var2)
  377.                                         local other = (var == var1) and var2 or var1
  378.  
  379.                                         if set_singular(var.values) then
  380.                                                 local newvalues = set_copy(var.values)
  381.                                                 local oldvalues = other.values
  382.  
  383.                                                 -- printf('  Equal %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues))
  384.  
  385.                                                 local result = narrow(solver, other, newvalues)
  386.  
  387.                                                 -- printf('   %s', result and 'succeeded' or 'failed')
  388.  
  389.                                                 return result
  390.                                         end
  391.  
  392.                                         return true
  393.                                 end
  394.                 end,
  395.         Cardinality =
  396.                 function ( vars, params )
  397.                         local min = params.min
  398.                         local max = params.max
  399.                         local value = params.value
  400.  
  401.                         assert(0 <= min)
  402.                         assert(min <= max)
  403.                         assert(value)
  404.                         assert(min <= #vars)
  405.  
  406.                         local variables = set_fromarray(varss)
  407.  
  408.                         return
  409.                                 function ( solver, var )
  410.                                         assert(variables[var])
  411.  
  412.                                         local possible = {}
  413.                                         local definite = {}
  414.                                         local numpossible = 0
  415.                                         local numdefinite = 0
  416.  
  417.                                         for variable, _ in pairs(variables) do
  418.                                                 if variable.values[value] then
  419.                                                         possible[variable] = true
  420.                                                         numpossible = numpossible + 1
  421.  
  422.                                                         if set_singular(variable.values) then
  423.                                                                 definite[variable] = true
  424.                                                                 numdefinite = numdefinite + 1
  425.                                                         end
  426.                                                 end
  427.                                         end
  428.  
  429.                                         if numpossible < min or numdefinite > max then
  430.                                                 numcfails = numcfails + 1
  431.                                                 return false
  432.                                         end
  433.  
  434.                                         -- We have enough possibles, but only just, so try and make
  435.                                         -- them definite to be sure.
  436.                                         if numpossible == min then
  437.                                                 for variable, _ in pairs(possible) do
  438.                                                         local result = narrow(solver, variable, set(value))
  439.  
  440.                                                         if not result then
  441.                                                                 numcfails = numcfails + 1
  442.                                                                 return false
  443.                                                         end
  444.                                                 end
  445.  
  446.                                                 return true
  447.                                         end
  448.  
  449.                                         -- We've got the maximum number of definites so make the
  450.                                         -- possibles that are not also definite into impossibles.
  451.                                         if numdefinite == max then
  452.                                                 local valueset = set(value)
  453.  
  454.                                                 for variable, _ in pairs(possible) do
  455.                                                         if not definite[variable] then
  456.                                                                 local newvalues = set_subtract(variable.values, valueset)
  457.  
  458.                                                                 -- printf('  Cardinality %s => %s', vartos(variable), set_tostring(newvalues))
  459.  
  460.                                                                 local result = narrow(solver, variable, newvalues)
  461.  
  462.                                                                 if not result then
  463.                                                                         numcfails = numcfails + 1
  464.                                                                         return false
  465.                                                                 end
  466.                                                         end
  467.                                                 end
  468.  
  469.                                                 return true
  470.                                         end
  471.  
  472.                                         return true
  473.                                 end
  474.                 end,
  475.         Sum =
  476.                 function ( vars, params )
  477.                         assert(#vars > 0)
  478.  
  479.                         local total = params.total
  480.                         assert(total)
  481.  
  482.                         return
  483.                                 function ( solver, var )
  484.                                         local count = 0
  485.  
  486.                                         -- print(#vars)
  487.  
  488.                                         for _, variable in ipairs(vars) do
  489.                                                 if not set_singular(variable.values) then
  490.                                                         return true
  491.                                                 else
  492.                                                         -- TODO: Need a 'get unique value' function.
  493.                                                         count = count + next(variable.values)
  494.                                                 end
  495.                                         end
  496.  
  497.                                         -- printf('  Sum count:%d total:%d', count, total)
  498.  
  499.                                         return count == total
  500.                                 end
  501.                 end,
  502.         FourWayConnected =
  503.                 function ( vars, params )
  504.                         local height = params.height
  505.                         local width = params.width
  506.                         local passable = params.passable
  507.  
  508.                         assert(#vars == width * height)
  509.  
  510.                         -- Build up the 'neighbour' relation.
  511.                         local neighbours = {}
  512.  
  513.                         local at =
  514.                                 function ( x, y )
  515.                                         local result = x + (y-1)*width
  516.                                         assert(1 <= result and result <= width * height)
  517.                                         return result
  518.                                 end
  519.  
  520.                         for y = 1, height do
  521.                                 for x = 1, width do
  522.                                         local peers = {}
  523.                                        
  524.                                         -- North
  525.                                         if y > 1 then
  526.                                                 peers[#peers+1] = vars[at(x, y-1)]
  527.                                         end
  528.  
  529.                                         -- East
  530.                                         if x < width then
  531.                                                 peers[#peers+1] = vars[at(x+1, y)]
  532.                                         end
  533.  
  534.                                         -- South
  535.                                         if y < height then
  536.                                                 peers[#peers+1] = vars[at(x, y+1)]
  537.                                         end
  538.  
  539.                                         -- West
  540.                                         if x > 1 then
  541.                                                 peers[#peers+1] = vars[at(x-1, y)]
  542.                                         end
  543.  
  544.                                         neighbours[vars[at(x, y)]] = peers
  545.                                 end
  546.                         end
  547.  
  548.                         return
  549.                                 function ( solver, var )
  550.                                         for _, variable in ipairs(vars) do
  551.                                                 if not set_singular(variable.values) then
  552.                                                         return true
  553.                                                 end
  554.                                         end
  555.  
  556.                                         -- First create a set of passable variables.
  557.                                         local passables = {}
  558.                                         local numpassables = 0
  559.  
  560.                                         for _, variable in ipairs(vars) do
  561.                                                 if passable[next(variable.values)] then
  562.                                                         passables[variable] = true
  563.                                                         numpassables = numpassables + 1
  564.                                                 end
  565.                                         end
  566.  
  567.                                         if set_count(passables) == 0 then
  568.                                                 return true
  569.                                         else
  570.                                                 local bfs = function( variable )
  571.                                                         local result = { [variable] = true }
  572.                                                         local frontier = { [variable] = true }
  573.                                                         local count = 1
  574.  
  575.                                                         while next(frontier) do
  576.                                                                 local newFrontier = {}
  577.  
  578.                                                                 for variable, _  in pairs(frontier) do
  579.                                                                         for _, other in ipairs(neighbours[variable]) do
  580.                                                                                 if not result[other] and not frontier[other] and passables[other] then
  581.                                                                                         count = count + 1
  582.                                                                                         result[other] = true
  583.                                                                                         newFrontier[other] = true
  584.                                                                                 end
  585.                                                                         end
  586.                                                                 end
  587.  
  588.                                                                 frontier = newFrontier
  589.                                                         end
  590.  
  591.                                                         return count
  592.                                                 end
  593.  
  594.                                                 local count = bfs(next(passables))
  595.  
  596.                                                 if count ~= numpassables then
  597.                                                         numfwfails = numfwfails + 1
  598.  
  599.                                                         if numcfails % 1000 == 0 or numfwfails % 1000 == 0 then
  600.                                                                 -- printf('#card:%d #fourw:%d', numcfails, numfwfails)
  601.                                                         end
  602.                                                 end
  603.  
  604.                                                 return count == numpassables
  605.                                         end
  606.                                 end
  607.                 end,
  608. }
  609.  
  610. -- tbl = {
  611. --     vars = { [<string>] = { <string> }+ },
  612. --     constraints = { { <string>, vars = { <string> }+ } }*
  613. -- }
  614.  
  615.  
  616. function newSolver( tbl )
  617.         local dump = tbl.dump == true
  618.  
  619.         local vars = tbl.vars
  620.  
  621.         local variables = {}
  622.         local varmap = {}
  623.  
  624.         for name, values in pairs(vars) do
  625.                 assert(#values >= 1)
  626.  
  627.                 local variable = { name = name, values = set_fromarray(values) }
  628.                 variables[variable] = true
  629.                 varmap[name] = variable
  630.  
  631.                 if dump then
  632.                         print(vartos(variable))
  633.                 end
  634.         end
  635.  
  636.         local constraints = {}
  637.  
  638.         for variable, _ in pairs(variables) do
  639.                 constraints[variable] = {}
  640.         end
  641.  
  642.         for index, v in ipairs(tbl.constraints) do
  643.                 local name = v[1]
  644.                 local vars = environment(v.vars, varmap)
  645.                 local params = v.params
  646.  
  647.                 local constraint = _constraints[name](vars, params)
  648.  
  649.                 for _, variable, _ in ipairs(vars) do
  650.                         if dump then
  651.                                 printf('constraint #%d %s', index, variable.name)
  652.                         end
  653.  
  654.                         constraints[variable][constraint] = true
  655.                 end
  656.         end
  657.  
  658.         local random = tbl.random and true or false
  659.  
  660.         local result = {
  661.                 variables = variables,
  662.                 varmap = varmap,
  663.                 constraints = constraints,
  664.                 stack = {},
  665.                 random = random,
  666.         }
  667.  
  668.         local coro = coroutine.create(function () solve(result) end)
  669.        
  670.         return
  671.                 function ()
  672.                         local status, result = coroutine.resume(coro)
  673.  
  674.                         if status then
  675.                                 return result
  676.                         else
  677.                                 error(result)
  678.                                 return nil
  679.                         end
  680.                 end
  681. end
  682.  
  683. function solution_tostring( solution )
  684.         local sorted = {}
  685.  
  686.         for varname, value in pairs(solution) do
  687.                 sorted[#sorted+1] = { varname, value }
  688.         end
  689.  
  690.         table.sort(sorted,
  691.                 function ( lhs, rhs )
  692.                         return lhs[1] < rhs[1]
  693.                 end)
  694.  
  695.         local parts = {}
  696.  
  697.         for _, data in pairs(sorted) do
  698.                 parts[#parts+1] = string.format("%s=%s", data[1], data[2])
  699.         end
  700.  
  701.         return string.format("{ %s }", table.concat(parts, ' '))
  702. end
  703.  
  704. function enumerate( solver )
  705.         local count = 0
  706.  
  707.         for solution in solver do
  708.                 count = count + 1
  709.                 printf('#%d %s', count, solution_tostring(solution))
  710.         end
  711.  
  712.         printf('%d solutions', count)
  713.         print()
  714.  
  715.         return count
  716. end
  717.  
  718. function enumerate2( solver )
  719.         local count = 0
  720.  
  721.         for solution in solver do
  722.                 count = count + 1
  723.  
  724.                 if count % 1000 == 0 then
  725.                         printf('#%d %s', count, solution_tostring(solution))
  726.                 end
  727.         end
  728.  
  729.         printf('%d solutions', count)
  730.         print()
  731. end
  732.  
  733. local solver = newSolver {
  734.         vars = {
  735.                 a = { 'x', 'y', 'z' },
  736.                 b = { 'x', 'y', 'z' },
  737.                 c = { 'x', 'y', 'z' },
  738.         },
  739.         constraints = {},
  740. }
  741.  
  742. enumerate(solver)
  743.  
  744. local solver = newSolver {
  745.         dump = true,
  746.         vars = {
  747.                 a = { 'x', 'y', 'z' },
  748.                 b = { 'x', 'y', 'z' },
  749.                 c = { 'x', 'y', 'z' },
  750.         },
  751.         constraints = {
  752.                 { 'NotEqual', vars = { 'a', 'b' } },
  753.                 { 'NotEqual', vars = { 'a', 'c' } },
  754.                 { 'NotEqual', vars = { 'b', 'c' } },
  755.         },
  756. }
  757.  
  758. assert(enumerate(solver) == 6)
  759.  
  760. local solver = newSolver {
  761.         vars = {
  762.                 a = { 'x', 'y', 'z' },
  763.                 b = { 'x', 'y', 'z' },
  764.                 c = { 'x', 'y', 'z' },
  765.         },
  766.         constraints = {
  767.                 { 'Equal', vars = { 'a', 'b' } },
  768.                 { 'Equal', vars = { 'a', 'c' } },
  769.                 { 'Equal', vars = { 'b', 'c' } },
  770.         },
  771. }
  772.  
  773. assert(enumerate(solver) == 3)
  774.  
  775. function genNotEquals( varnames )
  776.         local result = {}
  777.  
  778.         for i = 1, #varnames do
  779.                 for j = i+1, #varnames do
  780.                         result[#result+1] = { 'NotEqual', vars = { varnames[i], varnames[j] } }
  781.                 end
  782.         end
  783.  
  784.         return result
  785. end
  786.  
  787. local solver = newSolver {
  788.         vars = {
  789.                 a = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  790.                 b = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  791.                 c = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  792.                 d = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  793.                 e = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  794.                 f = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  795.                 g = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  796.                 h = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  797.                 i = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  798.                 j = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  799.                 k = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  800.                 l = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  801.         },
  802.         constraints = genNotEquals { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l' } ,
  803. }
  804.  
  805. -- This should have 12! = 479,001,600 solutions.
  806. -- enumerate2(solver)
  807.  
  808. local solver = newSolver {
  809.         vars = {
  810.                 a = { '0', '1', },
  811.                 b = { '0', '1', },
  812.                 c = { '0', '1', },
  813.                 d = { '0', '1', },
  814.         },
  815.         constraints = {
  816.                 { 'Cardinality', vars = { 'a', 'b', 'c', 'd' }, params = { min = 2, max = 3, value = '1' } },
  817.         },
  818. }
  819.  
  820. assert(enumerate(solver) == 10)
  821.  
  822. local solver = newSolver {
  823.         vars = {
  824.                 a = { '0', '1', },
  825.                 b = { '0', '1', },
  826.                 c = { '0', '1', },
  827.                 d = { '0', '1', },
  828.                 e = { '0', '1', },
  829.                 f = { '0', '1', },
  830.                 g = { '0', '1', },
  831.                 h = { '0', '1', },
  832.                 i = { '0', '1', },
  833.                 j = { '0', '1', },
  834.                 k = { '0', '1', },
  835.                 l = { '0', '1', },
  836.                 m = { '0', '1', },
  837.                 n = { '0', '1', },
  838.                 o = { '0', '1', },
  839.                 r = { '0', '1', },
  840.                 s = { '0', '1', },
  841.                 t = { '0', '1', },
  842.                 u = { '0', '1', },
  843.                 v = { '0', '1', },
  844.                 w = { '0', '1', },
  845.                 x = { '0', '1', },
  846.                 y = { '0', '1', },
  847.                 z = { '0', '1', },
  848.         },
  849.         constraints = {
  850.                 { 'Cardinality', vars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' }, params = { min = 13, max = 13, value = '1' } },
  851.         },
  852. }
  853.  
  854. -- enumerate2(solver)
  855.  
  856. local solver = newSolver {
  857.         vars = {
  858.                 a = { 0, 1, 2, 3, 4, 5, },
  859.                 b = { 0, 1, 2, 3, 4, 5, },
  860.                 c = { 0, 1, 2, 3, 4, 5, },
  861.                 d = { 0, 1, 2, 3, 4, 5, },
  862.         },
  863.         constraints = {
  864.                 { 'Sum', vars = { 'a', 'b', 'c', 'd' }, params = { total = 2 } },
  865.         },
  866. }
  867.  
  868. enumerate(solver)
  869.  
  870. local solver = newSolver {
  871.         vars = {
  872.                 a11 = { '#', '.' },
  873.                 a21 = { '#', '.' },
  874.                 a31 = { '#', '.' },
  875.                 a12 = { '#', '.' },
  876.                 a22 = { '#', '.' },
  877.                 a32 = { '#', '.' },
  878.                 a13 = { '#', '.' },
  879.                 a23 = { '#', '.' },
  880.                 a33 = { '#', '.' },
  881.         },
  882.         constraints = {
  883.                 { 'FourWayConnected', vars = { 'a11', 'a21', 'a31', 'a12', 'a22', 'a32', 'a13', 'a23', 'a33' }, params = { width = 3, height = 3, passable = { ['.'] = true } } },
  884.         },
  885. }
  886.  
  887. -- enumerate3(solver)
  888.  
  889. function enumeratemap( width, height, minf, maxf )
  890.         assert(width >= 3)
  891.         assert(height >= 3)
  892.         assert(0 <= minf and minf <= 1)
  893.         assert(0 <= maxf and maxf <= 1)
  894.         assert(minf <= maxf)
  895.  
  896.         local min = math.floor((width*height)*minf)
  897.         local max = math.floor((width*height)*maxf)
  898.  
  899.         local vars = {}
  900.         local varnames = {}
  901.         local domain = { '#', '.' }
  902.         local passable = { ['.'] = true }
  903.  
  904.         for y = 1, height do
  905.                 for x = 1, width do
  906.                         local var = string.format("%d_%d", x, y)
  907.                         vars[var] = domain
  908.                         varnames[#varnames+1] = var
  909.                 end
  910.         end
  911.  
  912.         local solver = newSolver {
  913.                 random = true,
  914.                 vars = vars,
  915.                 constraints = {
  916.                         {
  917.                                 'FourWayConnected',
  918.                                 vars = varnames,
  919.                                 params = {
  920.                                         width = width,
  921.                                         height = height,
  922.                                         passable = passable
  923.                                 },
  924.                         },
  925.                         {
  926.                                 'Cardinality',
  927.                                 vars = varnames,
  928.                                 params = {
  929.                                         min = min,
  930.                                         max = max,
  931.                                         value = '.'
  932.                                 }
  933.                         }
  934.                 },             
  935.         }
  936.  
  937.         local count = 0
  938.  
  939.         for solution in solver do
  940.                 count = count + 1
  941.  
  942.                 printf('#%d', count)
  943.  
  944.                 for y = 1, height do
  945.                         local line = {}
  946.                         for x = 1, width do
  947.                                 local var = string.format("%d_%d", x, y)
  948.                                 line[#line+1] = solution[var]
  949.                         end
  950.                         print(table.concat(line))
  951.                 end
  952.                
  953.                 print()
  954.         end
  955.  
  956.         printf('%d solutions', count)
  957.         print()
  958.  
  959.         return count
  960. end
  961.  
  962. enumeratemap(10, 9, 0.2, 0.5)
clone this paste RAW Paste Data