Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pyspark.sql.functions as F
- from pyspark.sql import Row
- from pyspark.sql import *
- from pyspark.sql.functions import *
- import random
- from random import randint
- # Produce example data
- T = range(10**6)
- A = [random.uniform(0, 1) < 0.5 for _ in T]
- B = [None if a else randint(0, 9) for a in A]
- df = sc.parallelize([Row(T=t, A=a, B=b) for (t, a, b) in zip(T, A, B)])
- df = df.toDF().cache()
- w = Window.orderBy('T').rowsBetween(Window.currentRow, Window.unboundedFollowing)
- @pandas_udf('long')
- def backfill_udf(col):
- return col.fillna(method='backfill')
- %timeit -n1 -r1 df.withColumn('B', F.first(df.B, ignorenulls=True).over(w)).show()
- %timeit -n1 -r1 df.orderBy('t').withColumn('B', backfill_udf(df.B)).show()
- # Still works
- df.withColumn('B', F.first(df.B, ignorenulls=True).over(w)).filter(df.A).show()
- # This doesn't work -- all B rows are null
- df.orderBy('T').withColumn('B', backfill_udf(df.B)).filter(df.A).show()
- # Works again
- df.orderBy('T').withColumn('B', backfill_udf(df.B)).cache().filter(df.A).show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement