Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import datetime
- import json
- import logging
- import hashlib
- import apache_beam as beam
- from apache_beam.options.pipeline_options import PipelineOptions
- import apache_beam.transforms.window as window
- from apache_beam import pvalue
- class GroupWindowsIntoBatches(beam.PTransform):
- """
- A composite transform that groups Pub/Sub messages based on publish
- time and outputs a list of dictionaries, where each contains one message
- and its publish timestamp.
- """
- def __init__(self, window_size):
- # Convert minutes into seconds.
- self.window_size = int(window_size * 60)
- def expand(self, pcoll):
- return (
- pcoll
- # Assigns window info to each Pub/Sub message based on its
- # publish timestamp.
- | "Window into Fixed Intervals" >> beam.WindowInto(window.FixedWindows(self.window_size))
- | "Add timestamps to messages" >> beam.ParDo(AddTimestamps())
- | "Add Dummy Key" >> beam.Map(lambda elem: (None, elem))
- | "Groupby" >> beam.GroupByKey()
- | "Abandon Dummy Key" >> beam.MapTuple(lambda _, val: val)
- )
- class AddTimestamps(beam.DoFn):
- def process(self, element, publish_time=beam.DoFn.TimestampParam):
- """Processes each incoming windowed element by extracting the Pub/Sub
- message and its publish timestamp into a dictionary. `publish_time`
- defaults to the publish timestamp returned by the Pub/Sub server. It
- is bound to each element by Beam at runtime.
- """
- element["publish_time"] = datetime.datetime.utcfromtimestamp(float(publish_time)).strftime(
- "%Y-%m-%d %H:%M:%S.%f")
- yield element
- class WriteBatchesToGCS(beam.DoFn):
- def __init__(self, output_path):
- self.output_path = output_path
- def process(self, batch, window=beam.DoFn.WindowParam):
- """Write one batch per file to a Google Cloud Storage bucket. """
- ts_format = "%H:%M"
- window_start = window.start.to_utc_datetime().strftime(ts_format)
- window_end = window.end.to_utc_datetime().strftime(ts_format)
- filename = f"{self.output_path}{'-'.join([window_start, window_end])}"
- with beam.io.gcp.gcsio.GcsIO().open(filename=filename, mode="w") as f:
- for element in batch:
- f.write("{}\n".format(json.dumps(element)).encode("utf-8"))
- class Split(beam.DoFn):
- # These tags will be used to tag the outputs of this DoFn.
- OUTPUT_TAG_BQ = 'BigQuery'
- OUTPUT_TAG_GCS = 'GCS'
- def __init__(self, required_fields, rule_value, rule_key, pii_fields):
- self.required_fields = required_fields
- self.rule_value = rule_value
- self.rule_key = rule_key
- self.pii_fields = pii_fields.split(',')
- def process(self, element):
- """
- tags the input as it processes the original PCollection, hashes pii
- """
- # load the message as json, hash the PII fields
- element_raw = json.loads(element.decode("utf-8"))
- salt = 'CairoTokyoLondon'
- element = {k:(hashlib.md5(f"{str(v)}{salt}".encode()).hexdigest() if k in self.pii_fields else v) for (k,v) in element_raw.items()}
- # check if the message has the expected structure, check the rule
- if set(element.keys()) == set(self.required_fields) and element[self.rule_key] == self.rule_value:
- yield pvalue.TaggedOutput(self.OUTPUT_TAG_BQ, element)
- else:
- yield pvalue.TaggedOutput(self.OUTPUT_TAG_GCS, element)
- def run(input_subscription, output_path_gcs, output_table_bq, output_table_bq_schema, rule_key, rule_value,pii_fields,
- window_size=1.0, pipeline_args=None):
- required_fields = [i.split(':')[0] for i in output_table_bq_schema.split(',')]
- required_fields.remove('publish_time')
- pipeline_options = PipelineOptions(
- pipeline_args, streaming=True, save_main_session=True, direct_running_mode='in_memory', direct_num_workers=2
- )
- with beam.Pipeline(options=pipeline_options) as pipeline:
- tagged_lines_result = (pipeline | beam.io.ReadFromPubSub(subscription=input_subscription)
- | beam.ParDo(
- Split(required_fields=required_fields, rule_key=rule_key, rule_value=rule_value, pii_fields=pii_fields)).with_outputs(
- Split.OUTPUT_TAG_BQ,
- Split.OUTPUT_TAG_GCS))
- faulty_messages = tagged_lines_result[Split.OUTPUT_TAG_GCS] | "Window into GCS" >> GroupWindowsIntoBatches(
- window_size) | "Write to GCS" >> beam.ParDo(
- WriteBatchesToGCS(output_path_gcs))
- accepted_messages = tagged_lines_result[Split.OUTPUT_TAG_BQ] | "Window into BQ" >> GroupWindowsIntoBatches(
- window_size) | "FlatMap" >> beam.FlatMap(
- lambda elements: elements) | "Write to BQ" >> beam.io.gcp.bigquery.WriteToBigQuery(table=output_table_bq,
- schema=(
- output_table_bq_schema),
- write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
- create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED)
- pipeline.run().wait_until_finish()
- if __name__ == "__main__": # noqa
- logging.getLogger().setLevel(logging.INFO)
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--input_subscription",
- dest='input_subscription',
- help="The Cloud Pub/Sub subscription to read from.\n"
- '"projects/<PROJECT_NAME>/subscriptions/<SUBSCRIPTION_NAME>".',
- )
- parser.add_argument(
- "--window_size",
- dest='window_size',
- type=float,
- default=1.0,
- help="Output file's window size in number of minutes.",
- )
- parser.add_argument(
- "--output_path_gcs",
- dest='output_path_gcs',
- required=True,
- help="GCS Path of the output file including filename prefix.",
- )
- parser.add_argument(
- "--output_table_bq",
- dest='output_table_bq',
- required=True,
- help="BQ Table for output. Format: <project_id:dataset.table>",
- )
- parser.add_argument(
- "--output_table_bq_schema",
- dest='output_table_bq_schema',
- required=True,
- help="Output BQ Table Schema. Format: <col_name:type, col_name:type>",
- )
- parser.add_argument(
- "--rule_key",
- dest='rule_key',
- required=True,
- help="The key that should have hold <rule_value>. Used for determining whether the message should be accepted",
- )
- parser.add_argument(
- "--rule_value",
- dest='rule_value',
- required=True,
- help="The <rule_key> should hold this value. Used for determining whether the message should be accepted",
- )
- parser.add_argument(
- "--pii_fields",
- dest='pii_fields',
- default='',
- required=False,
- help="The values in these keys will be hashed before loading to GCS or BQ.",
- )
- known_args, pipeline_args = parser.parse_known_args()
- run(
- input_subscription=known_args.input_subscription,
- output_path_gcs=known_args.output_path_gcs,
- output_table_bq=known_args.output_table_bq,
- window_size=known_args.window_size,
- output_table_bq_schema=known_args.output_table_bq_schema,
- rule_key=known_args.rule_key,
- rule_value=known_args.rule_value,
- pii_fields = known_args.pii_fields,
- pipeline_args=pipeline_args,
- )
Advertisement
Add Comment
Please, Sign In to add comment