Advertisement
Mochinov

Untitled

Mar 16th, 2023
171
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 17.61 KB | None | 0 0
  1. import json
  2. from datetime import datetime
  3. from io import BytesIO
  4.  
  5. import numpy as np
  6. import pandas as pd
  7. import requests
  8. from django.db.models import Q
  9. from django.utils.translation import gettext as _
  10. from sqlalchemy import text
  11. from sqlalchemy.orm import Session
  12.  
  13. from config.settings import EMAIL_FROM, FRONTEND_SCENARIO_RESULT_URL, FORECAST_LEVELS_SEPARATOR
  14. from forecasts.clickhouse.db import get_session
  15. from forecasts.clickhouse.models import ForecastResultData, ForecastResultDataCorrected
  16. from forecasts.exceptions.classes import BlankDatesInPanelData, InvalidPartitionInPanelData, \
  17. NoActualShipmentsForMLPayloadCreation, MLServerUnavailable, MissingMlSettings, KeyCountExceededInMLBatch
  18. from forecasts.models import ForecastScenario, ForecastResult, MasterDataElement, ActualShipment, MasterDataLevel, \
  19. MasterDataHierarchy, MasterDataInnovation, MasterDataElementStatus
  20. from forecasts.services.ml_interaction.timeline_extender import TimelineExtender
  21. from forecasts.services.scenario_result.aggregated_scenario_data import AggregatedScenarioData
  22. from general.tasks import task_send_email
  23.  
  24.  
  25. class MLServerDataSender:
  26. default_dimensions = ['key', 'date', 'partition']
  27. partition_forecast = 'fc'
  28. partition_train = 'train'
  29. min_train_rows = 3
  30.  
  31. def __init__(
  32. self,
  33. aggregated_scenario_data: AggregatedScenarioData,
  34. timeline_extender: TimelineExtender,
  35. session: Session = get_session()
  36. ):
  37. self._ml_settings = aggregated_scenario_data.scenario.project.forecast_general_settings.ml_settings
  38. self._session = session
  39. self._scenario_data = aggregated_scenario_data
  40. self._timeline_extender = timeline_extender
  41. project = self._scenario_data.scenario.project
  42. self._cannibalization_level = project.general_settings.cannibalization_level if bool(
  43. project.general_settings.cannibalization_level
  44. and project.general_settings.cannibalization_level.id
  45. in set(MasterDataLevel.objects.filter(hierarchy__name=MasterDataHierarchy.PRODUCT_HIERARCHY).values_list('id', flat=True))
  46. ) else None
  47.  
  48. def execute(self):
  49. max_key_count = self._scenario_data.scenario.project.forecast_general_settings.max_calculation_key_count
  50.  
  51. headers = {
  52. 'api_key': self._ml_settings.api_key,
  53. 'feature_types': json.dumps({'can_level': self._cannibalization_level.name if self._cannibalization_level else None}),
  54. }
  55.  
  56. scenario_id = self._scenario_data.scenario.id
  57. scenario = ForecastScenario.objects.filter(id=scenario_id).first()
  58. scenario.last_process_error = None
  59. scenario.last_process_start = datetime.now()
  60.  
  61. try:
  62. data = self._collect_data()
  63. except Exception as ex:
  64. scenario.last_process_error = str(ex)
  65. scenario.is_in_process = False
  66. scenario.save()
  67.  
  68. raise ex
  69.  
  70. scenario.save()
  71. all_records = []
  72.  
  73. result_model = self._create_forecast_result(scenario_id)
  74.  
  75. try:
  76. if self._cannibalization_level:
  77. grouped_by_level = data.groupby([self._cannibalization_level.name])
  78.  
  79. for _, group in grouped_by_level:
  80. group_key_count = group['key'].nunique()
  81. if group_key_count > max_key_count:
  82. raise KeyCountExceededInMLBatch(group_key_count, max_key_count, language_code=self._scenario_data.scenario.project.language.code)
  83.  
  84. records = self._call_ml(group, headers, result_model)
  85. all_records.extend(records)
  86. else:
  87. grouped_by_key = data.groupby(['key'])
  88. for i in range(0, len(grouped_by_key), max_key_count):
  89. sliced = pd.concat(
  90. [grouped_by_key.get_group(n) for n in list(dict(list(grouped_by_key)).keys())[i:i + max_key_count]])
  91.  
  92. records = self._call_ml(sliced, headers, result_model)
  93. all_records.extend(records)
  94.  
  95. except Exception as ex:
  96. scenario.last_process_error = str(ex)
  97. scenario.is_in_process = False
  98. scenario.save()
  99.  
  100. result_model.delete()
  101.  
  102. raise ex
  103.  
  104. self._session.query(ForecastResultData) \
  105. .filter(ForecastResultData.forecast_result_id == result_model.id) \
  106. .delete()
  107.  
  108. self._session.query(ForecastResultDataCorrected) \
  109. .filter(ForecastResultDataCorrected.forecast_result_id == result_model.id) \
  110. .delete()
  111.  
  112. cleaned_records = self._remove_actuals_from_forecast_results(all_records)
  113.  
  114. if cleaned_records:
  115. self._session.execute(
  116. ForecastResultData.__table__.insert(),
  117. cleaned_records
  118. )
  119.  
  120. now = datetime.now()
  121. scenario.last_process_end = now
  122. scenario.processed = now
  123. scenario.is_in_process = False
  124. scenario.is_proposed = False
  125. scenario.save()
  126.  
  127. self._send_email_on_forecast_processing_finish(scenario)
  128.  
  129. def _call_ml(self, df: pd.DataFrame, headers: dict, result: ForecastResult) -> dict:
  130. bytes_io = BytesIO()
  131.  
  132. df.to_parquet(bytes_io, index=False, compression='gzip')
  133. payload = bytes_io.getvalue()
  134.  
  135. try:
  136. response = self._make_request(payload, headers)
  137. except Exception:
  138. language_code = self._scenario_data.scenario.project.language.code if self._scenario_data.scenario.project.language else None
  139. raise MLServerUnavailable(language_code=language_code)
  140.  
  141. result_df = pd.read_parquet(BytesIO(response.content))
  142. result_df['date'] = pd.to_datetime(result_df['date'], format='%Y-%m-%d')
  143. result_df['forecast_result_id'] = result.id
  144. result_df.rename(columns={'fc': 'value', 'uplift': 'promo', 'bl': 'baseline', 'cn': 'cannibalization'},
  145. inplace=True)
  146.  
  147. records = result_df.to_dict('records')
  148.  
  149. return records
  150.  
  151. def _create_forecast_result(self, scenario_id: int) -> ForecastResult:
  152. return ForecastResult.objects.create(
  153. scenario_id=scenario_id,
  154. cannibalization_level=self._cannibalization_level
  155. )
  156.  
  157. def _collect_data(self) -> pd.DataFrame:
  158. if not all((self._ml_settings.host, self._ml_settings.endpoint, self._ml_settings.api_key)):
  159. raise MissingMlSettings(language_code=self._scenario_data.scenario.project.language.code)
  160.  
  161. actual_shipments = self._get_actual_shipments_df(self._scenario_data.scenario.project_id)
  162. if actual_shipments.empty:
  163. raise NoActualShipmentsForMLPayloadCreation(language_code=self._scenario_data.scenario.project.language.code)
  164.  
  165. if self._cannibalization_level:
  166. actual_shipments = self._apply_cannibalization(actual_shipments)
  167.  
  168. resulting_df = self._update_data_based_on_forecast_horizon(actual_shipments)
  169.  
  170. resulting_df.rename(columns={'actual': 'target'}, inplace=True)
  171. resulting_df['partition'] = resulting_df.apply(
  172. lambda row: self.partition_forecast if pd.isnull(row['target']) else self.partition_train, axis=1
  173. )
  174. resulting_df['target'].fillna(0, inplace=True)
  175. resulting_df['date'] = pd.to_datetime(resulting_df['date'], format='%Y-%m-%d')
  176. resulting_df.sort_values(['key', 'date'], inplace=True)
  177. granularity_errors = self._check_granularity(resulting_df)
  178.  
  179. if granularity_errors:
  180. raise BlankDatesInPanelData(', '.join(str(error) for error in granularity_errors[:3]),
  181. len(granularity_errors), language_code=self._scenario_data.scenario.project.language.code)
  182.  
  183. partition_errors = self._check_partitions(resulting_df)
  184. if partition_errors:
  185. raise InvalidPartitionInPanelData(', '.join(error for error in partition_errors[:3]),
  186. len(partition_errors), language_code=self._scenario_data.scenario.project.language.code)
  187.  
  188. for dimension in self.default_dimensions:
  189. resulting_df[dimension] = pd.Categorical(resulting_df[dimension])
  190.  
  191. without_innovations = self._remove_innovations_from_df(resulting_df)
  192.  
  193. print(f"""
  194.  
  195. Метод _collect_data:
  196.  
  197.  
  198. actual_shipments:
  199. {actual_shipments}
  200.  
  201. resulting_df:
  202. {resulting_df}
  203.  
  204.  
  205. without_innovations:
  206. {without_innovations}
  207.  
  208. ---------- End _collect_data ---------
  209. """)
  210.  
  211. return without_innovations
  212.  
  213. def _update_data_based_on_forecast_horizon(self, actual_shipments: pd.DataFrame):
  214. if self._scenario_data.df.empty:
  215. return self._timeline_extender.execute(
  216. actual_shipments,
  217. self._cannibalization_level.name if self._cannibalization_level else None
  218. )
  219. else:
  220. return self._remove_excess_rows(actual_shipments)
  221.  
  222. def _make_request(self, payload: bytes, headers: dict):
  223. response = requests.post(
  224. url=f'{self._ml_settings.host}/{self._ml_settings.endpoint}',
  225. data=payload,
  226. headers=headers,
  227. timeout=self._scenario_data.scenario.project.forecast_general_settings.ml_settings.timeout
  228. )
  229.  
  230. response.raise_for_status()
  231.  
  232. return response
  233.  
  234. def _check_granularity(self, df: pd.DataFrame) -> list:
  235. date_col_index = df.columns.get_loc("date")
  236. key_col_index = df.columns.get_loc("key")
  237. iterator = df.iterrows()
  238. first_index, first_row = next(iterator)
  239. previous_date = first_row[date_col_index]
  240. previous_key = first_row[key_col_index]
  241. granularity_validator = self._scenario_data.scenario.project.forecast_general_settings.\
  242. granularity.validators.adjacent_dates
  243.  
  244. errors = []
  245.  
  246. for index, row in iterator:
  247. is_same_key = row[key_col_index] == previous_key
  248. is_valid = granularity_validator(previous_date, row[date_col_index])
  249. if is_same_key and not is_valid:
  250. errors.append(f'{row[key_col_index]} ({previous_date.strftime("%Y-%m-%d")} - {row[date_col_index].strftime("%Y-%m-%d")})')
  251.  
  252. previous_date = row[date_col_index]
  253. previous_key = row[key_col_index]
  254.  
  255. return errors
  256.  
  257. def _check_partitions(self, df: pd.DataFrame) -> list:
  258. keys_with_errors = []
  259. grouped = df.groupby(['key'])
  260.  
  261. for _, group in grouped:
  262. if self.partition_forecast not in set(group['partition']) or self.min_train_rows > len(group[group['partition'] == self.partition_train]):
  263. keys_with_errors.append(group.iloc[0]['key'])
  264.  
  265. return keys_with_errors
  266.  
  267. def _remove_excess_rows(self, actuals_df: pd.DataFrame) -> pd.DataFrame:
  268. forecast_horizon = self._scenario_data.scenario.project.forecast_general_settings.forecast_horizon
  269. self._scenario_data.df.sort_values(['key', 'date'], inplace=True)
  270. grouped = self._scenario_data.df.groupby('key')
  271.  
  272. def _remove_from_tail_for_group(group):
  273. if len(group) > forecast_horizon:
  274. group.drop(group.tail(len(group) - forecast_horizon).index,inplace=True)
  275.  
  276. return group
  277.  
  278. return group
  279.  
  280. with_cut_timeline = grouped.apply(_remove_from_tail_for_group).reset_index(drop=True)
  281. concat = pd.concat([with_cut_timeline, actuals_df], ignore_index=True)
  282. concat.drop_duplicates(subset=['key', 'date'], keep='last', inplace=True)
  283. concat.drop(labels=list(self._scenario_data.fields.values()), axis=1, inplace=True)
  284.  
  285. return concat
  286.  
  287. def _remove_actuals_from_forecast_results(self, results: list[dict]):
  288. actual_shipments = self._get_actual_shipments_df(self._scenario_data.scenario.project_id)
  289. actual_shipments = {(actual['key'], actual['date'].strftime('%Y-%m-%d')) for _, actual in actual_shipments.iterrows()}
  290. cleaned_records = []
  291.  
  292. for result in results:
  293. if (result['key'], result['date'].strftime('%Y-%m-%d')) not in actual_shipments:
  294. cleaned_records.append(result)
  295.  
  296. return cleaned_records
  297.  
  298. def _send_email_on_forecast_processing_finish(self, scenario):
  299. scenario_in_db = ForecastScenario.objects.filter(id=scenario.id).first()
  300.  
  301. if scenario_in_db:
  302. frontend_url = f': {FRONTEND_SCENARIO_RESULT_URL.format(scenario.id)}' if FRONTEND_SCENARIO_RESULT_URL else ''
  303.  
  304. task_send_email.delay({
  305. 'subject': _('Сценарий "{}" рассчитан').format(scenario.project.title),
  306. 'recipients': list(scenario.project.users.values_list('email', flat=True)),
  307. 'sender': EMAIL_FROM,
  308. 'body': _('Расчет сценария прогноза "{}" был завершен {}').format(
  309. scenario.project.title, frontend_url)
  310. })
  311.  
  312. def _apply_cannibalization(self, actual_shipments: pd.DataFrame):
  313. order_by_key = self._scenario_data.scenario.project.product_hierarchy.order_by_key
  314.  
  315. actual_shipments['product_key'] = actual_shipments.apply(lambda row: row['key'].split(FORECAST_LEVELS_SEPARATOR)[order_by_key], axis=1)
  316. system_keys = set(actual_shipments['product_key'].tolist())
  317.  
  318. fc_levels = MasterDataElement.get_forecast_level_elements_by_system_keys_with_cannibalization_level(
  319. system_keys,
  320. self._scenario_data.scenario.project_id,
  321. self._cannibalization_level.id
  322. )
  323. df = pd.DataFrame(list(fc_levels), columns=['product_key', self._cannibalization_level.name])
  324.  
  325. merged_df = actual_shipments.merge(df, on=['product_key'], how='left')
  326. merged_df.drop(labels=['product_key'], axis=1, inplace=True)
  327.  
  328. return merged_df
  329.  
  330. def _get_actual_shipments_df(self, project_id: int) -> pd.DataFrame:
  331. actual_shipments_ids = list(ActualShipment.objects.filter(
  332. project_id=project_id,
  333. deleted__isnull=True).values_list('id', flat=True))
  334.  
  335. actual_shipments_data = self._session.execute(text(
  336. """
  337. SELECT
  338. key,
  339. date,
  340. value,
  341. created
  342. FROM actual_shipment_data
  343. WHERE actual_shipment_id IN :actual_shipments_ids
  344. """).bindparams(actual_shipments_ids=actual_shipments_ids)).all()
  345.  
  346. actual_shipments = pd.DataFrame.from_records(
  347. actual_shipments_data,
  348. columns=['key', 'date', 'value', 'created']
  349. )
  350.  
  351. actual_shipments['actual'] = actual_shipments['value'].replace(r'^\s*$', np.nan, regex=True).astype(float)
  352. actual_shipments.sort_values(['date', 'key', 'created'], inplace=True)
  353. actual_shipments.drop_duplicates(subset=['key', 'date'], keep='last', inplace=True)
  354. actual_shipments.drop(labels=['value', 'created'], axis=1, inplace=True)
  355. print(f"""
  356.  
  357. Метод _get_actual_shipments_df:
  358. {actual_shipments}
  359.  
  360.  
  361.  
  362. ---------- End _get_actual_shipments_df ---------
  363. """)
  364. return actual_shipments
  365.  
  366. def _processing_transition_elements(self, df: pd.DataFrame) -> pd.DataFrame:
  367. order_by_key = self._scenario_data.scenario.project.product_hierarchy.order_by_key
  368.  
  369. df['product_key'] = df.apply(lambda row: row['key'].split(FORECAST_LEVELS_SEPARATOR)[order_by_key], axis=1)
  370. system_keys = set(df['product_key'].tolist())
  371.  
  372. elements_with_transition_status = MasterDataElement.objects.filter(
  373. Q(info__manual_status_id=MasterDataElementStatus.Options.RELAUNCH.value) |
  374. Q(info__calculated_status_id=MasterDataElementStatus.Options.RELAUNCH.value)
  375. ).filter(system_key__in=system_keys).values_list('system_key', flat=True)
  376.  
  377. filtered_df = df[df['product_key'].map(lambda v: v not in set(elements_with_transition_status))]
  378.  
  379.  
  380. def _remove_innovations_from_df(self, df: pd.DataFrame) -> pd.DataFrame:
  381. order_by_key = self._scenario_data.scenario.project.product_hierarchy.order_by_key
  382.  
  383. df['product_key'] = df.apply(lambda row: row['key'].split(FORECAST_LEVELS_SEPARATOR)[order_by_key], axis=1)
  384. system_keys = set(df['product_key'].tolist())
  385.  
  386. elements_with_innovation_status = MasterDataElement.objects.filter(
  387. Q(info__manual_status_id=MasterDataElementStatus.Options.INNOVATION.value) | Q(info__calculated_status_id=MasterDataElementStatus.Options.INNOVATION.value)
  388. ).filter(system_key__in=system_keys).exclude(
  389. Q(innovation_as_dummy__is_archived=True) | Q(innovation_as_real__is_archived=True)
  390. ).values_list('system_key', flat=True)
  391.  
  392. filtered_df = df[df['product_key'].map(lambda v: v not in set(elements_with_innovation_status))]
  393. print(f"""
  394.  
  395.  
  396. Метод _remove_innovations_from_df:
  397.  
  398. system_keys:
  399. {system_keys}
  400.  
  401. elements_with_innovation_status:
  402. {elements_with_innovation_status}
  403.  
  404. filtered_df:
  405. {filtered_df}
  406.  
  407. ---------- End _remove_innovations_from_df ---------
  408. """)
  409. return filtered_df
  410.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement