# Untitled

a guest Jun 19th, 2017
1. #!/usr/bin/env python
2. # -*- coding: iso-8859-1 -*-
3.
4. from sympy import symbols, Eq, solve, StrictInequality, Interval
5.
6. class Switch(object):
7.     """
8.    >>> BOUNDS = Switch("MIN", "MAX")
9.    >>> BOUNDS
10.    'MIN'
11.    >>> BOUNDS.MIN
12.    'MIN'
13.    >>> ~BOUNDS.MIN
14.    'MAX'
15.    >>> BOUNDS.MAX == ~BOUNDS.MIN
16.    True
17.    """
18.     def __init__(self, state1, state2):
19.         self.state1 = state1
20.         self.state2 = state2
21.         self.state = state1
22.     def __repr__(self):
23.         return repr(self.state)
24.     def __eq__(self, other_switch):
25.         return self.state == other_switch.state
26.     def __getattr__(self, state):
27.         if state == self.state1:
28.             return self.__class__(self.state1, self.state2)
29.         if state == self.state2:
30.             return self.__class__(self.state2, self.state1)
31.         return None
32.     def __invert__(self):
33.         if self.state == self.state2:
34.             return self.__class__(self.state1, self.state2)
35.         else:
36.             return self.__class__(self.state2, self.state1)
37.
38. BOUNDS = Switch("MIN", "MAX")
39.
40. def solve_inequality(equation, symbol):
41.     return solve(Eq(equation.lhs, equation.rhs), symbol)[0]
42.
43. def symbol_bounds(equation, symbol, other_symbol):
44.     """
45.    >>> x, y = symbols("x y")
46.    >>> symbol_bounds(x > y, x, y)
47.    ('MIN', 'MIN')
48.    >>> symbol_bounds(x < y, x, y)
49.    ('MAX', 'MAX')
50.    >>> symbol_bounds(x + y > 0, x, y)
51.    ('MIN', 'MAX')
52.    >>> symbol_bounds(x + y < 0, x, y)
53.    ('MAX', 'MIN')
54.    """
55.     symbol_term_is_lower = symbol in equation.lhs
56.     other_symbol_term_is_lower = other_symbol in equation.lhs
57.     symbol_bounds = BOUNDS.MAX if symbol_term_is_lower else BOUNDS.MIN
58.     other_symbol_bounds = BOUNDS.MIN if other_symbol_term_is_lower \
59.         else BOUNDS.MAX
60.     return (symbol_bounds, other_symbol_bounds)
61.
62. def bounds_offset(equation, symbol):
63.     """
64.    >>> x, y = symbols("x y")
65.    >>> bounds_offset(x >= 0, x)
66.    0
67.    >>> bounds_offset(x > 0, x)
68.    1
69.    >>> bounds_offset(x <= 0, x)
70.    0
71.    >>> bounds_offset(x < 0, x)
72.    -1
73.    """
74.     offset = 1 if isinstance(equation, StrictInequality) else 0
75.     return -offset if symbol in equation.lhs else offset
76.
77. def get_bounds(interval, bounds):
78.     """
79.    >>> interval = Interval(4, 7)
80.    >>> get_bounds(interval, BOUNDS.MIN)
81.    4
82.    >>> get_bounds(interval, BOUNDS.MAX)
83.    7
84.    """
85.     if bounds == BOUNDS.MIN:
86.         return interval.start
87.     if bounds == BOUNDS.MAX:
88.         return interval.end
89.     return None
90.
91. def set_bounds(interval, bounds, value):
92.     """
93.    >>> interval = Interval(4, 7)
94.    >>> set_bounds(interval, BOUNDS.MIN, 3)
95.    [3, 7]
96.    >>> set_bounds(interval, BOUNDS.MAX, 8)
97.    [4, 8]
98.    """
99.     if bounds == BOUNDS.MIN:
100.         return Interval(value, interval.end)
101.     if bounds == BOUNDS.MAX:
102.         return Interval(interval.start, value)
103.     return None
104.
105. def bounds_consistency(constraint, symbol_1, symbol_2, domain_1, domain_2):
106.     """
107.    >>> x, y = symbols("x y")
108.    >>> domain_x = Interval(1, 10)
109.    >>> domain_y = Interval(1, 10)
110.
111.    >>> bounds_consistency(x < y, x, y, domain_x, domain_y)
112.    [[1, 9], [2, 10]]
113.    >>> bounds_consistency(x + 2 < y, x, y, domain_x, domain_y)
114.    [[1, 7], [4, 10]]
115.    """
116.     def revise_domain(constraint, symbol, other_symbol, domain, other_domain):
117.         bounds, other_bounds = symbol_bounds(constraint, symbol, other_symbol)
118.         transposed_constraint = solve_inequality(constraint, symbol)
119.         offset = bounds_offset(constraint, symbol)
120.         substitution = get_bounds(other_domain, other_bounds) + offset
121.         result = transposed_constraint.subs(other_symbol, substitution)
122.         domain = set_bounds(domain, bounds, result)
123.         return domain
124.     domain_1 = revise_domain(constraint, symbol_1, symbol_2, domain_1, domain_2)
125.     domain_2 = revise_domain(constraint, symbol_2, symbol_1, domain_2, domain_1)
126.     return [domain_1, domain_2]
127.
128. if __name__ == "__main__":
129.     import doctest
130.     doctest.testmod()
131.
132. if __name__ == "__main__":
133.     x, y = symbols("x y")
134.     domain_x = Interval(1, 10)
135.     domain_y = Interval(1, 10)
136.
137.     domain_x, domain_y = bounds_consistency(x + 2 < y, x, y,
138.                                             domain_x, domain_y)
139.     print domain_x, domain_y
140.
141.     domain_x, domain_y = bounds_consistency(x + y >= 16, x, y,
142.                                             domain_x, domain_y)
143.     print domain_x, domain_y
