Spark: Doing a Coalesce and foreachpartitions in spark directly on an iceberg table is leaking memory heavy iterators
Apache Iceberg version
1.5.0
Query engine
Spark
Please describe the bug 🐞
Summary
Doing the following should not leak any significant amount of memory.
sparkSession.sql("select * from icebergcatalog.db.table").coalesce(4).foreachPartition( (iterator) -> {
while (iterator.hasNext()) iterator.next();
});
A workaround is to use repartition() instead however this requires more resources to handle spilling shuffling etc..
Spark version: Spark 3.4.X
Details
The Below code can be run on a sufficiently large iceberg table.
static AtomicInteger partitionCounter = new AtomicInteger(0);
static void reproduceBug(SparkSession sparkSession, String table) {
sparkSession.sql("select * from "+table).coalesce(4).foreachPartition( (iterator) -> {
int partition = partitionCounter.getAndIncrement();
AtomicLong rowCounter = ThreadLocal.withInitial(() -> new AtomicLong(0)).get();
while (iterator.hasNext()) {
iterator.next();
if (rowCounter.getAndIncrement() % 100000 == 0) {
System.out.println(partition + " " + rowCounter.get());
}
}
});
}
The following image is me running the reproduceBug method over sufficiently large table that we have in our environment with ~500 columns.
The following image shows the "Dominators" report in VisualVM org.apache.spark.TaskContextImpl
Digging deeper we see that the onCallbacks is keeping an anonymous class inside org.apache.spark.sql.execution.datasources.v2.DataSourceRDD and that is holding a reference to org.apache.iceberg.spark.source.RowDataReader
I believe this callback is added here https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala#L90
We also see the org.apache.iceberg.util.Filter iterator holding a heavy reference.
Exploring the problem
Is this inherently a bug in org.apache.spark.sql.execution.datasources.v2.DataSourceRDD? Or should iterators not hold onto state no longer needed once advanced to the end? Is the iterator even exhausted? Once an iterator is exhausted there is no longer a need for referencing. However this kind of breaks the concept of a CloseableIterator which has an explicit close vs an implicit close where you could detect hasNext() is false and auto-close. then even ignore a duplicate close() as it was handled by an implicit close() of iterator exhaustion. I believe an iterator accumulating hundreds of megabytes of state kind of breaks the implicit "expected" contract of an iterator being a streaming set of objects. There might even be a distinction between closing and simply holding onto large object references.
Digging deeper I see items = org.apache.iceberg.parquet.ParquetReader$FileIterator#6 holding onto a model reference. It might be possible to null out the model references when hasNext() is false.
@Override
public boolean hasNext() {
boolean hasNext = valuesRead < totalValues;
if (!hasNext) {
this.model = null;
}
return hasNext;
}
Willingness to contribute
- [x] I can contribute a fix for this bug independently
- [x] I would be willing to contribute a fix for this bug with guidance from the Iceberg community
- [ ] I cannot contribute a fix for this bug at this time
I was able to make some hacky changes that reduced memory usage in this draft PR. https://github.com/apache/iceberg/pull/13298 mainly to show that these were the critical objects that needed to be GC'd that were no longer needed.
I also tested
String[] paths = sparkSession.sql("select file_path from "+table+".files").collectAsList().stream().map( row -> row.getString(0)).toArray(String[]::new);
System.out.println(Arrays.asList(paths));
sparkSession.read().load(paths).coalesce(4).foreachPartition((iterator) -> {
int partition = partitionCounter.getAndIncrement();
AtomicLong rowCounter = ThreadLocal.withInitial(() -> new AtomicLong(0)).get();
while (iterator.hasNext()) {
iterator.next();
if (rowCounter.getAndIncrement() % 100000 == 0) {
System.out.println(partition + " " + rowCounter.get());
}
}
}
);
To show that it is the new read path through iceberg that is the issue and when loading/processing the parquet files directly the issue doesn't manifest.
The TaskContextImpl still has the onCompleteCalllbacks but the FileScanRDD is much smaller
Can your force V2 Sources and try the Parquet version again (FileScanRDD is a different code path)? It feels odd to me that the iterator we are making should null itself out when fully iterated. Feels like that should be on the last "next" call as is currently in the code or in the close method?
Or is the issue here that we are saving this task context for the UI so we don't actually ever drop the iterator reference and close is never called?
I'm not sure what you mean by force V2 sources and try the Parquet version again. I was comparing both the v1 and v2 memory usage and the v2 path using iceberg consumes significantly more memory.
Is there a v2 equivalent to
sparkSession.read().load(paths)
Speaking of should be handled on the last next() call. I don't see that in my experimental changes which did reduce memory usage. https://github.com/apache/iceberg/pull/13298/files#diff-d80c15b3e5376265436aeab8b79d5a92fb629c6b81f58ad10a11b9b9d3bfcffcR134
Or is the issue here that we are saving this task context for the UI so we don't actually ever drop the iterator reference and close is never called?
The close() will not likely be called until the job completes. the call back being held onto is what will do the close() operation and that wont likely happen until the job completes.
context.addTaskCompletionListener[Unit] { _ =>
// In case of early stopping before consuming the entire iterator,
// we need to do one more metric update at the end of the task.
CustomMetrics
.updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
iter.forceUpdateMetrics()
reader.close()
}
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala#L90
Let me know if you want me to write something to produce synthetic data for this so it can be reproduced without using real data. That way you can reproduce it on your environment.
I'm not sure what you mean by force V2 sources and try the Parquet version again. I was comparing both the v1 and v2 memory usage and the v2 path using iceberg consumes significantly more memory.
Is there a v2 equivalent to
sparkSession.read().load(paths) Speaking of should be handled on the last next() call. I don't see that in my experimental changes which did reduce memory usage. https://github.com/apache/iceberg/pull/13298/files#diff-d80c15b3e5376265436aeab8b79d5a92fb629c6b81f58ad10a11b9b9d3bfcffcR134
https://github.com/apache/iceberg/pull/13298/files#diff-7f0d08eb9c160db9e4fefa082993c58f2fe749f9be4f108c26cb4201547f2521L139
Ok - for my own recap here
- Spark is holding onto task contexts in order to invoke callbacks to get metric information till the end of the job
- Task Context through this callback holds onto the iterator the DSV2 source makes
- Our iterator ends up being slightly heavy, it holds onto a reader for every column in the table which means for a very wide table we end up holding onto a parquet reader for every column in the table. This ends up adding up to a significant amount of memory.
Now I think there are a few avenues to make this better
-
As @jkolash suggests we can preemptively null out the reference to the readers when the iterator has been exhausted. One issue here is that a partially consumed iterator will still hold onto the model until the job has completed.
-
We can work on the Spark side to make sure that the iterator is actually released when exhausted or the task has completed.
Thoughts @szehon-ho + @aokolnychyi ?
I dont see immediately a spark api to easily remove the callback once an iter is exhausted.
Maybe the iceberg change is easier. but if you have some idea on spark side, it make sense to explore?
I wanted to bring up another point that @RussellSpitzer helped me identify which was this issue reproduces in a similar way with the v2 Datasources
when I Ran https://github.com/apache/iceberg/issues/13297#issuecomment-2963910701
but set the spark option spark.sql.sources.useV1SourceList to an empty string it reproduced via the new v2 datasource code, but not the v1 datasource.
So this issue also needs to be addressed in spark as well.
On the spark side I think the close() call likely should be done on task completion not via a callback on job completion?, and the metrics can be collected before that close?
The issue is that coalesce is combining N tasks into 1 task but really there are N tasks, so coalesce should be more subtask aware?
Reading more deeply in DataSourceRDD.scala it seems the callback is there due to a partially consumed iterator, but spark wraps the raw iterator using the MetricsBatchIterator that could detect whether the iterator was exhausted via hasNext() == false and remove the onComplete callback?
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/TaskContextImpl.scala#L71 is currently a stack so that would need to change to something more like a linked list to alllow removing arbitrary listeners.
Actually I don't think any spark api changes are needed the callback and the MetricsBatchIterator are both created in the DataSourceRDD
In java you could write something like this.
static class HeavyCallback implements Runnable{
byte[] heavyReference = new byte[1000000];
@Override
public void run() {
}
}
static class InvokeOnceCallback implements Runnable {
final AtomicReference<Runnable> callbackHandle = new AtomicReference<>();
InvokeOnceCallback(Runnable target) {
callbackHandle.set(target);
}
@Override
public void run() {
if (callbackHandle.get() != null) {
callbackHandle.get().run();
callbackHandle.set(null);
}
}
}
static void showcaseIndirectCallback() {
InvokeOnceCallback callback = new InvokeOnceCallback(new HeavyCallback());
//this can be called from either the iterable exhaustion or the original callback.
//Once invoked the reference is removed and can be GC'd
callback.run();
}
Hi, my team and I had a similar problem while querying and optimizing tables with a huge number of data and delete files. Spark always ran OOM, because resources (Parquet readers, ...) weren't freed properly. I did some debugging and identified CloseableIterable as the culprit. Unfortunately, I never had time to upstream it properly.
CloseableIterable has the following two problems:
- When turning a
CloseableIterableinto an iterator, the returnedCloseableIteratordoesn't properly close the iterable, thereby causing a resource leak. By returning a custom instance ofCloseableIteratorwhich keeps track of the closable iterable, we properly close the iterable when the iterator gets closed. - When combining an iterable with a closeable using
CloseableIterable.combinewe have a resource leak if the iterable implementsCloseableIterablebecause we never call close on the iterable. By overloadingcombineforCloseableIterablewe can ensure that we close the iterable in addition to the closeable.
With the following patch applied, the queries and table optimizations ran without OOM and much lower memory limits.
From 6e2cfd71c2dcecf432f2db2119ac530674af5042 Mon Sep 17 00:00:00 2001
From: Emmanuel Pescosta <[email protected]>
Date: Thu, 16 Jan 2025 12:19:46 +0100
Subject: [PATCH] fix(iceberg.io): Properly close CloseableIterable when
turning it into an iterator
A CloseableIterator created from a CloseableIterable didn't close
the iterable when the iterator is closed. This may cause a resource
leak.
---
.../apache/iceberg/io/CloseableIterable.java | 40 ++++++++++++++++++-
1 file changed, 38 insertions(+), 2 deletions(-)
diff --git a/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java b/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java
index 06323612a..240c1ef8f 100644
--- a/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java
+++ b/api/src/main/java/org/apache/iceberg/io/CloseableIterable.java
@@ -51,7 +51,7 @@ public interface CloseableIterable<T> extends Iterable<T>, Closeable {
@Override
public CloseableIterator<E> iterator() {
- return CloseableIterator.withClose(iterable.iterator());
+ return closeableIteratorOf(iterable, this);
}
};
}
@@ -69,7 +69,43 @@ public interface CloseableIterable<T> extends Iterable<T>, Closeable {
@Override
public CloseableIterator<E> iterator() {
- return CloseableIterator.withClose(iterable.iterator());
+ return closeableIteratorOf(iterable, this);
+ }
+ };
+ }
+
+ static <E> CloseableIterable<E> combine(CloseableIterable<E> iterable, Closeable closeable) {
+ return new CloseableIterable<E>() {
+ @Override
+ public void close() throws IOException {
+ closeable.close();
+ iterable.close();
+ }
+
+ @Override
+ public CloseableIterator<E> iterator() {
+ return closeableIteratorOf(iterable, this);
+ }
+ };
+ }
+
+ private static <E> CloseableIterator<E> closeableIteratorOf(
+ Iterable<E> iterable, Closeable closeable) {
+ Iterator<E> iterator = iterable.iterator();
+ return new CloseableIterator<E>() {
+ @Override
+ public void close() throws IOException {
+ closeable.close();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return iterator.hasNext();
+ }
+
+ @Override
+ public E next() {
+ return iterator.next();
}
};
}
--
2.49.0
@jkolash Can you please check if this would fix your problem?
Hi @emmanuel099 Those changes had no impact on my issue.
@szehon-ho Let me know if a PR on the spark side would be welcome or not. I don't think the fix for this in DataSourceRDD will be difficult. to be extra safe we could compute a value object on iterator exhaustion which is much smaller that can be applied on the callback close so multiple threads aren't updating the metrics as that may not be thread safe.
So I made the following changes to just spark, without my iceberg changes. and I was able to not OOM.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 67e77a9786..288e9de16c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.execution.datasources.v2
-import scala.language.existentials
+import java.util.concurrent.atomic.AtomicReference
+import scala.language.existentials
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
@@ -29,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.sql.vectorized.ColumnarBatch
+
class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition])
extends Partition with Serializable
@@ -74,24 +76,33 @@ class DataSourceRDD(
val inputPartition = inputPartitions(currentIndex)
currentIndex += 1
+ val exhaustCallback = new InvokeOnceCallback()
+
// TODO: SPARK-25083 remove the type erasure hack in data source scan
val (iter, reader) = if (columnarReads) {
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
val iter = new MetricsBatchIterator(
- new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
+ new PartitionIterator[ColumnarBatch](batchReader, customMetrics), exhaustCallback)
(iter, batchReader)
} else {
val rowReader = partitionReaderFactory.createReader(inputPartition)
val iter = new MetricsRowIterator(
- new PartitionIterator[InternalRow](rowReader, customMetrics))
+ new PartitionIterator[InternalRow](rowReader, customMetrics), exhaustCallback)
(iter, rowReader)
}
+
+ exhaustCallback.setCallback(callback = new Runnable() {
+ override def run(): Unit = {
+ // In case of early stopping before consuming the entire iterator,
+ // we need to do one more metric update at the end of the task.
+ CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
+ iter.forceUpdateMetrics()
+ reader.close()
+ }
+ })
+
context.addTaskCompletionListener[Unit] { _ =>
- // In case of early stopping before consuming the entire iterator,
- // we need to do one more metric update at the end of the task.
- CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
- iter.forceUpdateMetrics()
- reader.close()
+ exhaustCallback.run()
}
currentIter = Some(iter)
hasNext
@@ -107,6 +118,21 @@ class DataSourceRDD(
}
}
+private class InvokeOnceCallback extends Runnable {
+ val originalCallback = new AtomicReference[Runnable](null)
+
+ override def run(): Unit = {
+ if (originalCallback.get() != null) {
+ originalCallback.get().run()
+ originalCallback.set(null);
+ }
+ }
+
+ def setCallback(callback: Runnable): Unit = {
+ originalCallback.set(callback);
+ }
+}
+
private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
@@ -151,14 +177,16 @@ private class MetricsHandler extends Logging with Serializable {
}
}
-private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] {
+private abstract class MetricsIterator[I](
+ iter: Iterator[I],
+ exhaustionCallback: InvokeOnceCallback) extends Iterator[I] {
protected val metricsHandler = new MetricsHandler
override def hasNext: Boolean = {
if (iter.hasNext) {
true
} else {
- forceUpdateMetrics()
+ exhaustionCallback.run()
false
}
}
@@ -167,7 +195,8 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I]
}
private class MetricsRowIterator(
- iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) {
+ iter: Iterator[InternalRow],
+ callback: InvokeOnceCallback) extends MetricsIterator[InternalRow](iter, callback) {
override def next(): InternalRow = {
val item = iter.next
metricsHandler.updateMetrics(1)
@@ -176,7 +205,8 @@ private class MetricsRowIterator(
}
private class MetricsBatchIterator(
- iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) {
+ iter: Iterator[ColumnarBatch],
+ callback: InvokeOnceCallback) extends MetricsIterator[ColumnarBatch](iter, callback) {
override def next(): ColumnarBatch = {
val batch: ColumnarBatch = iter.next
metricsHandler.updateMetrics(batch.numRows)
Yea of course, it is welcome on spark side. btw, to learn, why is it only for coalesce + foreach? is it for all wide iceberg tables, and coalesce just makes it more vulnerable?
I believe it is happening because normally these would be separate tasks but coalesce kind of hides each task and combines multiple partitions into 1 partition so the task cannot "complete" and the callbacks are held much longer.
Also I ran with the parquet v2 code https://github.com/apache/iceberg/issues/13297#issuecomment-2968557949
and a similar fix needs to be applied here I believe. https://github.com/apache/spark/blob/59e6b5b7d350a1603502bc92e3c117311ab2cbb6/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala#L312
is it for all wide iceberg tables, and coalesce just makes it more vulnerable?
This particular table is ~ 500 columns wide and with nesting. I can produce a synthetic dataset later or as part of this issue so it can be reproduced by anyone.
On the v2 parquet reader side quest. This total set of spark changes allows the v2 parquet reader to work
I introduce a GarbageCollectableRecordReader that nulls out the delegate once close() has been called. The close() was called correctly by callers but the callback still prevented GC even if close had been called.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 67e77a9786..288e9de16c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.execution.datasources.v2
-import scala.language.existentials
+import java.util.concurrent.atomic.AtomicReference
+import scala.language.existentials
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
@@ -29,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.sql.vectorized.ColumnarBatch
+
class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition])
extends Partition with Serializable
@@ -74,24 +76,33 @@ class DataSourceRDD(
val inputPartition = inputPartitions(currentIndex)
currentIndex += 1
+ val exhaustCallback = new InvokeOnceCallback()
+
// TODO: SPARK-25083 remove the type erasure hack in data source scan
val (iter, reader) = if (columnarReads) {
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
val iter = new MetricsBatchIterator(
- new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
+ new PartitionIterator[ColumnarBatch](batchReader, customMetrics), exhaustCallback)
(iter, batchReader)
} else {
val rowReader = partitionReaderFactory.createReader(inputPartition)
val iter = new MetricsRowIterator(
- new PartitionIterator[InternalRow](rowReader, customMetrics))
+ new PartitionIterator[InternalRow](rowReader, customMetrics), exhaustCallback)
(iter, rowReader)
}
+
+ exhaustCallback.setCallback(callback = new Runnable() {
+ override def run(): Unit = {
+ // In case of early stopping before consuming the entire iterator,
+ // we need to do one more metric update at the end of the task.
+ CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
+ iter.forceUpdateMetrics()
+ reader.close()
+ }
+ })
+
context.addTaskCompletionListener[Unit] { _ =>
- // In case of early stopping before consuming the entire iterator,
- // we need to do one more metric update at the end of the task.
- CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
- iter.forceUpdateMetrics()
- reader.close()
+ exhaustCallback.run()
}
currentIter = Some(iter)
hasNext
@@ -107,6 +118,21 @@ class DataSourceRDD(
}
}
+private class InvokeOnceCallback extends Runnable {
+ val originalCallback = new AtomicReference[Runnable](null)
+
+ override def run(): Unit = {
+ if (originalCallback.get() != null) {
+ originalCallback.get().run()
+ originalCallback.set(null);
+ }
+ }
+
+ def setCallback(callback: Runnable): Unit = {
+ originalCallback.set(callback);
+ }
+}
+
private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
@@ -151,14 +177,16 @@ private class MetricsHandler extends Logging with Serializable {
}
}
-private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] {
+private abstract class MetricsIterator[I](
+ iter: Iterator[I],
+ exhaustionCallback: InvokeOnceCallback) extends Iterator[I] {
protected val metricsHandler = new MetricsHandler
override def hasNext: Boolean = {
if (iter.hasNext) {
true
} else {
- forceUpdateMetrics()
+ exhaustionCallback.run()
false
}
}
@@ -167,7 +195,8 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I]
}
private class MetricsRowIterator(
- iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) {
+ iter: Iterator[InternalRow],
+ callback: InvokeOnceCallback) extends MetricsIterator[InternalRow](iter, callback) {
override def next(): InternalRow = {
val item = iter.next
metricsHandler.updateMetrics(1)
@@ -176,7 +205,8 @@ private class MetricsRowIterator(
}
private class MetricsBatchIterator(
- iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) {
+ iter: Iterator[ColumnarBatch],
+ callback: InvokeOnceCallback) extends MetricsIterator[ColumnarBatch](iter, callback) {
override def next(): ColumnarBatch = {
val batch: ColumnarBatch = iter.next
metricsHandler.updateMetrics(batch.numRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
index 5951c1d8dd..165fe88bef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.datasources.v2.parquet
import java.time.ZoneId
-
import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
@@ -26,7 +25,6 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate}
import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS}
import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader}
import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
-
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
@@ -36,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.WholeStageCodegenExec
-import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet._
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
@@ -45,6 +43,8 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration
+import java.util.concurrent.atomic.AtomicReference
+
/**
* A factory used to create Parquet readers.
*
@@ -158,7 +158,7 @@ case class ParquetPartitionReaderFactory(
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
val fileReader = if (aggregation.isEmpty) {
val vectorizedReader = createVectorizedReader(file)
- vectorizedReader.enableReturningBatches()
+ vectorizedReader.delegate.asInstanceOf[VectorizedParquetRecordReader].enableReturningBatches()
new PartitionReader[ColumnarBatch] {
override def next(): Boolean = vectorizedReader.nextKeyValue()
@@ -205,7 +205,8 @@ case class ParquetPartitionReaderFactory(
InternalRow,
Option[FilterPredicate], Option[ZoneId],
RebaseSpec,
- RebaseSpec) => RecordReader[Void, T]): RecordReader[Void, T] = {
+ RebaseSpec) =>
+ GarbageCollectableRecordReader[Void, T]): GarbageCollectableRecordReader[Void, T] = {
val conf = broadcastedConf.value.value
val filePath = file.toPath
@@ -279,7 +280,7 @@ case class ParquetPartitionReaderFactory(
pushed: Option[FilterPredicate],
convertTz: Option[ZoneId],
datetimeRebaseSpec: RebaseSpec,
- int96RebaseSpec: RebaseSpec): RecordReader[Void, InternalRow] = {
+ int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, InternalRow] = {
logDebug(s"Falling back to parquet-mr")
val taskContext = Option(TaskContext.get())
// ParquetRecordReader returns InternalRow
@@ -296,17 +297,57 @@ case class ParquetPartitionReaderFactory(
}
val readerWithRowIndexes = ParquetRowIndexUtil.addRowIndexToRecordReaderIfNeeded(
reader, readDataSchema)
- val iter = new RecordReaderIterator(readerWithRowIndexes)
+ val delegatingRecordReader =
+ new GarbageCollectableRecordReader[Void, InternalRow](readerWithRowIndexes)
// SPARK-23457 Register a task completion listener before `initialization`.
- taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
- readerWithRowIndexes
+ taskContext.foreach(_.addTaskCompletionListener[Unit](_ => delegatingRecordReader.close()))
+ delegatingRecordReader
+ }
+
+ private class GarbageCollectableRecordReader[K, V](reader: RecordReader[K, V])
+ extends RecordReader[K, V] {
+ val delegate = new AtomicReference[RecordReader[K, V]](reader)
+
+ override def initialize(inputSplit: InputSplit,
+ taskAttemptContext: TaskAttemptContext): Unit = {
+ delegate.get().initialize(inputSplit, taskAttemptContext)
+ }
+
+ override def nextKeyValue(): Boolean = {
+ delegate.get().nextKeyValue()
+ }
+
+ override def getCurrentKey: K = {
+ delegate.get().getCurrentKey
+ }
+
+ override def getCurrentValue: V = {
+ delegate.get().getCurrentValue
+ }
+
+ override def getProgress: Float = {
+ if (delegate.get() == null) {
+ 1.0f
+ } else {
+ delegate.get().getProgress
+ }
+ }
+
+ override def close(): Unit = {
+ if (delegate.get() != null) {
+ delegate.get().close()
+ delegate.set(null)
+ }
+ }
}
- private def createVectorizedReader(file: PartitionedFile): VectorizedParquetRecordReader = {
- val vectorizedReader = buildReaderBase(file, createParquetVectorizedReader)
+ private def createVectorizedReader(file: PartitionedFile):
+ GarbageCollectableRecordReader[Void, InternalRow] = {
+ val gcReader = buildReaderBase(file, createParquetVectorizedReader)
+ val vectorizedReader = gcReader.delegate.get()
.asInstanceOf[VectorizedParquetRecordReader]
vectorizedReader.initBatch(partitionSchema, file.partitionValues)
- vectorizedReader
+ gcReader
}
private def createParquetVectorizedReader(
@@ -314,7 +355,7 @@ case class ParquetPartitionReaderFactory(
pushed: Option[FilterPredicate],
convertTz: Option[ZoneId],
datetimeRebaseSpec: RebaseSpec,
- int96RebaseSpec: RebaseSpec): VectorizedParquetRecordReader = {
+ int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, InternalRow] = {
val taskContext = Option(TaskContext.get())
val vectorizedReader = new VectorizedParquetRecordReader(
convertTz.orNull,
@@ -323,11 +364,14 @@ case class ParquetPartitionReaderFactory(
int96RebaseSpec.mode.toString,
int96RebaseSpec.timeZone,
enableOffHeapColumnVector && taskContext.isDefined,
- capacity)
- val iter = new RecordReaderIterator(vectorizedReader)
+ capacity).asInstanceOf[RecordReader[Void, InternalRow]]
+
+ val delegatingRecordReader =
+ new GarbageCollectableRecordReader[Void, InternalRow](vectorizedReader)
+
// SPARK-23457 Register a task completion listener before `initialization`.
- taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
+ taskContext.foreach(_.addTaskCompletionListener[Unit](_ => delegatingRecordReader.close()))
logDebug(s"Appending $partitionSchema $partitionValues")
- vectorizedReader
+ delegatingRecordReader
}
}
On the ParquetPartitionReaderFactory side quest.
looking deeper at https://github.toasttab.com/toasttab/spork/commit/23ebd389b5cb528a7ba04113a12929bebfaf1e9a#diff-392c885fe00cf03dceb1d295a06034b279cee152b2f4a1ee0a4cfa3aec3b3660R199
it looks like that iter was just copy/pasted it isn't even used, it is used just for closing.
I filed https://issues.apache.org/jira/projects/SPARK/issues/SPARK-52516 and have a WIP branch that I need to test against the latest spark master https://github.com/apache/spark/compare/master...jkolash:spark:fix-memory-leak-completionListener
I also created a gist here https://gist.github.com/jkolash/c13d48f97657787068bedc464e0c43c4 to generate synthetic data to reproduce the issue. at least with raw parquet if you were to execute the snapshot procedure https://iceberg.apache.org/docs/nightly/spark-procedures/#snapshot you would be able to reproduce it with iceberg.
Some tuning needs to be done to generate synthetic data that doesn't hit other issues but hits this issue and the amount of ram you run the test with.
I've stopped working on this issue for now but will get back to it later. I plan on writing some integration tests that can take as parameters a Spark docker image version and an iceberg version to find when this regression started. I also suspect there may possibly be a leak in the VectorizedParquetRecordReader and the corresponding classes in iceberg, but I need to prove that.
sorry i am a bit behind on this thread, so the MetricsRowIterator and ParquetPartitionReaderFactory fix together did not fix the issue? cc @viirya as well
Just read through early discussion. So the iterator held by task completion listener is heavy for some Iceberg tables. Once iterator is exhausted before the task finishes, the iterator and related resources cannot be released because it is still held by the listener.
For Spark side, the issue looks like easy to fix.
I proposed a fix at https://github.com/apache/spark/pull/51503
I will verify the fix @viirya on the original dataset that triggered this issue. I think there may still be a similar bug here
https://github.com/apache/spark/blob/branch-3.5/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala#L307
For ParquetPartitionReaderFactory, proposed a fix at https://github.com/apache/spark/pull/51528 yesterday.
Ok looks like
https://hub.docker.com/layers/apache/spark/4.1.0-preview1/images/sha256-6c8ce99c91b278336894bb204a4eb20b726511d78242a185798e288a89c1dbdf
is available so I can write an OOM test against a released container image.