Caching and Persisting DataFrames
Sveip for å vise menyen
Every time you call an action on a DataFrame, Spark re-executes the full lineage from the source. If your pipeline reads the same data multiple times – for example, computing several different aggregations – caching avoids redundant I/O and computation.
cache() and persist()
cache() stores the DataFrame in memory. persist() lets you choose the storage level – useful when the dataset is too large to fit entirely in RAM:
1234567891011121314151617181920212223242526272829import urllib.request from pyspark import StorageLevel from pyspark.sql import SparkSession urllib.request.urlretrieve( "https://staging-content-media-cdn.codefinity.com/courses/aa80ac56-0d50-49e8-9231-2c2374cd3e9d/flights.csv", "flights.csv" ) spark = SparkSession.builder \ .appName("Caching") \ .master("local[*]") \ .getOrCreate() flights_df = spark.read.csv("flights.csv", header=True, inferSchema=True) \ .fillna(0, subset=["DEPARTURE_DELAY", "ARRIVAL_DELAY"]) # Caching in memory flights_df.cache() # Triggering the cache by running an action flights_df.count() # Unpersisting before switching storage level flights_df.unpersist() # Persisting to memory and disk – spills to disk if data does not fit in memory flights_df.persist(StorageLevel.MEMORY_AND_DISK) flights_df.count()
The cache is populated on the first action after cache() or persist() is called. Subsequent actions reuse the cached data.
Unpersisting
Always unpersist cached DataFrames when you no longer need them to free memory:
1flights_df.unpersist()
When Caching Helps
12345678910111213141516171819import time from pyspark.sql.functions import avg, col # Without cache – Spark reads the file twice start = time.time() flights_df.filter(col("ARRIVAL_DELAY") > 60).count() flights_df.groupBy("AIRLINE").agg(avg("ARRIVAL_DELAY")).collect() print(f"Without cache: {time.time() - start:.2f}s") # With cache – file is read once flights_df.cache() flights_df.count() # Populating the cache start = time.time() flights_df.filter(col("ARRIVAL_DELAY") > 60).count() flights_df.groupBy("AIRLINE").agg(avg("ARRIVAL_DELAY")).collect() print(f"With cache: {time.time() - start:.2f}s") flights_df.unpersist()
Takk for tilbakemeldingene dine!
Spør AI
Spør AI
Spør om hva du vil, eller prøv ett av de foreslåtte spørsmålene for å starte chatten vår