What is PySpark broadcast join?

PySpark broadcast join is a method used in PySpark (a Python library for Apache Spark) to improve joint operation performance when one of the joined tables is tiny. The primary goal of a broadcast join is to eliminate data shuffling and network overhead associated with join operations, which can result in considerable speed benefits. A broadcast join sends the smaller table (or DataFrame) to all worker nodes, ensuring each worker node has a complete copy of the smaller table in memory. This allows the join operation to be conducted locally on each worker node, eliminating the network's data shuffle and transfer requirement.

We can use the broadcast() method from the pyspark.sql.functions module to use broadcast joins in PySpark.

Let us design two sample tables in PySpark to show the broadcast join. We'll utilize a combination of large table sales and smaller table products:

Sales Table

order_id

product_id

product_id

1

101

2

2

102

1

3

103

3

4

101

1

5

104

4

Products Table

product_id

product_name

price

101

Learn C++

910

102

Mobile: X1

14000

103

LCD

8000

104

Laptop

25000

Now, let's try a broadcast join in PySpark with the tables above:

from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast

# Initialize the Spark session
spark = SparkSession.builder.appName("Broadcast Join Example").getOrCreate()

# Create DataFrames from sample data
sales_data = [(1, 101, 2), (2, 102, 1), (3, 103, 3), (4, 101, 1), (5, 104, 4)]
products_data = [(101, "Learn C++", 10), (102, "Mobile: X1", 20), (103, "LCD", 30), (104, "Laptop", 40)]

sales_columns = ["order_id", "product_id", "quantity"]
products_columns = ["product_id", "product_name", "price"]

sales_df = spark.createDataFrame(sales_data, schema=sales_columns)
products_df = spark.createDataFrame(products_data, schema=products_columns)

# Perform broadcast join
result = sales_df.join(broadcast(products_df), sales_df["product_id"] == products_df["product_id"])

# Show result
result.show()
PySpark broadcast join

Explanation

This PySpark code performs a broadcast join between two DataFrames, sales_df and products_df, using the "product_id" column as the key. Here's the explanation of each part:

  • Lines 1–2: Import necessary modules from PySpark: SparkSession and broadcast function.

  • Line 5: Initialize a SparkSession with the name "Broadcast Join Example".

  • Lines 8–9: Create sample sales and product data as lists of tuples.

  • Lines 11–12: Define column names for the sales and products DataFrames.

  • Lines 14–15: Create the sales and products DataFrames using the sample data and column names.

  • Line 18: Perform a broadcast join between the sales and products DataFrames using the "product_id" column as the key. The broadcast function is used to hint that the smaller DataFrame (products_df) should be "product_id" broadcast to all worker nodes, optimizing the join performance.

  • Line 21: Show the result of the join by calling the show() method on the resulting DataFrame.

The result of the broadcast join above will be as follows:

Sales Table

order_id

product_id

quantity

product_id

product_name

price

1

101

2

101

Learn C++

10

2

102

1

102

Mobile: X1

20

3

103

3

103

LCD

30

4

101

1

101

Learn C++

10

5

104

4

104

Laptop

40

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved