Advertisement
Guest User

Untitled

a guest
May 25th, 2016
206
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.00 KB | None | 0 0
  1. # PyAlgoTrade
  2. #
  3. # Copyright 2011-2015 Gabriel Martin Becedillas Ruiz
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16.  
  17. """
  18. .. moduleauthor:: Gabriel Martin Becedillas Ruiz <[email protected]>
  19. """
  20.  
  21. import numpy as np
  22. import matplotlib.pyplot as plt
  23.  
  24. from pyalgotrade.technical import roc
  25. from pyalgotrade import dispatcher
  26.  
  27.  
  28. class Results(object):
  29. """Results from the profiler."""
  30. def __init__(self, eventsDict, lookBack, lookForward):
  31. assert(lookBack > 0)
  32. assert(lookForward > 0)
  33. self.__lookBack = lookBack
  34. self.__lookForward = lookForward
  35. self.__values = [[] for i in xrange(lookBack+lookForward+1)]
  36. self.__eventCount = 0
  37.  
  38. # Process events.
  39. for instrument, events in eventsDict.items():
  40. for event in events:
  41. # Skip events which are on the boundary or for some reason are not complete.
  42. if event.isComplete():
  43. self.__eventCount += 1
  44. # Compute cumulative returns: (1 + R1)*(1 + R2)*...*(1 + Rn)
  45. values = np.cumprod(event.getValues() + 1)
  46. # Normalize everything to the time of the event
  47. values = values / values[event.getLookBack()]
  48. for t in range(event.getLookBack()*-1, event.getLookForward()+1):
  49. self.setValue(t, values[t+event.getLookBack()])
  50.  
  51. def __mapPos(self, t):
  52. assert(t >= -1*self.__lookBack and t <= self.__lookForward)
  53. return t + self.__lookBack
  54.  
  55. def setValue(self, t, value):
  56. if value is None:
  57. raise Exception("Invalid value at time %d" % (t))
  58. pos = self.__mapPos(t)
  59. self.__values[pos].append(value)
  60.  
  61. def getValues(self, t):
  62. pos = self.__mapPos(t)
  63. return self.__values[pos]
  64.  
  65. def getLookBack(self):
  66. return self.__lookBack
  67.  
  68. def getLookForward(self):
  69. return self.__lookForward
  70.  
  71. def getEventCount(self):
  72. """Returns the number of events occurred. Events that are on the boundary are skipped."""
  73. return self.__eventCount
  74.  
  75.  
  76. class Predicate(object):
  77. """Base class for event identification. You should subclass this to implement
  78. the event identification logic."""
  79.  
  80. def eventOccurred(self, instrument, bards):
  81. """Override (**mandatory**) to determine if an event took place in the last bar (bards[-1]).
  82.  
  83. :param instrument: Instrument identifier.
  84. :type instrument: string.
  85. :param bards: The BarDataSeries for the given instrument.
  86. :type bards: :class:`pyalgotrade.dataseries.bards.BarDataSeries`.
  87. :rtype: boolean.
  88. """
  89. raise NotImplementedError()
  90.  
  91.  
  92. class Event(object):
  93. def __init__(self, lookBack, lookForward):
  94. assert(lookBack > 0)
  95. assert(lookForward > 0)
  96. self.__lookBack = lookBack
  97. self.__lookForward = lookForward
  98. self.__values = np.empty((lookBack + lookForward + 1))
  99. self.__values[:] = np.NAN
  100.  
  101. def __mapPos(self, t):
  102. assert(t >= -1*self.__lookBack and t <= self.__lookForward)
  103. return t + self.__lookBack
  104.  
  105. def isComplete(self):
  106. return not any(np.isnan(self.__values))
  107.  
  108. def getLookBack(self):
  109. return self.__lookBack
  110.  
  111. def getLookForward(self):
  112. return self.__lookForward
  113.  
  114. def setValue(self, t, value):
  115. if value is not None:
  116. pos = self.__mapPos(t)
  117. self.__values[pos] = value
  118.  
  119. def getValue(self, t):
  120. pos = self.__mapPos(t)
  121. return self.__values[pos]
  122.  
  123. def getValues(self):
  124. return self.__values
  125.  
  126.  
  127. class Profiler(object):
  128. """This class is responsible for scanning over historical data and analyzing returns before
  129. and after the events.
  130.  
  131. :param predicate: A :class:`Predicate` subclass responsible for identifying events.
  132. :type predicate: :class:`Predicate`.
  133. :param lookBack: The number of bars before the event to analyze. Must be > 0.
  134. :type lookBack: int.
  135. :param lookForward: The number of bars after the event to analyze. Must be > 0.
  136. :type lookForward: int.
  137. """
  138.  
  139. def __init__(self, predicate, lookBack, lookForward):
  140. assert(lookBack > 0)
  141. assert(lookForward > 0)
  142. self.__predicate = predicate
  143. self.__lookBack = lookBack
  144. self.__lookForward = lookForward
  145. self.__feed = None
  146. self.__rets = {}
  147. self.__futureRets = {}
  148. self.__events = {}
  149.  
  150. def __addPastReturns(self, instrument, event):
  151. begin = (event.getLookBack() + 1) * -1
  152. for t in xrange(begin, 0):
  153. try:
  154. ret = self.__rets[instrument][t]
  155. if ret is not None:
  156. event.setValue(t+1, ret)
  157. except IndexError:
  158. pass
  159.  
  160. def __addCurrentReturns(self, instrument):
  161. nextTs = []
  162. for event, t in self.__futureRets[instrument]:
  163. event.setValue(t, self.__rets[instrument][-1])
  164. if t < event.getLookForward():
  165. t += 1
  166. nextTs.append((event, t))
  167. self.__futureRets[instrument] = nextTs
  168.  
  169. def __onBars(self, dateTime, bars):
  170. for instrument in bars.getInstruments():
  171. self.__addCurrentReturns(instrument)
  172. eventOccurred = self.__predicate.eventOccurred(instrument, self.__feed[instrument])
  173. if eventOccurred:
  174. event = Event(self.__lookBack, self.__lookForward)
  175. self.__events[instrument].append(event)
  176. self.__addPastReturns(instrument, event)
  177. # Add next return for this instrument at t=1.
  178. self.__futureRets[instrument].append((event, 1))
  179.  
  180. def getResults(self):
  181. """Returns the results of the analysis.
  182.  
  183. :rtype: :class:`Results`.
  184. """
  185. return Results(self.__events, self.__lookBack, self.__lookForward)
  186.  
  187. def run(self, feed, useAdjustedCloseForReturns=True):
  188. """Runs the analysis using the bars supplied by the feed.
  189.  
  190. :param barFeed: The bar feed to use to run the analysis.
  191. :type barFeed: :class:`pyalgotrade.barfeed.BarFeed`.
  192. :param useAdjustedCloseForReturns: True if adjusted close values should be used to calculate returns.
  193. :type useAdjustedCloseForReturns: boolean.
  194. """
  195.  
  196. if useAdjustedCloseForReturns:
  197. assert feed.barsHaveAdjClose(), "Feed doesn't have adjusted close values"
  198.  
  199. try:
  200. self.__feed = feed
  201. self.__rets = {}
  202. self.__futureRets = {}
  203. for instrument in feed.getRegisteredInstruments():
  204. self.__events.setdefault(instrument, [])
  205. self.__futureRets[instrument] = []
  206. if useAdjustedCloseForReturns:
  207. ds = feed[instrument].getAdjCloseDataSeries()
  208. else:
  209. ds = feed[instrument].getCloseDataSeries()
  210. self.__rets[instrument] = roc.RateOfChange(ds, 1)
  211.  
  212. feed.getNewValuesEvent().subscribe(self.__onBars)
  213. disp = dispatcher.Dispatcher()
  214. disp.addSubject(feed)
  215. disp.run()
  216. finally:
  217. feed.getNewValuesEvent().unsubscribe(self.__onBars)
  218.  
  219.  
  220. def build_plot(profilerResults):
  221. # Calculate each value.
  222. x = []
  223. y = []
  224. std = []
  225. for t in xrange(profilerResults.getLookBack()*-1, profilerResults.getLookForward()+1):
  226. x.append(t)
  227. values = np.asarray(profilerResults.getValues(t))
  228. y.append(values.mean())
  229. std.append(values.std())
  230.  
  231. # Plot
  232. plt.clf()
  233. plt.plot(x, y, color='#0000FF')
  234. eventT = profilerResults.getLookBack()
  235. # stdBegin = eventT + 1
  236. # plt.errorbar(x[stdBegin:], y[stdBegin:], std[stdBegin:], alpha=0, ecolor='#AAAAFF')
  237. plt.errorbar(x[eventT+1:], y[eventT+1:], std[eventT+1:], alpha=0, ecolor='#AAAAFF')
  238. # plt.errorbar(x, y, std, alpha=0, ecolor='#AAAAFF')
  239. plt.axhline(y=y[eventT], xmin=-1*profilerResults.getLookBack(), xmax=profilerResults.getLookForward(), color='#000000')
  240. plt.xlim(profilerResults.getLookBack()*-1-0.5, profilerResults.getLookForward()+0.5)
  241. plt.xlabel('Time')
  242. plt.ylabel('Cumulative returns')
  243.  
  244.  
  245. def plot(profilerResults):
  246. """Plots the result of the analysis.
  247.  
  248. :param profilerResults: The result of the analysis
  249. :type profilerResults: :class:`Results`.
  250. """
  251.  
  252. build_plot(profilerResults)
  253. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement