Thursday, 8 November 2018

How to do an aggregate function on a Spark Dataframe using collect_set


In order to explain usage of collect_set, Lets create a Dataframe with 3 columns.
spark-shell --queue= *;

To adjust logging level use sc.setLogLevel(newLevel).
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 1.6.0
Spark context available as sc 
SQL context available as sqlContext.

scala>  val sqlcontext = new org.apache.spark.sql.SQLContext(sc)
sqlcontext: org.apache.spark.sql.SQLContext = org.apache.spark.sql.SQLContext@4f9a8d71  
 
scala> import org.apache.spark.sql.Column
scala> val BazarDF = Seq(
        ("Veg", "tomato", 1.99),
        ("Veg", "potato", 0.45),
        ("Fruit", "apple", 0.99),
        ("Fruit", "pineapple", 2.59)
         ).toDF("Type", "Item", "Price")
BazarDF: org.apache.spark.sql.DataFrame = [Type: string, Item: string, Price: double]

Now lets do a group by on Type column and get distinct values in Item column using collect_set()
scala> var aggBazarDF = BazarDF.groupBy($"Type")
         .agg(collect_set($"Item").as("All_Items"))
aggBazarDF: org.apache.spark.sql.DataFrame = [Type: string, All_Items: array<string>]
collect_set() : returns distinct values for a particular key specified.
Lets see the resultant Dataframe.
scala>  aggBazarDF.show()
+-----+------------------+
| Type|         All_Items|
+-----+------------------+
|  Veg|  [tomato, potato]|
|Fruit|[apple, pineapple]|
+-----+------------------+

What if we need to remove the square brackets?
We can make use of concat_ws()

scala> var aggBazarDFNew = BazarDF.groupBy($"Type")
     .agg(concat_ws(",",collect_set($"Item"))
                                 .as("All_Items"))
aggBazarDFNew: org.apache.spark.sql.DataFrame = [Type: string, All_Items: string]

scala> aggBazarDFNew.show()
+-----+---------------+
| Type|      All_Items|
+-----+---------------+
|  Veg|  tomato,potato|
|Fruit|apple,pineapple|
+-----+---------------+