Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def cache(func):
- cached = {}
- def inner_func(*args, **kwargs):
- if args not in cached:
- cached[args] = func(*args)
- return cached[args]
- return inner_func
- ##return the value, vcont, of the location n... does it need to be n,m? identify a value at a single square
- ##TODO n equals the total number of samples, h = the nummber of positive values from the total sample set
- @cache
- def values(total, n, h, reward, cost, continue_val=0):
- if out_of_range(total, n, h):
- raise Exception('Out of range, {} and {} must be less than or equal to {} and greater than equal to 0'.format(n,h,total))
- phit = h / (h + (n - h))
- u = max(phit, 1 - phit) * reward
- vstop = u - (n * cost)
- if is_base_case(total, n):
- ##TODO what is the formuala for the value of the outer edge, in relation to n and h
- #return (vstop, ((phit * continue_val) + ((1 - phit) * continue_val) - n * cost))
- return (vstop, reward - (total * cost))
- ##TODO vstop vs vcont, what are we returning? return a tuple for the value of stoping and the
- # value of continuing at each point in the grid?
- # the value of stoping seems to have nothing to do with the value of continuing? should we group them together?
- # worth putting this in ab obj? or return a tuple for each location? maybe return an object for each location?
- # object {x, y, vstop, vcont}??
- print('n={}, h={}, phit={}, qhit={}, u={} phit_u={}, EV={}'.format(n,h,phit,1-phit,u,phit*u,phit*u-n*cost))
- return (vstop, ((phit * values(total, n + 1, h + 1, reward, cost, continue_val)[1])
- + ((1 - phit) * values(total, n + 1, h, reward, cost, continue_val)[1])
- - cost))
- def out_of_range(total, n, h):
- """
- Identifies if the params passed through are out of range
- :param total:
- :param n:
- :param h:
- :return:
- """
- return n > total or h > total or n <= 0 or h < 0 or n < h
- def is_base_case(total, n):
- """
- Identifies if the locations are on the outer edge of the graph
- :param total:
- :param n:
- :return:
- """
- return n == total
- def main():
- print(values(25, 20, 7, 5, .1))
- pass
- if __name__ == '__main__':
- main()
Add Comment
Please, Sign In to add comment