Skip to content
On this page

Grouping and Aggregating Data- .groupBy()

Overview

The groupBy() function is used to group the rows in a DataFrame based on one or more columns. After grouping, you can apply various aggregate functions to summarize the grouped data. The groupBy() function allows you to perform complex data aggregations efficiently. It returns a GroupedData object, on which you can apply aggregate functions to get meaningful insights from your data.

Grouping by Single Column

You can use the groupBy() function to group the DataFrame based on a single column.

python
from pyspark.sql import SparkSession

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("GroupByExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30, "department": "HR"},
    {"name": "Bob", "age": 25, "department": "Finance"},
    {"name": "Charlie", "age": 35, "department": "HR"},
    {"name": "David", "age": 28, "department": "Finance"},
    {"name": "Eva", "age": 32, "department": "HR"},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Group by a single column
grouped_data = df.groupBy("department")

grouped_data.count().show()

Output:

+----------+-----+
|department|count|
+----------+-----+
|        HR|    3|
|  Finance |    2|
+----------+-----+

Grouping by Multiple Columns

You can use the groupBy() function to group the DataFrame based on multiple columns.

python
from pyspark.sql import SparkSession

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("GroupByExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30, "department": "HR"},
    {"name": "Bob", "age": 25, "department": "Finance"},
    {"name": "Charlie", "age": 35, "department": "HR"},
    {"name": "David", "age": 28, "department": "Finance"},
    {"name": "Eva", "age": 32, "department": "HR"},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Group by multiple columns
grouped_data = df.groupBy("department", "age")

grouped_data.count().show()

Output:

+----------+---+-----+
|department|age|count|
+----------+---+-----+
|        HR| 30|    1|
|  Finance | 25|    1|
|        HR| 35|    1|
|  Finance | 28|    1|
|        HR| 32|    1|
+----------+---+-----+

Single Aggregations

You can apply single aggregate functions to the grouped data.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("GroupByExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30, "department": "HR"},
    {"name": "Bob", "age": 25, "department": "Finance"},
    {"name": "Charlie", "age": 35, "department": "HR"},
    {"name": "David", "age": 28, "department": "Finance"},
    {"name": "Eva", "age": 32, "department": "HR"},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Group by department and calculate the average age
grouped_data = df.groupBy("department")

grouped_data.agg(avg("age")).show()

Output:

+----------+--------+
|department|avg(age)|
+----------+--------+
|        HR|    32.3|
|  Finance |    26.5|
+----------+--------+

Multiple Aggregations

You can apply multiple aggregate functions to the grouped data.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg, max, min

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("GroupByExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30, "department": "HR"},
    {"name": "Bob", "age": 25, "department": "Finance"},
    {"name": "Charlie", "age": 35, "department": "HR"},
    {"name": "David", "age": 28, "department": "Finance"},
    {"name": "Eva", "age": 32, "department": "HR"},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Group by department and calculate multiple aggregate functions
grouped_data = df.groupBy("department")

grouped_data.agg(avg("age"), max("age"), min("age")).show()

Output:

+----------+--------+--------+--------+
|department|avg(age)|max(age)|min(age)|
+----------+--------+--------+--------+
|        HR|    32.3|      35|      30|
|  Finance |    26.5|      28|      25|
+----------+--------+--------+--------+

The groupBy() function in PySpark allows you to group your DataFrame based on one or more columns, and the GroupedData object provides various aggregation functions to summarize the data efficiently. Whether you need to apply single or multiple aggregations, the groupBy() function enables you to perform powerful data analysis and gain insights from your structured data.

📖👉 Official Doc