iceberg icon indicating copy to clipboard operation
iceberg copied to clipboard

Spark: Doing a Coalesce and foreachpartitions in spark directly on an iceberg table is leaking memory heavy iterators

Open jkolash opened this issue 7 months ago • 21 comments

Apache Iceberg version

1.5.0

Query engine

Spark

Please describe the bug 🐞

Summary

Doing the following should not leak any significant amount of memory.

sparkSession.sql("select * from icebergcatalog.db.table").coalesce(4).foreachPartition( (iterator) -> {
 while (iterator.hasNext()) iterator.next();
});

A workaround is to use repartition() instead however this requires more resources to handle spilling shuffling etc..

Spark version: Spark 3.4.X

Details

The Below code can be run on a sufficiently large iceberg table.

  static AtomicInteger partitionCounter = new AtomicInteger(0);

    static void reproduceBug(SparkSession sparkSession, String table) {

        sparkSession.sql("select * from "+table).coalesce(4).foreachPartition( (iterator) -> {
            int partition = partitionCounter.getAndIncrement();
            AtomicLong rowCounter = ThreadLocal.withInitial(() -> new AtomicLong(0)).get();

            while (iterator.hasNext()) {
                iterator.next();
                if (rowCounter.getAndIncrement() % 100000 == 0) {
                    System.out.println(partition + " " + rowCounter.get());
                }
            }
        });
    }

The following image is me running the reproduceBug method over sufficiently large table that we have in our environment with ~500 columns.

Image

The following image shows the "Dominators" report in VisualVM org.apache.spark.TaskContextImpl

Image

Digging deeper we see that the onCallbacks is keeping an anonymous class inside org.apache.spark.sql.execution.datasources.v2.DataSourceRDD and that is holding a reference to org.apache.iceberg.spark.source.RowDataReader

I believe this callback is added here https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala#L90 We also see the org.apache.iceberg.util.Filter iterator holding a heavy reference. Image

Exploring the problem

Is this inherently a bug in org.apache.spark.sql.execution.datasources.v2.DataSourceRDD? Or should iterators not hold onto state no longer needed once advanced to the end? Is the iterator even exhausted? Once an iterator is exhausted there is no longer a need for referencing. However this kind of breaks the concept of a CloseableIterator which has an explicit close vs an implicit close where you could detect hasNext() is false and auto-close. then even ignore a duplicate close() as it was handled by an implicit close() of iterator exhaustion. I believe an iterator accumulating hundreds of megabytes of state kind of breaks the implicit "expected" contract of an iterator being a streaming set of objects. There might even be a distinction between closing and simply holding onto large object references.

Digging deeper I see items = org.apache.iceberg.parquet.ParquetReader$FileIterator#6 holding onto a model reference. It might be possible to null out the model references when hasNext() is false.

   @Override
    public boolean hasNext() {
      boolean hasNext = valuesRead < totalValues; 
      if (!hasNext) {
        this.model = null;
      }
      return hasNext;
    }
Image

Willingness to contribute

  • [x] I can contribute a fix for this bug independently
  • [x] I would be willing to contribute a fix for this bug with guidance from the Iceberg community
  • [ ] I cannot contribute a fix for this bug at this time

jkolash avatar Jun 11 '25 14:06 jkolash

I was able to make some hacky changes that reduced memory usage in this draft PR. https://github.com/apache/iceberg/pull/13298 mainly to show that these were the critical objects that needed to be GC'd that were no longer needed.

I also tested

   String[] paths = sparkSession.sql("select file_path from "+table+".files").collectAsList().stream().map( row -> row.getString(0)).toArray(String[]::new);
        System.out.println(Arrays.asList(paths));
        sparkSession.read().load(paths).coalesce(4).foreachPartition((iterator) -> {
                    int partition = partitionCounter.getAndIncrement();
                    AtomicLong rowCounter = ThreadLocal.withInitial(() -> new AtomicLong(0)).get();

                    while (iterator.hasNext()) {
                        iterator.next();
                        if (rowCounter.getAndIncrement() % 100000 == 0) {
                            System.out.println(partition + " " + rowCounter.get());
                        }
                    }
                }
        );

To show that it is the new read path through iceberg that is the issue and when loading/processing the parquet files directly the issue doesn't manifest.

The TaskContextImpl still has the onCompleteCalllbacks but the FileScanRDD is much smaller

Image

jkolash avatar Jun 11 '25 19:06 jkolash

Can your force V2 Sources and try the Parquet version again (FileScanRDD is a different code path)? It feels odd to me that the iterator we are making should null itself out when fully iterated. Feels like that should be on the last "next" call as is currently in the code or in the close method?

Or is the issue here that we are saving this task context for the UI so we don't actually ever drop the iterator reference and close is never called?

RussellSpitzer avatar Jun 12 '25 19:06 RussellSpitzer

I'm not sure what you mean by force V2 sources and try the Parquet version again. I was comparing both the v1 and v2 memory usage and the v2 path using iceberg consumes significantly more memory.

Is there a v2 equivalent to

sparkSession.read().load(paths)

Speaking of should be handled on the last next() call. I don't see that in my experimental changes which did reduce memory usage. https://github.com/apache/iceberg/pull/13298/files#diff-d80c15b3e5376265436aeab8b79d5a92fb629c6b81f58ad10a11b9b9d3bfcffcR134

jkolash avatar Jun 12 '25 19:06 jkolash

Or is the issue here that we are saving this task context for the UI so we don't actually ever drop the iterator reference and close is never called?

The close() will not likely be called until the job completes. the call back being held onto is what will do the close() operation and that wont likely happen until the job completes.

          context.addTaskCompletionListener[Unit] { _ =>
            // In case of early stopping before consuming the entire iterator,
            // we need to do one more metric update at the end of the task.
            CustomMetrics
              .updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
            iter.forceUpdateMetrics()
            reader.close()
          }

https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala#L90

jkolash avatar Jun 12 '25 19:06 jkolash

Let me know if you want me to write something to produce synthetic data for this so it can be reproduced without using real data. That way you can reproduce it on your environment.

jkolash avatar Jun 12 '25 19:06 jkolash

I'm not sure what you mean by force V2 sources and try the Parquet version again. I was comparing both the v1 and v2 memory usage and the v2 path using iceberg consumes significantly more memory.

Is there a v2 equivalent to

sparkSession.read().load(paths) Speaking of should be handled on the last next() call. I don't see that in my experimental changes which did reduce memory usage. https://github.com/apache/iceberg/pull/13298/files#diff-d80c15b3e5376265436aeab8b79d5a92fb629c6b81f58ad10a11b9b9d3bfcffcR134

https://github.com/apache/iceberg/pull/13298/files#diff-7f0d08eb9c160db9e4fefa082993c58f2fe749f9be4f108c26cb4201547f2521L139

RussellSpitzer avatar Jun 12 '25 19:06 RussellSpitzer

Ok - for my own recap here

  1. Spark is holding onto task contexts in order to invoke callbacks to get metric information till the end of the job
  2. Task Context through this callback holds onto the iterator the DSV2 source makes
  3. Our iterator ends up being slightly heavy, it holds onto a reader for every column in the table which means for a very wide table we end up holding onto a parquet reader for every column in the table. This ends up adding up to a significant amount of memory.

Now I think there are a few avenues to make this better

  1. As @jkolash suggests we can preemptively null out the reference to the readers when the iterator has been exhausted. One issue here is that a partially consumed iterator will still hold onto the model until the job has completed.

  2. We can work on the Spark side to make sure that the iterator is actually released when exhausted or the task has completed.

Thoughts @szehon-ho + @aokolnychyi ?

RussellSpitzer avatar Jun 12 '25 20:06 RussellSpitzer

I dont see immediately a spark api to easily remove the callback once an iter is exhausted.

Maybe the iceberg change is easier. but if you have some idea on spark side, it make sense to explore?

szehon-ho avatar Jun 12 '25 23:06 szehon-ho

I wanted to bring up another point that @RussellSpitzer helped me identify which was this issue reproduces in a similar way with the v2 Datasources

when I Ran https://github.com/apache/iceberg/issues/13297#issuecomment-2963910701 but set the spark option spark.sql.sources.useV1SourceList to an empty string it reproduced via the new v2 datasource code, but not the v1 datasource.

Image

So this issue also needs to be addressed in spark as well.

On the spark side I think the close() call likely should be done on task completion not via a callback on job completion?, and the metrics can be collected before that close?

jkolash avatar Jun 13 '25 00:06 jkolash

The issue is that coalesce is combining N tasks into 1 task but really there are N tasks, so coalesce should be more subtask aware?

jkolash avatar Jun 13 '25 00:06 jkolash

Reading more deeply in DataSourceRDD.scala it seems the callback is there due to a partially consumed iterator, but spark wraps the raw iterator using the MetricsBatchIterator that could detect whether the iterator was exhausted via hasNext() == false and remove the onComplete callback?

jkolash avatar Jun 13 '25 00:06 jkolash

https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/TaskContextImpl.scala#L71 is currently a stack so that would need to change to something more like a linked list to alllow removing arbitrary listeners.

jkolash avatar Jun 13 '25 00:06 jkolash

Actually I don't think any spark api changes are needed the callback and the MetricsBatchIterator are both created in the DataSourceRDD

In java you could write something like this.

   static class HeavyCallback implements Runnable{
        byte[] heavyReference = new byte[1000000];
        @Override
        public void run() {
            
        }
    }
    
    static class InvokeOnceCallback implements Runnable {
        final AtomicReference<Runnable> callbackHandle = new AtomicReference<>();
        
        InvokeOnceCallback(Runnable target) {
            callbackHandle.set(target);
        }
        
        @Override
        public void run() {
            if (callbackHandle.get() != null) {
                callbackHandle.get().run();
                callbackHandle.set(null);
            }
        }
    }
    
    static void showcaseIndirectCallback() {
        InvokeOnceCallback callback = new InvokeOnceCallback(new HeavyCallback());

        //this can be called from either the iterable exhaustion or the original callback.
        //Once invoked the reference is removed and can be GC'd
        callback.run(); 
    }

jkolash avatar Jun 13 '25 10:06 jkolash

Hi, my team and I had a similar problem while querying and optimizing tables with a huge number of data and delete files. Spark always ran OOM, because resources (Parquet readers, ...) weren't freed properly. I did some debugging and identified CloseableIterable as the culprit. Unfortunately, I never had time to upstream it properly.

CloseableIterable has the following two problems:

  1. When turning a CloseableIterable into an iterator, the returned CloseableIterator doesn't properly close the iterable, thereby causing a resource leak. By returning a custom instance of CloseableIterator which keeps track of the closable iterable, we properly close the iterable when the iterator gets closed.
  2. When combining an iterable with a closeable using CloseableIterable.combine we have a resource leak if the iterable implements CloseableIterable because we never call close on the iterable. By overloading combine for CloseableIterable we can ensure that we close the iterable in addition to the closeable.

With the following patch applied, the queries and table optimizations ran without OOM and much lower memory limits.

From 6e2cfd71c2dcecf432f2db2119ac530674af5042 Mon Sep 17 00:00:00 2001
From: Emmanuel Pescosta <[email protected]>
Date: Thu, 16 Jan 2025 12:19:46 +0100
Subject: [PATCH] fix(iceberg.io): Properly close CloseableIterable when
 turning it into an iterator

A CloseableIterator created from a CloseableIterable didn't close
the iterable when the iterator is closed. This may cause a resource
leak.
---
 .../apache/iceberg/io/CloseableIterable.java  | 40 ++++++++++++++++++-
 1 file changed, 38 insertions(+), 2 deletions(-)

diff --git a/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java b/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java
index 06323612a..240c1ef8f 100644
--- a/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java
+++ b/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java
@@ -51,7 +51,7 @@ public interface CloseableIterable<T> extends Iterable<T>, Closeable {
 
       @Override
       public CloseableIterator<E> iterator() {
-        return CloseableIterator.withClose(iterable.iterator());
+        return closeableIteratorOf(iterable, this);
       }
     };
   }
@@ -69,7 +69,43 @@ public interface CloseableIterable<T> extends Iterable<T>, Closeable {
 
       @Override
       public CloseableIterator<E> iterator() {
-        return CloseableIterator.withClose(iterable.iterator());
+        return closeableIteratorOf(iterable, this);
+      }
+    };
+  }
+
+  static <E> CloseableIterable<E> combine(CloseableIterable<E> iterable, Closeable closeable) {
+    return new CloseableIterable<E>() {
+      @Override
+      public void close() throws IOException {
+        closeable.close();
+        iterable.close();
+      }
+
+      @Override
+      public CloseableIterator<E> iterator() {
+        return closeableIteratorOf(iterable, this);
+      }
+    };
+  }
+
+  private static <E> CloseableIterator<E> closeableIteratorOf(
+      Iterable<E> iterable, Closeable closeable) {
+    Iterator<E> iterator = iterable.iterator();
+    return new CloseableIterator<E>() {
+      @Override
+      public void close() throws IOException {
+        closeable.close();
+      }
+
+      @Override
+      public boolean hasNext() {
+        return iterator.hasNext();
+      }
+
+      @Override
+      public E next() {
+        return iterator.next();
       }
     };
   }
-- 
2.49.0

@jkolash Can you please check if this would fix your problem?

emmanuel099 avatar Jun 16 '25 08:06 emmanuel099

Hi @emmanuel099 Those changes had no impact on my issue.

jkolash avatar Jun 16 '25 11:06 jkolash

@szehon-ho Let me know if a PR on the spark side would be welcome or not. I don't think the fix for this in DataSourceRDD will be difficult. to be extra safe we could compute a value object on iterator exhaustion which is much smaller that can be applied on the callback close so multiple threads aren't updating the metrics as that may not be thread safe.

jkolash avatar Jun 16 '25 11:06 jkolash

So I made the following changes to just spark, without my iceberg changes. and I was able to not OOM.

Image
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 67e77a9786..288e9de16c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,8 +17,9 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import scala.language.existentials
+import java.util.concurrent.atomic.AtomicReference
 
+import scala.language.existentials
 import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
@@ -29,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
+
 class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition])
   extends Partition with Serializable
 
@@ -74,24 +76,33 @@ class DataSourceRDD(
           val inputPartition = inputPartitions(currentIndex)
           currentIndex += 1
 
+          val exhaustCallback = new InvokeOnceCallback()
+
           // TODO: SPARK-25083 remove the type erasure hack in data source scan
           val (iter, reader) = if (columnarReads) {
             val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
             val iter = new MetricsBatchIterator(
-              new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
+              new PartitionIterator[ColumnarBatch](batchReader, customMetrics), exhaustCallback)
             (iter, batchReader)
           } else {
             val rowReader = partitionReaderFactory.createReader(inputPartition)
             val iter = new MetricsRowIterator(
-              new PartitionIterator[InternalRow](rowReader, customMetrics))
+              new PartitionIterator[InternalRow](rowReader, customMetrics), exhaustCallback)
             (iter, rowReader)
           }
+
+          exhaustCallback.setCallback(callback = new Runnable() {
+            override def run(): Unit = {
+              // In case of early stopping before consuming the entire iterator,
+              // we need to do one more metric update at the end of the task.
+              CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
+              iter.forceUpdateMetrics()
+              reader.close()
+            }
+          })
+
           context.addTaskCompletionListener[Unit] { _ =>
-            // In case of early stopping before consuming the entire iterator,
-            // we need to do one more metric update at the end of the task.
-            CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
-            iter.forceUpdateMetrics()
-            reader.close()
+            exhaustCallback.run()
           }
           currentIter = Some(iter)
           hasNext
@@ -107,6 +118,21 @@ class DataSourceRDD(
   }
 }
 
+private class InvokeOnceCallback extends Runnable {
+  val originalCallback = new AtomicReference[Runnable](null)
+
+  override def run(): Unit = {
+    if (originalCallback.get() != null) {
+      originalCallback.get().run()
+      originalCallback.set(null);
+    }
+  }
+
+  def setCallback(callback: Runnable): Unit = {
+    originalCallback.set(callback);
+  }
+}
+
 private class PartitionIterator[T](
     reader: PartitionReader[T],
     customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
@@ -151,14 +177,16 @@ private class MetricsHandler extends Logging with Serializable {
   }
 }
 
-private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] {
+private abstract class MetricsIterator[I](
+    iter: Iterator[I],
+    exhaustionCallback: InvokeOnceCallback) extends Iterator[I] {
   protected val metricsHandler = new MetricsHandler
 
   override def hasNext: Boolean = {
     if (iter.hasNext) {
       true
     } else {
-      forceUpdateMetrics()
+      exhaustionCallback.run()
       false
     }
   }
@@ -167,7 +195,8 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I]
 }
 
 private class MetricsRowIterator(
-    iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) {
+    iter: Iterator[InternalRow],
+    callback: InvokeOnceCallback) extends MetricsIterator[InternalRow](iter, callback) {
   override def next(): InternalRow = {
     val item = iter.next
     metricsHandler.updateMetrics(1)
@@ -176,7 +205,8 @@ private class MetricsRowIterator(
 }
 
 private class MetricsBatchIterator(
-    iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) {
+    iter: Iterator[ColumnarBatch],
+    callback: InvokeOnceCallback) extends MetricsIterator[ColumnarBatch](iter, callback) {
   override def next(): ColumnarBatch = {
     val batch: ColumnarBatch = iter.next
     metricsHandler.updateMetrics(batch.numRows)

jkolash avatar Jun 16 '25 17:06 jkolash

Yea of course, it is welcome on spark side. btw, to learn, why is it only for coalesce + foreach? is it for all wide iceberg tables, and coalesce just makes it more vulnerable?

szehon-ho avatar Jun 16 '25 18:06 szehon-ho

I believe it is happening because normally these would be separate tasks but coalesce kind of hides each task and combines multiple partitions into 1 partition so the task cannot "complete" and the callbacks are held much longer.

Also I ran with the parquet v2 code https://github.com/apache/iceberg/issues/13297#issuecomment-2968557949

and a similar fix needs to be applied here I believe. https://github.com/apache/spark/blob/59e6b5b7d350a1603502bc92e3c117311ab2cbb6/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala#L312

Image

is it for all wide iceberg tables, and coalesce just makes it more vulnerable?

This particular table is ~ 500 columns wide and with nesting. I can produce a synthetic dataset later or as part of this issue so it can be reproduced by anyone.

jkolash avatar Jun 16 '25 20:06 jkolash

On the v2 parquet reader side quest. This total set of spark changes allows the v2 parquet reader to work

Image

I introduce a GarbageCollectableRecordReader that nulls out the delegate once close() has been called. The close() was called correctly by callers but the callback still prevented GC even if close had been called.

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 67e77a9786..288e9de16c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,8 +17,9 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import scala.language.existentials
+import java.util.concurrent.atomic.AtomicReference
 
+import scala.language.existentials
 import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
@@ -29,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
+
 class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition])
   extends Partition with Serializable
 
@@ -74,24 +76,33 @@ class DataSourceRDD(
           val inputPartition = inputPartitions(currentIndex)
           currentIndex += 1
 
+          val exhaustCallback = new InvokeOnceCallback()
+
           // TODO: SPARK-25083 remove the type erasure hack in data source scan
           val (iter, reader) = if (columnarReads) {
             val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
             val iter = new MetricsBatchIterator(
-              new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
+              new PartitionIterator[ColumnarBatch](batchReader, customMetrics), exhaustCallback)
             (iter, batchReader)
           } else {
             val rowReader = partitionReaderFactory.createReader(inputPartition)
             val iter = new MetricsRowIterator(
-              new PartitionIterator[InternalRow](rowReader, customMetrics))
+              new PartitionIterator[InternalRow](rowReader, customMetrics), exhaustCallback)
             (iter, rowReader)
           }
+
+          exhaustCallback.setCallback(callback = new Runnable() {
+            override def run(): Unit = {
+              // In case of early stopping before consuming the entire iterator,
+              // we need to do one more metric update at the end of the task.
+              CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
+              iter.forceUpdateMetrics()
+              reader.close()
+            }
+          })
+
           context.addTaskCompletionListener[Unit] { _ =>
-            // In case of early stopping before consuming the entire iterator,
-            // we need to do one more metric update at the end of the task.
-            CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
-            iter.forceUpdateMetrics()
-            reader.close()
+            exhaustCallback.run()
           }
           currentIter = Some(iter)
           hasNext
@@ -107,6 +118,21 @@ class DataSourceRDD(
   }
 }
 
+private class InvokeOnceCallback extends Runnable {
+  val originalCallback = new AtomicReference[Runnable](null)
+
+  override def run(): Unit = {
+    if (originalCallback.get() != null) {
+      originalCallback.get().run()
+      originalCallback.set(null);
+    }
+  }
+
+  def setCallback(callback: Runnable): Unit = {
+    originalCallback.set(callback);
+  }
+}
+
 private class PartitionIterator[T](
     reader: PartitionReader[T],
     customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
@@ -151,14 +177,16 @@ private class MetricsHandler extends Logging with Serializable {
   }
 }
 
-private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] {
+private abstract class MetricsIterator[I](
+    iter: Iterator[I],
+    exhaustionCallback: InvokeOnceCallback) extends Iterator[I] {
   protected val metricsHandler = new MetricsHandler
 
   override def hasNext: Boolean = {
     if (iter.hasNext) {
       true
     } else {
-      forceUpdateMetrics()
+      exhaustionCallback.run()
       false
     }
   }
@@ -167,7 +195,8 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I]
 }
 
 private class MetricsRowIterator(
-    iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) {
+    iter: Iterator[InternalRow],
+    callback: InvokeOnceCallback) extends MetricsIterator[InternalRow](iter, callback) {
   override def next(): InternalRow = {
     val item = iter.next
     metricsHandler.updateMetrics(1)
@@ -176,7 +205,8 @@ private class MetricsRowIterator(
 }
 
 private class MetricsBatchIterator(
-    iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) {
+    iter: Iterator[ColumnarBatch],
+    callback: InvokeOnceCallback) extends MetricsIterator[ColumnarBatch](iter, callback) {
   override def next(): ColumnarBatch = {
     val batch: ColumnarBatch = iter.next
     metricsHandler.updateMetrics(batch.numRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
index 5951c1d8dd..165fe88bef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.execution.datasources.v2.parquet
 
 import java.time.ZoneId
-
 import org.apache.hadoop.mapred.FileSplit
 import org.apache.hadoop.mapreduce._
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
@@ -26,7 +25,6 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate}
 import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS}
 import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader}
 import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
-
 import org.apache.spark.TaskContext
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
@@ -36,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
 import org.apache.spark.sql.execution.WholeStageCodegenExec
-import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile}
 import org.apache.spark.sql.execution.datasources.parquet._
 import org.apache.spark.sql.execution.datasources.v2._
 import org.apache.spark.sql.internal.SQLConf
@@ -45,6 +43,8 @@ import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.SerializableConfiguration
 
+import java.util.concurrent.atomic.AtomicReference
+
 /**
  * A factory used to create Parquet readers.
  *
@@ -158,7 +158,7 @@ case class ParquetPartitionReaderFactory(
   override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
     val fileReader = if (aggregation.isEmpty) {
       val vectorizedReader = createVectorizedReader(file)
-      vectorizedReader.enableReturningBatches()
+      vectorizedReader.delegate.asInstanceOf[VectorizedParquetRecordReader].enableReturningBatches()
 
       new PartitionReader[ColumnarBatch] {
         override def next(): Boolean = vectorizedReader.nextKeyValue()
@@ -205,7 +205,8 @@ case class ParquetPartitionReaderFactory(
         InternalRow,
           Option[FilterPredicate], Option[ZoneId],
           RebaseSpec,
-          RebaseSpec) => RecordReader[Void, T]): RecordReader[Void, T] = {
+          RebaseSpec) =>
+        GarbageCollectableRecordReader[Void, T]): GarbageCollectableRecordReader[Void, T] = {
     val conf = broadcastedConf.value.value
 
     val filePath = file.toPath
@@ -279,7 +280,7 @@ case class ParquetPartitionReaderFactory(
       pushed: Option[FilterPredicate],
       convertTz: Option[ZoneId],
       datetimeRebaseSpec: RebaseSpec,
-      int96RebaseSpec: RebaseSpec): RecordReader[Void, InternalRow] = {
+      int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, InternalRow] = {
     logDebug(s"Falling back to parquet-mr")
     val taskContext = Option(TaskContext.get())
     // ParquetRecordReader returns InternalRow
@@ -296,17 +297,57 @@ case class ParquetPartitionReaderFactory(
     }
     val readerWithRowIndexes = ParquetRowIndexUtil.addRowIndexToRecordReaderIfNeeded(
       reader, readDataSchema)
-    val iter = new RecordReaderIterator(readerWithRowIndexes)
+    val delegatingRecordReader =
+      new GarbageCollectableRecordReader[Void, InternalRow](readerWithRowIndexes)
     // SPARK-23457 Register a task completion listener before `initialization`.
-    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
-    readerWithRowIndexes
+    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => delegatingRecordReader.close()))
+    delegatingRecordReader
+  }
+
+  private class GarbageCollectableRecordReader[K, V](reader: RecordReader[K, V])
+    extends RecordReader[K, V] {
+    val delegate = new AtomicReference[RecordReader[K, V]](reader)
+
+    override def initialize(inputSplit: InputSplit,
+                            taskAttemptContext: TaskAttemptContext): Unit = {
+      delegate.get().initialize(inputSplit, taskAttemptContext)
+    }
+
+    override def nextKeyValue(): Boolean = {
+      delegate.get().nextKeyValue()
+    }
+
+    override def getCurrentKey: K = {
+      delegate.get().getCurrentKey
+    }
+
+    override def getCurrentValue: V = {
+      delegate.get().getCurrentValue
+    }
+
+    override def getProgress: Float = {
+      if (delegate.get() == null) {
+        1.0f
+      } else {
+        delegate.get().getProgress
+      }
+    }
+
+    override def close(): Unit = {
+      if (delegate.get() != null) {
+        delegate.get().close()
+        delegate.set(null)
+      }
+    }
   }
 
-  private def createVectorizedReader(file: PartitionedFile): VectorizedParquetRecordReader = {
-    val vectorizedReader = buildReaderBase(file, createParquetVectorizedReader)
+  private def createVectorizedReader(file: PartitionedFile):
+    GarbageCollectableRecordReader[Void, InternalRow] = {
+    val gcReader = buildReaderBase(file, createParquetVectorizedReader)
+    val vectorizedReader = gcReader.delegate.get()
       .asInstanceOf[VectorizedParquetRecordReader]
     vectorizedReader.initBatch(partitionSchema, file.partitionValues)
-    vectorizedReader
+    gcReader
   }
 
   private def createParquetVectorizedReader(
@@ -314,7 +355,7 @@ case class ParquetPartitionReaderFactory(
       pushed: Option[FilterPredicate],
       convertTz: Option[ZoneId],
       datetimeRebaseSpec: RebaseSpec,
-      int96RebaseSpec: RebaseSpec): VectorizedParquetRecordReader = {
+      int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, InternalRow] = {
     val taskContext = Option(TaskContext.get())
     val vectorizedReader = new VectorizedParquetRecordReader(
       convertTz.orNull,
@@ -323,11 +364,14 @@ case class ParquetPartitionReaderFactory(
       int96RebaseSpec.mode.toString,
       int96RebaseSpec.timeZone,
       enableOffHeapColumnVector && taskContext.isDefined,
-      capacity)
-    val iter = new RecordReaderIterator(vectorizedReader)
+      capacity).asInstanceOf[RecordReader[Void, InternalRow]]
+
+    val delegatingRecordReader =
+      new GarbageCollectableRecordReader[Void, InternalRow](vectorizedReader)
+
     // SPARK-23457 Register a task completion listener before `initialization`.
-    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
+    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => delegatingRecordReader.close()))
     logDebug(s"Appending $partitionSchema $partitionValues")
-    vectorizedReader
+    delegatingRecordReader
   }
 }

jkolash avatar Jun 17 '25 02:06 jkolash

On the ParquetPartitionReaderFactory side quest.

looking deeper at https://github.toasttab.com/toasttab/spork/commit/23ebd389b5cb528a7ba04113a12929bebfaf1e9a#diff-392c885fe00cf03dceb1d295a06034b279cee152b2f4a1ee0a4cfa3aec3b3660R199

it looks like that iter was just copy/pasted it isn't even used, it is used just for closing.

jkolash avatar Jun 17 '25 02:06 jkolash

I filed https://issues.apache.org/jira/projects/SPARK/issues/SPARK-52516 and have a WIP branch that I need to test against the latest spark master https://github.com/apache/spark/compare/master...jkolash:spark:fix-memory-leak-completionListener

I also created a gist here https://gist.github.com/jkolash/c13d48f97657787068bedc464e0c43c4 to generate synthetic data to reproduce the issue. at least with raw parquet if you were to execute the snapshot procedure https://iceberg.apache.org/docs/nightly/spark-procedures/#snapshot you would be able to reproduce it with iceberg.

Some tuning needs to be done to generate synthetic data that doesn't hit other issues but hits this issue and the amount of ram you run the test with.

jkolash avatar Jun 20 '25 15:06 jkolash

I've stopped working on this issue for now but will get back to it later. I plan on writing some integration tests that can take as parameters a Spark docker image version and an iceberg version to find when this regression started. I also suspect there may possibly be a leak in the VectorizedParquetRecordReader and the corresponding classes in iceberg, but I need to prove that.

jkolash avatar Jul 15 '25 13:07 jkolash

sorry i am a bit behind on this thread, so the MetricsRowIterator and ParquetPartitionReaderFactory fix together did not fix the issue? cc @viirya as well

szehon-ho avatar Jul 15 '25 18:07 szehon-ho

Just read through early discussion. So the iterator held by task completion listener is heavy for some Iceberg tables. Once iterator is exhausted before the task finishes, the iterator and related resources cannot be released because it is still held by the listener.

For Spark side, the issue looks like easy to fix.

viirya avatar Jul 15 '25 19:07 viirya

I proposed a fix at https://github.com/apache/spark/pull/51503

viirya avatar Jul 15 '25 20:07 viirya

I will verify the fix @viirya on the original dataset that triggered this issue. I think there may still be a similar bug here

https://github.com/apache/spark/blob/branch-3.5/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala#L307

jkolash avatar Jul 17 '25 13:07 jkolash

For ParquetPartitionReaderFactory, proposed a fix at https://github.com/apache/spark/pull/51528 yesterday.

viirya avatar Jul 17 '25 15:07 viirya

Ok looks like

https://hub.docker.com/layers/apache/spark/4.1.0-preview1/images/sha256-6c8ce99c91b278336894bb204a4eb20b726511d78242a185798e288a89c1dbdf

is available so I can write an OOM test against a released container image.

jkolash avatar Sep 22 '25 11:09 jkolash