Advertisement
Guest User

Untitled

a guest
Oct 26th, 2016
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.10 KB | None | 0 0
  1. from collections import OrderedDict, defaultdict
  2. import os
  3. import sys
  4. import subprocess
  5. from copy import deepcopy
  6. # from monosat import *
  7. # enable using multiple levels of dict keys automatically, even if nested levels don't yet exist
  8. NestedDict = lambda: defaultdict(NestedDict)
  9.  
  10.  
  11. class SATGenerator(object):
  12. def __init__(self, traces, maxx, maxy, maxz):
  13. self.maxx = maxx
  14. self.maxy = maxy
  15. self.maxz = maxz
  16.  
  17. self.output_path = 'output.cnf'
  18. self.vars = []
  19. self.nodes = []
  20. self.traces = traces
  21. self.visited_neighbor_edges = []
  22.  
  23. self.setup_output()
  24. self.create_vars()
  25. self.create_clauses()
  26. self.parse_solution(self.solve())
  27.  
  28. def setup_output(self):
  29. self.output = open(self.output_path, 'w')
  30. self.output.write('p cnf 0 0\ndigraph int 0 0 0\n')
  31. # self.g = Graph()
  32.  
  33. def create_vars(self):
  34. # setup nodes
  35. self.grid_by_xyz = OrderedDict()
  36. for x in range(self.maxx):
  37. ys = OrderedDict()
  38. for y in range(self.maxy):
  39. zs = OrderedDict()
  40. for z in range(self.maxz):
  41. v = self.node(('x {} y {} z {} node'.format(x, y, z), (x, y, z)))
  42. zs[z] = {'node':v, 'edges':[], 'xyz':(x,y,z), 'edges_xyz': NestedDict()}
  43. ys[y] = zs
  44. self.grid_by_xyz[x] = ys
  45.  
  46. # setup all possible edges
  47. for x in self.grid_by_xyz:
  48. for y in self.grid_by_xyz[x]:
  49. for z in self.grid_by_xyz[x][y]:
  50. self._neighbor_edges(x, y, z)
  51.  
  52. def create_clauses(self):
  53. # # for every space
  54. all_edges = set()
  55. for x in range(self.maxx):
  56. for y in range(self.maxy):
  57. for z in range(self.maxz):
  58. for e in self.grid_by_xyz[x][y][z]['edges']:
  59. all_edges.add(e)
  60. for cl_node in self._get_circumferential_locs(x, y, z):
  61. starting_edge = self._node_edge_to(self.grid_by_xyz[x][y][z], cl_node)
  62. self._neighbor_constraint(self.grid_by_xyz[x][y][z], cl_node, starting_edge)
  63. # for edge_vars_for_direction in zip(*locs):
  64. # # allow only one trace's edge
  65. # self._naive_mutex(edge_vars_for_direction)
  66. # the following line should be handled by the neighbor clauses, falling like dominoes from the start edges
  67. # self.clause(all_edges)
  68.  
  69. # go through the traces, OR the start node edges, then setup reachability to the end node
  70. self.start_end = []
  71. ios_overall = []
  72. for trace in self.traces:
  73. ios=[]
  74. for io in ['input', 'output']:
  75. x,y,z = self.traces[trace][io]
  76. v = self.grid_by_xyz[x][y][z]
  77. # print(v)
  78. ios.append(v)
  79. # at least one of these edges is True
  80. self.clause(v['edges'])
  81. self.start_end.append((trace, x, y, z))
  82.  
  83. # new var for the clause of reachability between two nodes (start and end)
  84. rv = self.var('reach {}'.format(trace))
  85. # reach <graphID> node1 node2 var
  86. self.output.write('reach 0 {} {} {}\n'.format(ios[0]['node'], ios[1]['node'], rv))
  87. # the var is True
  88. self.clause([rv])
  89. ios_overall.append(ios)
  90.  
  91. # disallow an input to connect to any other trace's output
  92. for ios in ios_overall:
  93. inp = ios[0]
  94. # get all other traces beside this current one
  95. others = list(ios_overall); others.remove(ios)
  96. # get all other traces' outputs
  97. other_outputs = [other_io[1] for other_io in others]
  98.  
  99. rv = self.var('anti reach {}'.format(trace)) # TODO pull this line outside FOR loop?
  100. for other_out in other_outputs:
  101. self.output.write('reach 0 {} {} {}\n'.format(inp['node'], other_out['node'], rv))
  102. self.clause([-rv]) # TODO pull this line outside FOR loop?
  103.  
  104. def solve(self):
  105. # close the clause file
  106. self.output.close()
  107. # call the SAT/SMT solver, pass the clause file path
  108. proc = subprocess.Popen(['/home/nathan/Projects/github/monosat/monosat', self.output_path, '-witness'], stdout=subprocess.PIPE,
  109. universal_newlines=True)
  110. # wait for solver output
  111. stdout, stderr = proc.communicate()
  112. stdout = stdout.strip()
  113. print('num nodes {} num edges {}'.format(len(self.nodes), len(self.vars)))
  114.  
  115. with open('solver_out', 'w') as so:
  116. so.write(stdout)
  117. if stdout.endswith('UNSATISFIABLE'):
  118. print('FAILED')
  119. sys.exit(1)
  120. else:
  121. return stdout
  122.  
  123. def parse_solution(self, stdout):
  124. """ for a given trace, walk all edges and save them in a grid for easy printing """
  125. sol_org = {}
  126. for line in stdout.split(os.linesep):
  127. # var lines are prefixed with letter v
  128. if not line.startswith('v'):
  129. continue
  130. # split the string on whitespace, after getting rid of the prefixed letter v
  131. for i in line[1:].split():
  132. # only interested in non-negative (True) vars, because those are the utilized edges
  133. if not i.startswith('-'):
  134. # convert the variable number back into a list index (the first item is the 0th index)
  135. ii = int(i)-1
  136. # there is no var 0, just like there isn't a negative list index
  137. if ii<0:
  138. continue
  139. # grab the variable's metadata
  140. v = self.vars[ii]
  141. # make sure it is an edge, which are the only metadata saved as tuples at this point
  142. if isinstance(v, tuple):
  143. sol_org[v[0][1]] = v[1][1]
  144. print('{} {}\n'.format(v[0][1] , v[1][1]))
  145. else:
  146. print(v)
  147.  
  148. # open a new file to print the readable solution to
  149. o = open('sol_out', 'w')
  150.  
  151. # find the longest trace name
  152. l = ''
  153. for trace in list(self.traces) + ['*']:
  154. if len(trace) > len(l):
  155. l = trace
  156. # store the length of the longest trace name
  157. l = len(l)
  158.  
  159. for trace in self.traces:
  160. out_xyz = []
  161. sol = deepcopy(sol_org)
  162.  
  163. print(trace)
  164. o.write('\n trace OUT\n')
  165.  
  166. start = place = self.traces[trace]['input']
  167. end = self.traces[trace]['output']
  168.  
  169. # because the graph can cycle, we need to be able to backtrack
  170. trace_progression = []
  171. backing_up = False
  172.  
  173. while True:
  174. # save where we are, so we can backtrack later if needed... unless we just hit a dead-end
  175. if not backing_up:
  176. trace_progression.append(place)
  177. try:
  178. # edges have a source and destination, using the source (i.e. starting point), get the next location
  179. place = sol.pop(place)
  180. # if we popped off the next location with no error, we are not backtracking
  181. # and on the next loop, this place will be saved in case we have to revisit
  182. backing_up = False
  183. except KeyError:
  184. # the current location didn't point elsewhere, so start backing up
  185. backing_up = True
  186. try:
  187. np = trace_progression.pop()
  188. print("end-point {} wasn't found in the solution from point {}!! BACKTRACKING to {}".format(end, place, np))
  189. place = np
  190. except Exception as e:
  191. print("couldn't backtrack anymore!!!")
  192. raise e
  193.  
  194. xx, yy, zz = place
  195. out_xyz.append((xx, yy, zz))
  196. if place == end:
  197. print('hit end')
  198. break
  199.  
  200. # now that we're done walking the graph, we can make a crude sort-of bitmap
  201. for z in range(self.maxz):
  202. o.write('Z {}\n'.format(z))
  203. for x in range(self.maxx):
  204. for y in range(self.maxy):
  205. # print a * for start/end points
  206. if (trace, x, y, z) in self.start_end:
  207. o.write('*{} '.format(' ' * (l - 1)))
  208. # print the trace-name for points in a trace
  209. elif (x,y,z) in out_xyz:
  210. o.write('{}{} '.format(trace, ' ' * (l - len(trace))))
  211. # if a space is unused, print a 0
  212. else:
  213. o.write('0{} '.format(' ' * (l - 1)))
  214. o.write('\n')
  215. o.write('\n')
  216. o.close()
  217.  
  218. # try printing all traces on the same "bitmap"... harder to read, but I wanted a sanity check
  219. o = open('sol_out2', 'w')
  220. for z in range(self.maxz):
  221. o.write('Z {}\n'.format(z))
  222. for x in range(self.maxx):
  223. for y in range(self.maxy):
  224. # check if this point is in any trace's start or end
  225. if [t for t in self.start_end if t[1:]==(x, y, z)]:
  226. o.write('*{} '.format(' ' * (l - 1)))
  227. # check if this point is in any trace, source or destination node
  228. elif (x,y,z) in sol_org or (x,y,z) in sol_org.values():
  229. o.write('{} '.format('T '))
  230. else:
  231. o.write('0{} '.format(' ' * (l - 1)))
  232. o.write('\n')
  233. o.write('\n')
  234.  
  235. def var(self, name):
  236. self.vars.append(name)
  237. # n = self.g.addNode()
  238. # self.real_vars.append(n)
  239. #return (self.num_vars, n)
  240. return self.num_vars
  241.  
  242. def node(self, name):
  243. self.nodes.append(name)
  244. return len(self.nodes)
  245.  
  246. @property
  247. def num_vars(self):
  248. return len(self.vars)
  249.  
  250. def clause(self, v_list, comment=None):
  251. svs = [str(v) for v in v_list]
  252. if comment is not None:
  253. self.comment(comment)
  254. self.output.write('{} 0\n'.format(' '.join(svs)))
  255.  
  256. def comment(self, comment):
  257. self.output.write('c {}\n'.format(comment))
  258.  
  259. def _kleiber_kwon(self, vs):
  260. pass
  261.  
  262. def _naive_mutex(self, vs, cmdr_list=None):
  263. self.comment('_naive_mutex to follow{}'.format(' with cmdrs {} and {}'.format(cmdr_list, vs) if cmdr_list is not None else ''))
  264. for i, vi in enumerate(vs):
  265. if (i+1)>=len(vs):
  266. continue
  267. for vj in vs[i+1:]:
  268. self.clause((cmdr_list if cmdr_list is not None else []) + [-vi, -vj])
  269. self.comment('_naive_mutex finished')
  270.  
  271. def _neighbor_edges(self, x, y, z):
  272. n = self.grid_by_xyz[x][y][z]
  273. #bvs = []
  274. circumferential_locs = self._get_circumferential_locs(x, y, z)
  275. for cl in circumferential_locs:
  276. others = list(circumferential_locs); others.remove(cl)
  277. # if v and cl, then one of the others
  278. #bv1 = BitVector(1)
  279. #bvs.append(bv1)
  280. t = n['node']
  281. f = cl['node']
  282. tt = self.nodes[t-1]
  283. ff = self.nodes[f-1]
  284. ee = (tt, ff)
  285. ev = self.var(ee)
  286. self._edge(n['node'], cl['node'], ev)#, bv1)
  287. n['edges'].append(ev)
  288. xx,yy,zz = cl['xyz']
  289. n['edges_xyz'][xx][yy][zz] = ev
  290.  
  291. def _node_edge_to(self, n, on):
  292. return n['edges_xyz'][on['xyz'][0]][on['xyz'][1]][on['xyz'][2]]
  293.  
  294. def _neighbor_constraint(self, from_node, center_node, edge_came_from):
  295. if (from_node, center_node) in self.visited_neighbor_edges or (center_node, from_node) in self.visited_neighbor_edges:
  296. return False
  297. if center_node['xyz'] in [self.traces[t]['output'] for t in self.traces]:
  298. return False
  299. self.visited_neighbor_edges.append((from_node, center_node))
  300.  
  301. circumferential_loc_nodes = self._get_circumferential_locs(*center_node['xyz'])
  302. if from_node not in circumferential_loc_nodes:
  303. raise Exception('{} {}'.format(circumferential_loc_nodes, from_node))
  304. circumferential_loc_nodes.remove(from_node)
  305. # if the edge_came_from is True, then one of the outgoing edges is True
  306. self.clause([-edge_came_from] + [self._node_edge_to(center_node, n) for n in circumferential_loc_nodes])
  307.  
  308. # Assert(Or(Not(v[1]), *es))
  309.  
  310. def _edge(self, _from, _to, v, w=1, graph_id=0):
  311. """edge <GraphID> <from> <to> <CNF Variable> [weight]"""
  312. self.output.write('edge {} {} {} {} {}\n'.format(graph_id, _from, _to, v, w if w>1 else ''))
  313.  
  314. def _get_circumferential_locs(self, x,y,z):
  315. # up, down (in Z), left, right, ahead, behind (in-plane), diags
  316. ensure = 2 + 4 + 4
  317. # up, down (in Z), left, right, ahead, behind (in-plane)
  318. ensure = 2 + 4
  319. nvs = []
  320.  
  321. disallow_fortyfives = True
  322.  
  323. for xx in [x-1,x,x+1]:
  324. if xx<0 or xx>=self.maxx:
  325. ensure = None
  326. continue
  327. for yy in [y-1,y,y+1]:
  328. if yy<0 or yy>=self.maxy:
  329. ensure = None
  330. continue
  331. for zz in [z-1,z,z+1]:
  332. if xx<0 or xx>=self.maxx or yy<0 or yy>=self.maxy or zz<0 or zz>=self.maxz:
  333. ensure = None
  334. continue
  335.  
  336. # skip the center point
  337. if x==xx and y==yy and z==zz:
  338. continue
  339.  
  340. # restrict vias to vertical Z transitions only
  341. if (xx!=x or yy!=y) and zz!=z:
  342. continue
  343.  
  344. # forty-five degree angles are disabled for now
  345. # they can be enabled when diagonally crossing edges are disallowed
  346. if disallow_fortyfives:
  347. if abs(xx-x) == 1 and abs(yy-y) == 1:
  348. continue
  349.  
  350. try:
  351. nvs.append(self.grid_by_xyz[xx][yy][zz])
  352. except:
  353. print('{} {} {}'.format(xx, yy, zz))
  354. raise
  355. if ensure is not None:
  356. assert len(nvs) == ensure, 'nvs was {}'.format(nvs)
  357. return nvs
  358.  
  359.  
  360. # make up some start and end points, to be placed within a grid of size: 20 x 20 x 2
  361. traces = OrderedDict([('t1', {'input': (0,0,1),
  362. 'output': (10,10,0)}),
  363. ('t2', {'input': (0,10,0),
  364. 'output': (10,1,0)}),
  365. ('t3', {'input': (0,0,0),
  366. 'output': (10,10,1)}),
  367. ('t4', {'input': (1,3,0),
  368. 'output': (10,4,0)})])
  369. # pass the traces and the grid size, it begins to generate clauses and proceed to call the solver
  370. SATGenerator(traces, maxx = 20, maxy = 20, maxz = 2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement