Advertisement
Guest User

Simple Constraint Solver in Lua

a guest
May 11th, 2013
118
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 19.94 KB | None | 0 0
  1. --
  2. -- Simple Finite Domain, Constraint Propagation Solver
  3. --
  4. -- Inspired by 'Fast Procedural Level Population with Playability Constraints'
  5. -- by Ian Horswill and Leif Foged.
  6. --
  7. -- paper: http://www.aaai.org/ocs/index.php/AIIDE/AIIDE12/paper/view/5466/5691
  8. -- code:  http://code.google.com/p/constraint-thingy/
  9. --
  10. -- I haven't implemented any of the more advanced parts of the paper like
  11. -- interval domains of the DAG based pathing stuff.
  12. --
  13. -- It's not very efficient as it uses tables (Lua's built in hash map) for
  14. -- representing domains instead of bit arrays but it is simple.
  15. --
  16.  
  17.  
  18. -- variable = {
  19. --     name = <string>,
  20. --     values = { [<value>] = true }*
  21. -- }
  22.  
  23. -- constraint = function ( solver, variable )
  24.  
  25. -- solver = {
  26. --     variables = { [variable] = true }+
  27. --     varmap = { [<string>] = variable }+
  28. --     constraints = { [variable] = { constraint }+ }+
  29. --     stack = { { variables } }*
  30. -- }
  31.  
  32. function printf( ... ) print(string.format(...)) end
  33.  
  34. function set( value )
  35.     return { [value] = true }
  36. end
  37.  
  38. function set_empty( set )
  39.     return next(set) == nil
  40. end
  41.  
  42. function set_singular( set )
  43.     local k = next(set)
  44.  
  45.     return k ~= nil and next(set, k) == nil
  46. end
  47.  
  48. function set_count( set )
  49.     local result = 0
  50.  
  51.     for k, _ in pairs(set) do
  52.         result = result + 1
  53.     end
  54.  
  55.     return result
  56. end
  57.  
  58. function set_copy( set )
  59.     local result = {}
  60.  
  61.     for k, _ in pairs(set) do
  62.         result[k] = true
  63.     end
  64.  
  65.     return result
  66. end
  67.  
  68. function set_intersection( set1, set2 )
  69.     local result = {}
  70.  
  71.     for k, _ in pairs(set1) do
  72.         if set2[k] then
  73.             result[k] = true
  74.         end
  75.     end
  76.  
  77.     return result
  78. end
  79.  
  80. function set_subtract( set1, set2 )
  81.     local result = {}
  82.  
  83.     for k, _ in pairs(set1) do
  84.         if not set2[k] then
  85.             result[k] = true
  86.         end
  87.     end
  88.  
  89.     return result
  90. end
  91.  
  92. function set_equal( set1, set2 )
  93.     for k, _ in pairs(set1) do
  94.         if not set2[k] then
  95.             return false
  96.         end
  97.     end
  98.  
  99.     for k,_ in pairs(set2) do
  100.         if not set1[k] then
  101.             return false
  102.         end
  103.     end
  104.  
  105.     return true
  106. end
  107.  
  108. function set_toarray( set )
  109.     local result = {}
  110.  
  111.     for k, _ in pairs(set) do
  112.         result[#result+1] = k
  113.     end
  114.  
  115.     return result
  116. end
  117.  
  118. function set_fromarray( array )
  119.     local result = {}
  120.  
  121.     for _, v in ipairs(array) do
  122.         result[v] = true
  123.     end
  124.  
  125.     return result
  126. end
  127.  
  128. function set_tostring( set )
  129.     local parts = {}
  130.  
  131.     for k, _ in pairs(set) do
  132.         parts[#parts+1] = k
  133.     end
  134.  
  135.     return string.format("{%s}", table.concat(parts, ' '))
  136. end
  137.  
  138. function print_stack( solver )
  139.     for index, frame in ipairs(solver.stack) do
  140.         printf('%d ->', index)
  141.  
  142.         for variable, values in pairs(frame) do
  143.             printf('  %s = %s', variable.name, set_tostring(values))
  144.         end
  145.     end
  146. end
  147.  
  148. function newframe( solver )
  149.     local stack = solver.stack
  150.     local frame = {}
  151.     stack[#stack+1] = frame
  152. end
  153.  
  154. function save( solver, var )
  155.     local stack = solver.stack
  156.     local frame = stack[#stack]
  157.    
  158.     frame[var] = var.values
  159. end
  160.  
  161. function unwind( solver )
  162.     local stack = solver.stack
  163.     local frame = stack[#stack]
  164.  
  165.     for variable, values in pairs(frame) do
  166.         variable.values = values
  167.     end
  168.  
  169.     stack[#stack] = nil
  170. end
  171.  
  172. function solution( solver )
  173.     local result = {}
  174.  
  175.     for variable, _ in pairs(solver.variables) do
  176.         assert(set_singular(variable.values))
  177.  
  178.         result[variable.name] = next(variable.values)
  179.     end
  180.  
  181.     return result
  182. end
  183.  
  184. function narrow( solver, var, set )
  185.     -- printf('  narrow %s to %s', vartos(var), set_tostring(set))
  186.     local newvalues = set_intersection(var.values, set)
  187.  
  188.     if set_empty(newvalues) then
  189.         return false
  190.     end
  191.  
  192.     if not set_equal(var.values, newvalues) then
  193.         save(solver, var)
  194.  
  195.         var.values = newvalues
  196.  
  197.         local constraints = solver.constraints[var]
  198.  
  199.         -- printf('  propagate %s over %d constraints', vartos(var), set_count(constraints))
  200.  
  201.         for constraint, _ in pairs(constraints) do
  202.             if not constraint(solver, var) then
  203.                 return false
  204.             end
  205.         end
  206.     end
  207.  
  208.     return true
  209. end
  210.  
  211. function shuffle( array )
  212.     for i = 1, #array-1 do
  213.         local j = math.random(i, #array)
  214.         array[i], array[j] = array[j], array[i]
  215.     end
  216.  
  217.     return array
  218. end
  219.    
  220. function randpairs( tbl )
  221.     local kvs = {}
  222.  
  223.     for k, v in pairs(tbl) do
  224.         kvs[#kvs+1] = { k, v }
  225.     end
  226.  
  227.     kvs = shuffle(kvs)
  228.  
  229.     local index = 0
  230.  
  231.     return
  232.         function ()
  233.             index = index + 1
  234.             local kv = kvs[index]
  235.  
  236.             if kv then
  237.                 return kv[1], kv[2]
  238.             else
  239.                 return nil
  240.             end
  241.         end
  242. end
  243.  
  244. function solve( solver )
  245.     local varsarray = set_toarray(solver.variables)
  246.  
  247.     -- We want to try and solve the 'smallest' variables first. This also
  248.     -- allows a quick check for any empty variables.
  249.     table.sort(varsarray,
  250.         function ( lhs, rhs )
  251.             return set_count(lhs.values) < set_count(rhs.values)
  252.         end)
  253.  
  254.     -- If there's any empty variables we're buggered.
  255.     if set_empty(varsarray[1].values) then
  256.         error('empty')
  257.     end
  258.  
  259.     if solver.random then
  260.         shuffle(varsarray)
  261.     end
  262.  
  263.     local iterator = (solver.random) and randpairs or pairs
  264.  
  265.     local function aux( solver, index )
  266.         if index > #varsarray then
  267.             coroutine.yield(solution(solver))
  268.  
  269.             return
  270.         end
  271.  
  272.         local variable = varsarray[index]
  273.  
  274.         -- If the variable is singular (i.e. has a value) we just carry on.    
  275.         if set_singular(variable.values) then
  276.             aux(solver, index+1)
  277.         else
  278.             -- local failed = true
  279.  
  280.             for value, _ in iterator(variable.values) do
  281.                 newframe(solver)
  282.  
  283.                 -- printf(' try %s = %s', variable.name, value)
  284.  
  285.                 local success = narrow(solver, variable, set(value))
  286.  
  287.                 if success then
  288.                     -- failed = false
  289.                     -- print(debug.traceback())
  290.                     aux(solver, index+1)
  291.                 end
  292.                
  293.                 unwind(solver)
  294.             end
  295.  
  296.             -- if not failed then
  297.                 return
  298.             -- else
  299.                 -- error('unsatisfiable')
  300.             -- end
  301.         end
  302.     end
  303.  
  304.     aux(solver, 1)
  305. end
  306.  
  307. function printTable( tbl )
  308.     for k, v in pairs(tbl) do
  309.         print(k, v)
  310.     end
  311. end
  312.  
  313. function environment( varnames, varmap )
  314.     local result = {}
  315.  
  316.     local repeated = {}
  317.  
  318.     for _, varname in ipairs(varnames) do
  319.         local variable = varmap[varname]
  320.         assert(variable)
  321.         assert(not repeated[variable], variable.name)
  322.         result[#result+1] = variable
  323.         repeated[variable] = true
  324.     end
  325.  
  326.     return result
  327. end
  328.  
  329. function vartos( var, values )
  330.     values = values or var.values
  331.     return string.format("%s=%s", var.name, set_tostring(values))
  332. end
  333.  
  334. local numcfails = 0
  335. local numfwfails = 0
  336.  
  337. local numcardsame = 0
  338. local numcarddiff = 0
  339.  
  340. local _constraints = {
  341.     NotEqual =
  342.         function ( vars )
  343.             assert(#vars == 2)
  344.  
  345.             local var1, var2 = vars[1], vars[2]
  346.  
  347.             return
  348.                 function ( solver, var )
  349.                     assert(var == var1 or var == var2)
  350.                     local other = (var == var1) and var2 or var1
  351.  
  352.                     if set_singular(var.values) then
  353.                         local newvalues = set_subtract(other.values, var.values)
  354.                         local oldvalues = other.values
  355.  
  356.                         -- printf('  NotEqual %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues))
  357.  
  358.                         local result = narrow(solver, other, newvalues)
  359.  
  360.                         -- printf('   %s', result and 'succeeded' or 'failed')
  361.  
  362.                         return result
  363.                     end
  364.  
  365.                     return true
  366.                 end
  367.         end,
  368.     Equal =
  369.         function ( vars )
  370.             assert(#vars == 2)
  371.  
  372.             local var1, var2 = vars[1], vars[2]
  373.  
  374.             return
  375.                 function ( solver, var )
  376.                     assert(var == var1 or var == var2)
  377.                     local other = (var == var1) and var2 or var1
  378.  
  379.                     if set_singular(var.values) then
  380.                         local newvalues = set_copy(var.values)
  381.                         local oldvalues = other.values
  382.  
  383.                         -- printf('  Equal %s %s => %s', vartos(var), vartos(other, oldvalues), set_tostring(newvalues))
  384.  
  385.                         local result = narrow(solver, other, newvalues)
  386.  
  387.                         -- printf('   %s', result and 'succeeded' or 'failed')
  388.  
  389.                         return result
  390.                     end
  391.  
  392.                     return true
  393.                 end
  394.         end,
  395.     Cardinality =
  396.         function ( vars, params )
  397.             local min = params.min
  398.             local max = params.max
  399.             local value = params.value
  400.  
  401.             assert(0 <= min)
  402.             assert(min <= max)
  403.             assert(value)
  404.             assert(min <= #vars)
  405.  
  406.             local variables = set_fromarray(varss)
  407.  
  408.             return
  409.                 function ( solver, var )
  410.                     assert(variables[var])
  411.  
  412.                     local possible = {}
  413.                     local definite = {}
  414.                     local numpossible = 0
  415.                     local numdefinite = 0
  416.  
  417.                     for variable, _ in pairs(variables) do
  418.                         if variable.values[value] then
  419.                             possible[variable] = true
  420.                             numpossible = numpossible + 1
  421.  
  422.                             if set_singular(variable.values) then
  423.                                 definite[variable] = true
  424.                                 numdefinite = numdefinite + 1
  425.                             end
  426.                         end
  427.                     end
  428.  
  429.                     if numpossible < min or numdefinite > max then
  430.                         numcfails = numcfails + 1
  431.                         return false
  432.                     end
  433.  
  434.                     -- We have enough possibles, but only just, so try and make
  435.                     -- them definite to be sure.
  436.                     if numpossible == min then
  437.                         for variable, _ in pairs(possible) do
  438.                             local result = narrow(solver, variable, set(value))
  439.  
  440.                             if not result then
  441.                                 numcfails = numcfails + 1
  442.                                 return false
  443.                             end
  444.                         end
  445.  
  446.                         return true
  447.                     end
  448.  
  449.                     -- We've got the maximum number of definites so make the
  450.                     -- possibles that are not also definite into impossibles.
  451.                     if numdefinite == max then
  452.                         local valueset = set(value)
  453.  
  454.                         for variable, _ in pairs(possible) do
  455.                             if not definite[variable] then
  456.                                 local newvalues = set_subtract(variable.values, valueset)
  457.  
  458.                                 -- printf('  Cardinality %s => %s', vartos(variable), set_tostring(newvalues))
  459.  
  460.                                 local result = narrow(solver, variable, newvalues)
  461.  
  462.                                 if not result then
  463.                                     numcfails = numcfails + 1
  464.                                     return false
  465.                                 end
  466.                             end
  467.                         end
  468.  
  469.                         return true
  470.                     end
  471.  
  472.                     return true
  473.                 end
  474.         end,
  475.     Sum =
  476.         function ( vars, params )
  477.             assert(#vars > 0)
  478.  
  479.             local total = params.total
  480.             assert(total)
  481.  
  482.             return
  483.                 function ( solver, var )
  484.                     local count = 0
  485.  
  486.                     -- print(#vars)
  487.  
  488.                     for _, variable in ipairs(vars) do
  489.                         if not set_singular(variable.values) then
  490.                             return true
  491.                         else
  492.                             -- TODO: Need a 'get unique value' function.
  493.                             count = count + next(variable.values)
  494.                         end
  495.                     end
  496.  
  497.                     -- printf('  Sum count:%d total:%d', count, total)
  498.  
  499.                     return count == total
  500.                 end
  501.         end,
  502.     FourWayConnected =
  503.         function ( vars, params )
  504.             local height = params.height
  505.             local width = params.width
  506.             local passable = params.passable
  507.  
  508.             assert(#vars == width * height)
  509.  
  510.             -- Build up the 'neighbour' relation.
  511.             local neighbours = {}
  512.  
  513.             local at =
  514.                 function ( x, y )
  515.                     local result = x + (y-1)*width
  516.                     assert(1 <= result and result <= width * height)
  517.                     return result
  518.                 end
  519.  
  520.             for y = 1, height do
  521.                 for x = 1, width do
  522.                     local peers = {}
  523.                    
  524.                     -- North
  525.                     if y > 1 then
  526.                         peers[#peers+1] = vars[at(x, y-1)]
  527.                     end
  528.  
  529.                     -- East
  530.                     if x < width then
  531.                         peers[#peers+1] = vars[at(x+1, y)]
  532.                     end
  533.  
  534.                     -- South
  535.                     if y < height then
  536.                         peers[#peers+1] = vars[at(x, y+1)]
  537.                     end
  538.  
  539.                     -- West
  540.                     if x > 1 then
  541.                         peers[#peers+1] = vars[at(x-1, y)]
  542.                     end
  543.  
  544.                     neighbours[vars[at(x, y)]] = peers
  545.                 end
  546.             end
  547.  
  548.             return
  549.                 function ( solver, var )
  550.                     for _, variable in ipairs(vars) do
  551.                         if not set_singular(variable.values) then
  552.                             return true
  553.                         end
  554.                     end
  555.  
  556.                     -- First create a set of passable variables.
  557.                     local passables = {}
  558.                     local numpassables = 0
  559.  
  560.                     for _, variable in ipairs(vars) do
  561.                         if passable[next(variable.values)] then
  562.                             passables[variable] = true
  563.                             numpassables = numpassables + 1
  564.                         end
  565.                     end
  566.  
  567.                     if set_count(passables) == 0 then
  568.                         return true
  569.                     else
  570.                         local bfs = function( variable )
  571.                             local result = { [variable] = true }
  572.                             local frontier = { [variable] = true }
  573.                             local count = 1
  574.  
  575.                             while next(frontier) do
  576.                                 local newFrontier = {}
  577.  
  578.                                 for variable, _  in pairs(frontier) do
  579.                                     for _, other in ipairs(neighbours[variable]) do
  580.                                         if not result[other] and not frontier[other] and passables[other] then
  581.                                             count = count + 1
  582.                                             result[other] = true
  583.                                             newFrontier[other] = true
  584.                                         end
  585.                                     end
  586.                                 end
  587.  
  588.                                 frontier = newFrontier
  589.                             end
  590.  
  591.                             return count
  592.                         end
  593.  
  594.                         local count = bfs(next(passables))
  595.  
  596.                         if count ~= numpassables then
  597.                             numfwfails = numfwfails + 1
  598.  
  599.                             if numcfails % 1000 == 0 or numfwfails % 1000 == 0 then
  600.                                 -- printf('#card:%d #fourw:%d', numcfails, numfwfails)
  601.                             end
  602.                         end
  603.  
  604.                         return count == numpassables
  605.                     end
  606.                 end
  607.         end,
  608. }
  609.  
  610. -- tbl = {
  611. --     vars = { [<string>] = { <string> }+ },
  612. --     constraints = { { <string>, vars = { <string> }+ } }*
  613. -- }
  614.  
  615.  
  616. function newSolver( tbl )
  617.     local dump = tbl.dump == true
  618.  
  619.     local vars = tbl.vars
  620.  
  621.     local variables = {}
  622.     local varmap = {}
  623.  
  624.     for name, values in pairs(vars) do
  625.         assert(#values >= 1)
  626.  
  627.         local variable = { name = name, values = set_fromarray(values) }
  628.         variables[variable] = true
  629.         varmap[name] = variable
  630.  
  631.         if dump then
  632.             print(vartos(variable))
  633.         end
  634.     end
  635.  
  636.     local constraints = {}
  637.  
  638.     for variable, _ in pairs(variables) do
  639.         constraints[variable] = {}
  640.     end
  641.  
  642.     for index, v in ipairs(tbl.constraints) do
  643.         local name = v[1]
  644.         local vars = environment(v.vars, varmap)
  645.         local params = v.params
  646.  
  647.         local constraint = _constraints[name](vars, params)
  648.  
  649.         for _, variable, _ in ipairs(vars) do
  650.             if dump then
  651.                 printf('constraint #%d %s', index, variable.name)
  652.             end
  653.  
  654.             constraints[variable][constraint] = true
  655.         end
  656.     end
  657.  
  658.     local random = tbl.random and true or false
  659.  
  660.     local result = {
  661.         variables = variables,
  662.         varmap = varmap,
  663.         constraints = constraints,
  664.         stack = {},
  665.         random = random,
  666.     }
  667.  
  668.     local coro = coroutine.create(function () solve(result) end)
  669.    
  670.     return
  671.         function ()
  672.             local status, result = coroutine.resume(coro)
  673.  
  674.             if status then
  675.                 return result
  676.             else
  677.                 error(result)
  678.                 return nil
  679.             end
  680.         end
  681. end
  682.  
  683. function solution_tostring( solution )
  684.     local sorted = {}
  685.  
  686.     for varname, value in pairs(solution) do
  687.         sorted[#sorted+1] = { varname, value }
  688.     end
  689.  
  690.     table.sort(sorted,
  691.         function ( lhs, rhs )
  692.             return lhs[1] < rhs[1]
  693.         end)
  694.  
  695.     local parts = {}
  696.  
  697.     for _, data in pairs(sorted) do
  698.         parts[#parts+1] = string.format("%s=%s", data[1], data[2])
  699.     end
  700.  
  701.     return string.format("{ %s }", table.concat(parts, ' '))
  702. end
  703.  
  704. function enumerate( solver )
  705.     local count = 0
  706.  
  707.     for solution in solver do
  708.         count = count + 1
  709.         printf('#%d %s', count, solution_tostring(solution))
  710.     end
  711.  
  712.     printf('%d solutions', count)
  713.     print()
  714.  
  715.     return count
  716. end
  717.  
  718. function enumerate2( solver )
  719.     local count = 0
  720.  
  721.     for solution in solver do
  722.         count = count + 1
  723.  
  724.         if count % 1000 == 0 then
  725.             printf('#%d %s', count, solution_tostring(solution))
  726.         end
  727.     end
  728.  
  729.     printf('%d solutions', count)
  730.     print()
  731. end
  732.  
  733. local solver = newSolver {
  734.     vars = {
  735.         a = { 'x', 'y', 'z' },
  736.         b = { 'x', 'y', 'z' },
  737.         c = { 'x', 'y', 'z' },
  738.     },
  739.     constraints = {},
  740. }
  741.  
  742. enumerate(solver)
  743.  
  744. local solver = newSolver {
  745.     dump = true,
  746.     vars = {
  747.         a = { 'x', 'y', 'z' },
  748.         b = { 'x', 'y', 'z' },
  749.         c = { 'x', 'y', 'z' },
  750.     },
  751.     constraints = {
  752.         { 'NotEqual', vars = { 'a', 'b' } },
  753.         { 'NotEqual', vars = { 'a', 'c' } },
  754.         { 'NotEqual', vars = { 'b', 'c' } },
  755.     },
  756. }
  757.  
  758. assert(enumerate(solver) == 6)
  759.  
  760. local solver = newSolver {
  761.     vars = {
  762.         a = { 'x', 'y', 'z' },
  763.         b = { 'x', 'y', 'z' },
  764.         c = { 'x', 'y', 'z' },
  765.     },
  766.     constraints = {
  767.         { 'Equal', vars = { 'a', 'b' } },
  768.         { 'Equal', vars = { 'a', 'c' } },
  769.         { 'Equal', vars = { 'b', 'c' } },
  770.     },
  771. }
  772.  
  773. assert(enumerate(solver) == 3)
  774.  
  775. function genNotEquals( varnames )
  776.     local result = {}
  777.  
  778.     for i = 1, #varnames do
  779.         for j = i+1, #varnames do
  780.             result[#result+1] = { 'NotEqual', vars = { varnames[i], varnames[j] } }
  781.         end
  782.     end
  783.  
  784.     return result
  785. end
  786.  
  787. local solver = newSolver {
  788.     vars = {
  789.         a = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  790.         b = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  791.         c = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  792.         d = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  793.         e = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  794.         f = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  795.         g = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  796.         h = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  797.         i = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  798.         j = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  799.         k = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  800.         l = { '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',  },
  801.     },
  802.     constraints = genNotEquals { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l' } ,
  803. }
  804.  
  805. -- This should have 12! = 479,001,600 solutions.
  806. -- enumerate2(solver)
  807.  
  808. local solver = newSolver {
  809.     vars = {
  810.         a = { '0', '1', },
  811.         b = { '0', '1', },
  812.         c = { '0', '1', },
  813.         d = { '0', '1', },
  814.     },
  815.     constraints = {
  816.         { 'Cardinality', vars = { 'a', 'b', 'c', 'd' }, params = { min = 2, max = 3, value = '1' } },
  817.     },
  818. }
  819.  
  820. assert(enumerate(solver) == 10)
  821.  
  822. local solver = newSolver {
  823.     vars = {
  824.         a = { '0', '1', },
  825.         b = { '0', '1', },
  826.         c = { '0', '1', },
  827.         d = { '0', '1', },
  828.         e = { '0', '1', },
  829.         f = { '0', '1', },
  830.         g = { '0', '1', },
  831.         h = { '0', '1', },
  832.         i = { '0', '1', },
  833.         j = { '0', '1', },
  834.         k = { '0', '1', },
  835.         l = { '0', '1', },
  836.         m = { '0', '1', },
  837.         n = { '0', '1', },
  838.         o = { '0', '1', },
  839.         r = { '0', '1', },
  840.         s = { '0', '1', },
  841.         t = { '0', '1', },
  842.         u = { '0', '1', },
  843.         v = { '0', '1', },
  844.         w = { '0', '1', },
  845.         x = { '0', '1', },
  846.         y = { '0', '1', },
  847.         z = { '0', '1', },
  848.     },
  849.     constraints = {
  850.         { 'Cardinality', vars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' }, params = { min = 13, max = 13, value = '1' } },
  851.     },
  852. }
  853.  
  854. -- enumerate2(solver)
  855.  
  856. local solver = newSolver {
  857.     vars = {
  858.         a = { 0, 1, 2, 3, 4, 5, },
  859.         b = { 0, 1, 2, 3, 4, 5, },
  860.         c = { 0, 1, 2, 3, 4, 5, },
  861.         d = { 0, 1, 2, 3, 4, 5, },
  862.     },
  863.     constraints = {
  864.         { 'Sum', vars = { 'a', 'b', 'c', 'd' }, params = { total = 2 } },
  865.     },
  866. }
  867.  
  868. enumerate(solver)
  869.  
  870. local solver = newSolver {
  871.     vars = {
  872.         a11 = { '#', '.' },
  873.         a21 = { '#', '.' },
  874.         a31 = { '#', '.' },
  875.         a12 = { '#', '.' },
  876.         a22 = { '#', '.' },
  877.         a32 = { '#', '.' },
  878.         a13 = { '#', '.' },
  879.         a23 = { '#', '.' },
  880.         a33 = { '#', '.' },
  881.     },
  882.     constraints = {
  883.         { 'FourWayConnected', vars = { 'a11', 'a21', 'a31', 'a12', 'a22', 'a32', 'a13', 'a23', 'a33' }, params = { width = 3, height = 3, passable = { ['.'] = true } } },
  884.     },
  885. }
  886.  
  887. -- enumerate3(solver)
  888.  
  889. function enumeratemap( width, height, minf, maxf )
  890.     assert(width >= 3)
  891.     assert(height >= 3)
  892.     assert(0 <= minf and minf <= 1)
  893.     assert(0 <= maxf and maxf <= 1)
  894.     assert(minf <= maxf)
  895.  
  896.     local min = math.floor((width*height)*minf)
  897.     local max = math.floor((width*height)*maxf)
  898.  
  899.     local vars = {}
  900.     local varnames = {}
  901.     local domain = { '#', '.' }
  902.     local passable = { ['.'] = true }
  903.  
  904.     for y = 1, height do
  905.         for x = 1, width do
  906.             local var = string.format("%d_%d", x, y)
  907.             vars[var] = domain
  908.             varnames[#varnames+1] = var
  909.         end
  910.     end
  911.  
  912.     local solver = newSolver {
  913.         random = true,
  914.         vars = vars,
  915.         constraints = {
  916.             {
  917.                 'FourWayConnected',
  918.                 vars = varnames,
  919.                 params = {
  920.                     width = width,
  921.                     height = height,
  922.                     passable = passable
  923.                 },
  924.             },
  925.             {
  926.                 'Cardinality',
  927.                 vars = varnames,
  928.                 params = {
  929.                     min = min,
  930.                     max = max,
  931.                     value = '.'
  932.                 }
  933.             }
  934.         },     
  935.     }
  936.  
  937.     local count = 0
  938.  
  939.     for solution in solver do
  940.         count = count + 1
  941.  
  942.         printf('#%d', count)
  943.  
  944.         for y = 1, height do
  945.             local line = {}
  946.             for x = 1, width do
  947.                 local var = string.format("%d_%d", x, y)
  948.                 line[#line+1] = solution[var]
  949.             end
  950.             print(table.concat(line))
  951.         end
  952.        
  953.         print()
  954.     end
  955.  
  956.     printf('%d solutions', count)
  957.     print()
  958.  
  959.     return count
  960. end
  961.  
  962. enumeratemap(10, 9, 0.2, 0.5)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement