Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- trait DataFrameSuiteBase extends TestSuite
- with SharedSparkContext with DataFrameSuiteBaseLike { self: Suite =>
- import spark.implicits._
- override def beforeAll() {
- super.beforeAll()
- super.sqlBeforeAllTestCases()
- }
- override def afterAll() {
- super.afterAll()
- if (!reuseContextIfPossible) {
- SparkSessionProvider._sparkSession = null
- }
- }
- def csvStringToDataFrame(csv: String, schema: StructType = null): DataFrame = {
- val csvList = csv.stripMargin.lines.toList.filterNot(_.isEmpty)
- val dataset = context.parallelize(csvList).toDS
- val readCSV = spark.read.option("header", true)
- val dataFrameReader = if(schema !== null) {
- readCSV.schema(schema)
- } else {
- readCSV.option("inferSchema", true)
- }
- dataFrameReader.csv(dataset)
- }
- }
- case class Toolbox(session: SparkSession) extends Serializable {
- import session.implicits._
- def checkExtremes(dataframe: DataFrame, colName: String, bounds: Bounds): DataFrame = {
- if (hasColumn(dataframe, colName)) {
- val flagColumnName: String = s"${colName}_flag"
- val outliers =
- dataframe
- .withColumn(flagColumnName,
- when(!col(colName).between(bounds.lower, bounds.upper), true).otherwise(false))
- } else dataframe
- }
- }
- class ToolboxSpec extends FunSpec with DataFrameSuiteBase
- with DataFrameComparer
- with BeforeAndAfter {
- var Toolbox: Toolbox = _
- before {
- Toolbox = Toolbox(spark)
- }
- describe("Toolbox") {
- describe("checkExtremes") {
- it("should be checking for extreme values") {
- val inputCSV =
- """
- |"id","time","code","emi","v","t1","t2","t3","t4","t5","x_acc","y_acc","z_acc"
- |"46","2019-04-01 00:00:57","1",1444,"1",66,12,34,5,29,31,64,56,38,31,67,32,9,64,31,53
- |"46","2019-04-01 00:00:52","1",1515,"1",66,34,5,29,31,64,56,38,31,69,08,24,91,36,7
- |"46","2019-04-01 00:00:46","1",1452,"1",66,12,34,5,29,31,64,5,38,31,66,88,11,12,34,43
- |"47","2019-04-01 00:00:46","1",1452,"1",100,12,34,5,29,31,64,5,38,31,66,88,11,12,34,43
- |"77","2019-04-01 00:00:41","1",1319,"1",66,19,34,5,29,31,64,5,38,31,67,82,8,66,34,79
- """
- val inputColName = "t1"
- val flagColName = s"${inputColName}_flag"
- val expectedCSV =
- s"""
- |"id","time","code","emi","v","t1","t2","t3","t4","t5","x_acc","y_acc","z_acc","$flagColName"
- |"46","2019-04-01 00:00:57","1",1444,"1",66,12,34,5,29,31,64,56,38,31,67,32,9,64,31,53,false
- |"46","2019-04-01 00:00:52","1",1515,"1",66,34,5,29,31,64,56,38,31,69,08,24,91,36,7,false
- |"46","2019-04-01 00:00:46","1",1452,"1",66,12,34,5,29,31,64,5,38,31,66,88,11,12,34,43,false
- |"47","2019-04-01 00:00:46","1",1452,"1",100,12,34,5,29,31,64,5,38,31,66,88,11,12,34,43,true
- |"77","2019-04-01 00:00:41","1",1319,"1",66,19,34,5,29,31,64,5,38,31,67,82,8,66,34,79,false
- """
- val inputSchema = StructType(
- Array(
- StructField("id", StringType, false),
- StructField("time", TimestampType, false),
- StructField("code", StringType, true),
- StructField("emi", IntegerType, true),
- StructField("v", StringType, true),
- StructField("t1", DoubleType, true),
- StructField("t2", DoubleType, true),
- StructField("t3", DoubleType, true),
- StructField("t4", DoubleType, true),
- StructField("t5", DoubleType, true),
- StructField("x_acc", DoubleType, true),
- StructField("y_acc", DoubleType, true),
- StructField("z_acc", DoubleType, true)
- )
- )
- val expectedSchema = StructType(
- Array(
- StructField("id", StringType, false),
- StructField("time", TimestampType, false),
- StructField("code", StringType, true),
- StructField("emi", IntegerType, true),
- StructField("v", StringType, true),
- StructField("t1", DoubleType, true),
- StructField("t2", DoubleType, true),
- StructField("t3", DoubleType, true),
- StructField("t4", DoubleType, true),
- StructField("t5", DoubleType, true),
- StructField("x_acc", DoubleType, true),
- StructField("y_acc", DoubleType, true),
- StructField("z_acc", DoubleType, true),
- StructField(flagColName, BooleanType, true)
- )
- )
- val input = csvStringToDataFrame(inputCSV, inputSchema)
- val bounds = Bounds(10, 70)
- val output = Toolbox.checkExtremes(input, inputColName, bounds)
- output.show(5)
- output.printSchema()
- val expected = csvStringToDataFrame(expectedCSV, expectedSchema)
- // expected.printSchema()
- // assertSmallDatasetEquality(output, expected) // log 1
- // assertDataFrameEquals(expected, output) // log 2
- }
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement