Guest User

Untitled

a guest
Mar 1st, 2021
1,578
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.85 KB | None | 0 0
  1. import argparse
  2. import datetime
  3. import json
  4. import logging
  5. import hashlib
  6.  
  7. import apache_beam as beam
  8. from apache_beam.options.pipeline_options import PipelineOptions
  9. import apache_beam.transforms.window as window
  10. from apache_beam import pvalue
  11.  
  12.  
  13. class GroupWindowsIntoBatches(beam.PTransform):
  14.     """
  15.    A composite transform that groups Pub/Sub messages based on publish
  16.    time and outputs a list of dictionaries, where each contains one message
  17.    and its publish timestamp.
  18.    """
  19.  
  20.     def __init__(self, window_size):
  21.         # Convert minutes into seconds.
  22.         self.window_size = int(window_size * 60)
  23.  
  24.     def expand(self, pcoll):
  25.         return (
  26.                 pcoll
  27.                 # Assigns window info to each Pub/Sub message based on its
  28.                 # publish timestamp.
  29.                 | "Window into Fixed Intervals" >> beam.WindowInto(window.FixedWindows(self.window_size))
  30.                 | "Add timestamps to messages" >> beam.ParDo(AddTimestamps())
  31.                 | "Add Dummy Key" >> beam.Map(lambda elem: (None, elem))
  32.                 | "Groupby" >> beam.GroupByKey()
  33.                 | "Abandon Dummy Key" >> beam.MapTuple(lambda _, val: val)
  34.         )
  35.  
  36.  
  37. class AddTimestamps(beam.DoFn):
  38.     def process(self, element, publish_time=beam.DoFn.TimestampParam):
  39.         """Processes each incoming windowed element by extracting the Pub/Sub
  40.        message and its publish timestamp into a dictionary. `publish_time`
  41.        defaults to the publish timestamp returned by the Pub/Sub server. It
  42.        is bound to each element by Beam at runtime.
  43.        """
  44.  
  45.         element["publish_time"] = datetime.datetime.utcfromtimestamp(float(publish_time)).strftime(
  46.             "%Y-%m-%d %H:%M:%S.%f")
  47.         yield element
  48.  
  49.  
  50. class WriteBatchesToGCS(beam.DoFn):
  51.     def __init__(self, output_path):
  52.         self.output_path = output_path
  53.  
  54.     def process(self, batch, window=beam.DoFn.WindowParam):
  55.         """Write one batch per file to a Google Cloud Storage bucket. """
  56.  
  57.         ts_format = "%H:%M"
  58.         window_start = window.start.to_utc_datetime().strftime(ts_format)
  59.         window_end = window.end.to_utc_datetime().strftime(ts_format)
  60.         filename = f"{self.output_path}{'-'.join([window_start, window_end])}"
  61.         with beam.io.gcp.gcsio.GcsIO().open(filename=filename, mode="w") as f:
  62.             for element in batch:
  63.                 f.write("{}\n".format(json.dumps(element)).encode("utf-8"))
  64.  
  65.  
  66. class Split(beam.DoFn):
  67.     # These tags will be used to tag the outputs of this DoFn.
  68.     OUTPUT_TAG_BQ = 'BigQuery'
  69.     OUTPUT_TAG_GCS = 'GCS'
  70.  
  71.     def __init__(self, required_fields, rule_value, rule_key, pii_fields):
  72.         self.required_fields = required_fields
  73.         self.rule_value = rule_value
  74.         self.rule_key = rule_key
  75.         self.pii_fields = pii_fields.split(',')
  76.  
  77.     def process(self, element):
  78.         """
  79.        tags the input as it processes the original PCollection, hashes pii
  80.        """
  81.         # load the message as json, hash the PII fields
  82.         element_raw = json.loads(element.decode("utf-8"))
  83.         salt = 'CairoTokyoLondon'
  84.         element = {k:(hashlib.md5(f"{str(v)}{salt}".encode()).hexdigest() if k in self.pii_fields else v) for (k,v) in element_raw.items()}
  85.  
  86.         # check if the message has the expected structure, check the rule
  87.         if set(element.keys()) == set(self.required_fields) and element[self.rule_key] == self.rule_value:
  88.             yield pvalue.TaggedOutput(self.OUTPUT_TAG_BQ, element)
  89.         else:
  90.             yield pvalue.TaggedOutput(self.OUTPUT_TAG_GCS, element)
  91.  
  92.  
  93. def run(input_subscription, output_path_gcs, output_table_bq, output_table_bq_schema, rule_key, rule_value,pii_fields,
  94.         window_size=1.0, pipeline_args=None):
  95.     required_fields = [i.split(':')[0] for i in output_table_bq_schema.split(',')]
  96.     required_fields.remove('publish_time')
  97.     pipeline_options = PipelineOptions(
  98.         pipeline_args, streaming=True, save_main_session=True, direct_running_mode='in_memory', direct_num_workers=2
  99.     )
  100.  
  101.     with beam.Pipeline(options=pipeline_options) as pipeline:
  102.         tagged_lines_result = (pipeline | beam.io.ReadFromPubSub(subscription=input_subscription)
  103.                                | beam.ParDo(
  104.                     Split(required_fields=required_fields, rule_key=rule_key, rule_value=rule_value, pii_fields=pii_fields)).with_outputs(
  105.                     Split.OUTPUT_TAG_BQ,
  106.                     Split.OUTPUT_TAG_GCS))
  107.  
  108.         faulty_messages = tagged_lines_result[Split.OUTPUT_TAG_GCS] | "Window into GCS" >> GroupWindowsIntoBatches(
  109.             window_size) | "Write to GCS" >> beam.ParDo(
  110.             WriteBatchesToGCS(output_path_gcs))
  111.         accepted_messages = tagged_lines_result[Split.OUTPUT_TAG_BQ] | "Window into BQ" >> GroupWindowsIntoBatches(
  112.             window_size) | "FlatMap" >> beam.FlatMap(
  113.             lambda elements: elements) | "Write to BQ" >> beam.io.gcp.bigquery.WriteToBigQuery(table=output_table_bq,
  114.                                                                                                schema=(
  115.                                                                                                    output_table_bq_schema),
  116.                                                                                                write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
  117.                                                                                                create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED)
  118.  
  119.         pipeline.run().wait_until_finish()
  120.  
  121.  
  122. if __name__ == "__main__":  # noqa
  123.     logging.getLogger().setLevel(logging.INFO)
  124.  
  125.     parser = argparse.ArgumentParser()
  126.     parser.add_argument(
  127.         "--input_subscription",
  128.         dest='input_subscription',
  129.         help="The Cloud Pub/Sub subscription to read from.\n"
  130.              '"projects/<PROJECT_NAME>/subscriptions/<SUBSCRIPTION_NAME>".',
  131.     )
  132.     parser.add_argument(
  133.         "--window_size",
  134.         dest='window_size',
  135.         type=float,
  136.         default=1.0,
  137.         help="Output file's window size in number of minutes.",
  138.     )
  139.     parser.add_argument(
  140.         "--output_path_gcs",
  141.         dest='output_path_gcs',
  142.         required=True,
  143.         help="GCS Path of the output file including filename prefix.",
  144.     )
  145.     parser.add_argument(
  146.         "--output_table_bq",
  147.         dest='output_table_bq',
  148.         required=True,
  149.         help="BQ Table for output. Format: <project_id:dataset.table>",
  150.     )
  151.     parser.add_argument(
  152.         "--output_table_bq_schema",
  153.         dest='output_table_bq_schema',
  154.         required=True,
  155.         help="Output BQ Table Schema. Format: <col_name:type, col_name:type>",
  156.     )
  157.     parser.add_argument(
  158.         "--rule_key",
  159.         dest='rule_key',
  160.         required=True,
  161.         help="The key that should have hold <rule_value>. Used for determining whether the message should be accepted",
  162.     )
  163.     parser.add_argument(
  164.         "--rule_value",
  165.         dest='rule_value',
  166.         required=True,
  167.         help="The <rule_key> should hold this value. Used for determining whether the message should be accepted",
  168.     )
  169.     parser.add_argument(
  170.         "--pii_fields",
  171.         dest='pii_fields',
  172.         default='',
  173.         required=False,
  174.         help="The values in these keys will be hashed before loading to GCS or BQ.",
  175.     )
  176.     known_args, pipeline_args = parser.parse_known_args()
  177.  
  178.     run(
  179.         input_subscription=known_args.input_subscription,
  180.         output_path_gcs=known_args.output_path_gcs,
  181.         output_table_bq=known_args.output_table_bq,
  182.         window_size=known_args.window_size,
  183.         output_table_bq_schema=known_args.output_table_bq_schema,
  184.         rule_key=known_args.rule_key,
  185.         rule_value=known_args.rule_value,
  186.         pii_fields = known_args.pii_fields,
  187.         pipeline_args=pipeline_args,
  188.     )
  189.  
Advertisement
Add Comment
Please, Sign In to add comment