Skip to content
On this page

Filter Rows in a DataFrame - .filter()

Overview

The filter() function is used to filter rows in a DataFrame based on certain conditions. The filter() function allows you to select rows that satisfy specific criteria, effectively removing unwanted rows from the DataFrame. Additionally, the filter() function can be used interchangeably with the where() function, as they are aliases of each other. It returns a new DataFrame with the filtered rows.

Single Condition

You can use the filter() function to filter rows based on a single condition.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("FilterExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30},
    {"name": "Bob", "age": 25},
    {"name": "Charlie", "age": 35},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Filter rows based on a condition
filtered_df = df.filter(df["age"] > 28)  # option 1
filtered_df = df.filter(col("age") > 28)  # option 2

filtered_df.show()

Output:

+-------+---+
|   name|age|
+-------+---+
|  Alice| 30|
|Charlie| 35|
+-------+---+

Multiple Conditions

You can use the filter() function to filter rows based on multiple conditions.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("FilterExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30, "city": "New York"},
    {"name": "Bob", "age": 25, "city": "San Francisco"},
    {"name": "Charlie", "age": 35, "city": "Los Angeles"},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Filter rows based on multiple conditions
filtered_df = df.filter((col("age") > 28) & (col("city") == "New York"))

filtered_df.show()

Output:

+----+---+---------+
|name|age|   city  |
+----+---+---------+
|Alice| 30|New York|
+----+---+---------+

SQL Condition

The filter() function can also handle SQL-like conditions using the expr() function from pyspark.sql.functions.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import expr

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("FilterExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30},
    {"name": "Bob", "age": 25},
    {"name": "Charlie", "age": 35},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Filter rows using SQL condition
filtered_df = df.filter(expr("age > 28"))  # option 1
filtered_df = df.filter("age > 28")  # option 2

filtered_df.show()

Output:

+-------+---+
|   name|age|
+-------+---+
|  Alice| 30|
|Charlie| 35|
+-------+---+

Based on Collection

You can use the filter() function to filter rows based on values contained in a collection, such as a list or set.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("FilterExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30},
    {"name": "Bob", "age": 25},
    {"name": "Charlie", "age": 35},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Filter rows based on a collection of values
names_to_filter = ["Bob", "Charlie"]
filtered_df = df.filter(col("name").isin(names_to_filter))

filtered_df.show()

Output:

+-------+---+
|   name|age|
+-------+---+
|    Bob| 25|
|Charlie| 35|
+-------+---+

like & rlike

The filter() function allows you to perform filtering based on patterns using the like and rlike operators.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("FilterExample").getOrCreate()

# Sample data as a list of dictionaries
data = [
    {"name": "Alice", "age": 30},
    {"name": "Bob", "age": 25},
    {"name": "Charlie", "age": 35},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Filter rows based on a pattern using 'like'
filtered_df = df.filter(col("name").like("A%"))

filtered_df.show()

Output:

+-----+---+
| name|age|
+-----+---+
|Alice| 30|
+-----+---+

Array Column

The filter() function can be used to filter rows based on conditions applied to array columns

.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import array_contains, col

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("FilterExample").getOrCreate()

# Sample data with array column as a list of dictionaries
data = [
    {"name": "Alice", "scores": [90, 85, 95]},
    {"name": "Bob", "scores": [80, 70, 75]},
    {"name": "Charlie", "scores": [95, 92, 87]},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Filter rows based on the presence of a value in the array column
filtered_df = df.filter(array_contains(col("scores"), 95))

filtered_df.show()

Output:

+-------+-------------+
|   name|       scores|
+-------+-------------+
|  Alice| [90, 85, 95]|
|Charlie| [95, 92, 87]|
+-------+-------------+

Nested Column

The filter() function can also be used to filter rows based on conditions applied to nested columns.

python
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession (if not already created)
spark = SparkSession.builder.appName("FilterExample").getOrCreate()

# Sample data with nested columns as a list of dictionaries
data = [
    {"name": "Alice", "address": {"city": "New York", "zipcode": 10001}},
    {"name": "Bob", "address": {"city": "San Francisco", "zipcode": 94105}},
    {"name": "Charlie", "address": {"city": "Los Angeles", "zipcode": 90001}},
]

# Create a DataFrame
df = spark.createDataFrame(data)

# Filter rows based on conditions applied to nested columns
filtered_df = df.filter(col("address.city") == "New York")

filtered_df.show()

Output:

+-----+----------------+
| name|         address|
+-----+----------------+
|Alice|{New York, 10001}|
+-----+----------------+

The filter() function in PySpark (alias: where()) is a versatile tool for selecting specific rows from a DataFrame based on a wide range of conditions. It provides flexibility and enables you to apply various filtering techniques to tailor your DataFrame to the data analysis and processing tasks at hand.

📖👉 Official Doc