[SPARK-39931][PYTHON][WIP] Improve applyInPandas performance for very small groups
What changes were proposed in this pull request?
Given a batch size to applyInPandas, multiple groups are sent to Python UDF at once if they are very small. This improves performance of applyInPandas for very small groups.
Why are the changes needed?
Spark sends individual groups to Python. When groups are very small, only few rows are sent and processed with each call into Python, which degrades throughput. See SPARK-39931 for a benchmark.
Does this PR introduce any user-facing change?
Adds optional batchSize argument to applyInPandas.
How was this patch tested?
Python unit tests.
Can one of the admins verify this patch?
Hm, the general idea might be fine but I think the implementation is the problem. For example, the current design is that the user defined function always takes one group for pdf. To keep this behaviour, you should send the multiple groups into one, and apply the same function multiple times for each group.
Hm, the general idea might be fine but I think the implementation is the problem. For example, the current design is that the user defined
functionalways takes one group for
This is happening here: https://github.com/apache/spark/pull/37360/files#diff-5862151bb5e9fe7a6b2d1978301c235504dcc6c1bbbd1f9745a204a3ba93146eR218-R229 The user function gets decorated to take the batch of groups, group by key and apply the actual user function on each group. This is the trivial part.
I need guidance on how to prepend the InternalRow key (as a single struct column) to the InternalRow value: https://github.com/apache/spark/pull/37360/files#diff-d153e10db7aa6557eb995300730b7f2b2d437fa5659dfeaa611800b49a09da9dR45-R49
The difficult part is the deduplicated unsafe projection: https://github.com/apache/spark/pull/37360/files#diff-4d4a9c23cb4c92c1f60def46451bc5666ed466922d9b56bbd1accc06aefee4e2R82-R87
Here is a benchmark (core seconds for 10m rows) on the batched applyInPandasBatched with batch sizes 65536, 1024, 16:
| group size | no batch | 65535 | 1024 | 16 | 65535 | 1024 | 16 | |
|---|---|---|---|---|---|---|---|---|
| 65536 | 5.8 | 5.6 | 5.8 | 5.6 | -3.5 % | -1.3 % | -4.1 % | |
| 8192 | 8.9 | 7.4 | 9.4 | 9.4 | -16.2 % | 5.3 % | 5.7 % | |
| 1024 | 16.2 | 7.2 | 22.7 | 22.3 | -55.5 % | 39.4 % | 37.1 % | |
| 512 | 26.7 | 6.9 | 22.5 | 38.6 | -74.3 % | -15.8 % | 44.5 % | |
| 256 | 44.5 | 7.1 | 22.8 | 70.5 | -84.1 % | -48.8 % | 58.5 % | |
| 128 | 82.7 | 7.3 | 23.5 | 138.0 | -91.1 % | -71.6 % | 66.8 % | |
| 64 | 158.2 | 8.9 | 25.3 | 264.3 | -94.4 % | -84.0 % | 67.1 % | |
| 32 | 319.8 | 11.4 | 28.2 | 465.0 | -96.4 % | -91.2 % | 45.4 % | |
| 16 | 652.6 | 17.1 | 32.9 | 924.9 | -97.4 % | -95.0 % | 41.7 % | |
| 8 | 1,376.9 | 28.5 | 46.2 | 971.4 | -97.9 % | -96.6 % | -29.4 % | |
| 4 | 2,656.3 | 52.2 | 68.8 | 971.4 | -98.0 % | -97.4 % | -63.4 % | |
| 2 | 5,412.5 | 94.2 | 110.9 | 996.2 | -98.3 % | -98.0 % | -81.6 % | |
| 1 | 9,491.4 | 187.2 | 204.7 | 1099.1 | -98.0 % | -97.8 % | -88.4 % |
Improvements are over 90% when batch is an order of magnitude larger than group size and group sizes are small.
Running below code via ./bin/pyspark --driver-memory 2G --master "local[1]"
Reference (no batch):
import time
import pandas as pd
from pyspark.sql.functions import col
from pyspark.sql.types import IntegerType, StructType, StructField
for group_size in reversed([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 8192, 65536]):
df = spark.range(10000000).repartition(200).select((col("id") / group_size).cast("int").alias("id")).cache()
c = df.count()
start = time.time_ns()
c = df.groupby("id").applyInPandas(lambda df: pd.DataFrame({'id': [df['id'][0]], 'size': [df.size]}), "id long, size integer").count()
end = time.time_ns()
print(f"groupSize {group_size} took {(end - start) / 1000000000}s")
df = df.unpersist()
Batched:
import time
import pandas as pd
from pyspark.sql.functions import col
from pyspark.sql.types import IntegerType, StructType, StructField
for group_size in reversed([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 8192, 65536]):
print(f"groupSize is {group_size}")
for batch_size in [65536, 8192, 1024, 128, 16, 2]:
df = spark.range(10000000).repartition(200).select((col("id") / group_size).cast("int").alias("id")).cache()
c = df.count()
start = time.time_ns()
c = df.groupby("id").applyInPandasBatched(lambda gdf: gdf.apply(lambda df: df.size).to_frame("size").reset_index(), "id long, size integer", batch_size).count()
end = time.time_ns()
print(f"groupSize {group_size} batchSize {batch_size} took {(end - start) / 1000000000}s")
df = df.unpersist()
print()
@HyukjinKwon Two options here:
- provide an alternative for
applyInPandasthat takes the same user function signature in batch mode- Python (
wrap_grouped_batch_map_pandas_udf) calls that function multiple times for one invocation from Scala
- Python (
- provide an alternative for
applyInPandasthat takes a different user function signature in batch mode- provides a
pandas.DataFrameGroupByto provide user code access to multiple groups at once, and Pandas Group API
- provides a
Both alternatives could be supported based on annotations of the user function (inspect.getfullargspec).
I wouldn't touch applyInPandas as this might introduce some penalty user code that has large groups.
@zhengruifeng how do you feel about this potential performance improvement?
We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable. If you'd like to revive this PR, please reopen it and ask a committer to remove the Stale tag!
@xinrong-meng what do you think about this?
@xinrong-meng is there interest in this improvement?
@EnricoMi This PR is really useful. Moreover it looks like some basic feature absolutely required if you use pandas UDF in Spark.
@EnricoMi do you think it makes sense to re-open this PR?
@sergun happy to re-open, but before I invest more time into this, I'd like to have some committer support this approach.
@EnricoMi Am I right there is similar problem with Series to Scalar UDF which is called in cases like:
df.withColumn('mean_v', mean_udf("v").over(w)).show()
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html?highlight=pandas_udf#pyspark.sql.functions.pandas_udf
So number of calls of UDF / number of ps.Series objects to be created are similar to cardinality of grouping fields
Am I right there is similar problem with Series to Scalar UDF
I suppose so, yes.