Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from datetime import date
- import os
- import re
- import pandas as pd
- import pyarrow as pa
- import pyarrow.parquet as pq
- import itertools
- from typing import List, Union, Dict
- from collections import defaultdict
- from mapping import CAST_TYPES_INDEXES, PARQUET_SCHEMAS
- from environment_handler import Environment
- class StorageData:
- _tmp_storage = os.path.join(Environment.parquets_dir, 'tmp')
- _cdate: str = Environment.change_date().replace('-', '')
- _logger = Environment.get_logger('Storage')
- os.makedirs(_tmp_storage, exist_ok=True)
- @staticmethod
- def _transpose_table(table: List[List[str]]) -> List[List[str]]:
- return list(map(list, itertools.zip_longest(*table)))
- @staticmethod
- def _remove_dublicates(table: List[List[str]]) -> List[List[str]]:
- return list(map(list, set(map(tuple, table))))
- @staticmethod
- def _cast_types(
- method: str,
- table_name: str,
- table: List[List[str]]
- ) -> List[List[Union[str, int]]]:
- for _type, numcols in CAST_TYPES_INDEXES[method].get(table_name, {}).items():
- if isinstance(numcols, int):
- numcols = [numcols]
- if _type == 'bool':
- for numcol in numcols:
- table[numcol] = list(map(
- lambda x: 1 if x is not None and x.lower() in ('true', '1') else (
- 0 if x is not None and x.lower() in ('false', '0') else None
- ),
- table[numcol]
- ))
- elif _type == 'float':
- for numcol in numcols:
- table[numcol] = list(map(
- lambda x: None if (x is None or x == 'N/A') else x.replace(',', '.'),
- table[numcol]
- ))
- elif _type == 'int':
- for numcol in numcols:
- table[numcol] = list(map(lambda x: None if (x == 'N/A') else x, table[numcol]))
- return table
- @classmethod
- def _save_partition(
- cls,
- method: str,
- table_name: str,
- table: List[List[Union[str, int]]],
- ipart: int
- ) -> None:
- suffix = 'part' + str(ipart)
- filename = '_'.join(
- ['Spark', method, table_name, cls._cdate, suffix]
- ) + '.parquet'
- try:
- schema = PARQUET_SCHEMAS[method][table_name]
- except KeyError:
- raise Exception(
- f'Parquet schema not found for {table_name} table for method {method}')
- table = pa.Table.from_arrays([pa.array(col) for col in table], schema=schema)
- pq.write_table(table, os.path.join(cls._tmp_storage, filename))
- cls._logger.info('Saved: ' + filename)
- @classmethod
- def save_partitions(
- cls,
- method: str,
- partitions: Dict[str, List[List[str]]],
- ipart: int
- ) -> None:
- for table_name, table in partitions.items():
- table = cls._cast_types(
- method,
- table_name,
- cls._transpose_table(cls._remove_dublicates(table))
- )
- cls._save_partition(method, table_name, table, ipart)
- @classmethod
- def _save_table(
- cls,
- method: str,
- table_name: str,
- table: List[List[Union[str, int]]]
- ) -> None:
- filename = '_'.join(['Spark', method, table_name, cls._cdate]) + '.parquet'
- try:
- schema = PARQUET_SCHEMAS[method][table_name]
- except KeyError:
- raise Exception(
- f'Parquet schema not found for {table_name} table for method {method}')
- table = pa.Table.from_arrays(list(map(pa.array, table)), schema=schema)
- pq.write_table(table, os.path.join(Environment.parquets_dir, filename))
- cls._logger.info('Saved: ' + filename)
- @classmethod
- def save_tables(
- cls,
- method: str,
- tables: Dict[str, List[List[str]]]
- ) -> None:
- for table_name, table in tables.items():
- table = cls._cast_types(
- method,
- table_name,
- cls._transpose_table(cls._remove_dublicates(table))
- )
- cls._save_table(method, table_name, table)
- @classmethod
- def concat_partitions(cls, method: str) -> None:
- pattern = rf'Spark_{method}_([a-zA-Z0-9]*?)_{cls._cdate}_part(\d+).parquet'
- matching = defaultdict(list)
- for filename in os.listdir(cls._tmp_storage):
- match = re.fullmatch(pattern, filename)
- if match:
- group = match.groups()
- matching[group[0]].append(group[1])
- for table_name, part_indexes in matching.items():
- parts = []
- filename = '_'.join(['Spark', method, table_name, cls._cdate])
- for i in part_indexes:
- full_path = os.path.join(cls._tmp_storage, filename + f'_part{i}.parquet')
- parts.append(pq.read_table(full_path))
- full_table = pa.concat_tables(parts)
- pq.write_table(full_table, os.path.join(Environment.parquets_dir, filename + '.parquet'))
- cls._logger.info(f"Concatenated {len(part_indexes)} partitions into {filename + '.parquet'}")
- @classmethod
- def cleanup_tmp_dir(cls):
- for filename in os.listdir(cls._tmp_storage):
- os.remove(os.path.join(cls._tmp_storage, filename))
- cls._logger.info('Temp parquets storage cleared')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement