TabularSHAP hangs indefinitely
Hi SynapseML team,
I am using TabularSHAP to explain predictions of the LightGBM classification model using SynapseML version com.microsoft.azure:synapseml_2.12:0.9.5-35-e962330b-SNAPSHOT
Based on the tutorial, I have prepared my code below. I am using only one executor with five cores to ensure the code doesn't fail due to concurrency, and I repartition all data frames to use a single partition.
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from synapse.ml.lightgbm import JavaMLReadable, LightGBMClassifier, LightGBMClassificationModel
from synapse.ml.explainers import TabularSHAP
model = LightGBMClassifier.load('<path_to_model>')
train_data = spark.read.parquet('<path_to_data>').repartition(1).cache() # type: about 15K records
features = [...list of features...]
vector_assembler = VectorAssembler(inputCols = features, outputCol = 'features').setHandleInvalid("skip")
pipeline = Pipeline(stages=[vector_assembler, model])
model = pipeline.fit(train_data)
explain_instances = model.transform(train_data.limit(5).repartition(1).cache()).cache()
train_sample = train_data.orderBy(F.rand()).limit(100).repartition(1).cache()
shap = TabularSHAP(
inputCols=features,
outputCol="shapValues",
numSamples=10,
model=model,
targetCol="probability",
targetClasses=[1],
backgroundData=train_sample
)
explained = shap.transform(explain_instances).cache()
On the last line, when I call shap.transform(), the notebook shows me the progress bar but waits forever. I think it gets into some lock, but I am not sure. But I am sure it doesn't wait for the data frames to be calculated, because, while I have removed these lines from the code, I call .show() to force caching...
What am I missing?
Thank you.
Hey @Ishitori :wave:! Thank you so much for reporting the issue/feature request :rotating_light:. Someone from SynapseML Team will be looking to triage this issue soon. We appreciate your patience.
Any suggestions?
@Ishitori im looping in @memoryz who created this code to help out here
Looking at the code, I don't see anything obviously wrong, but can you try the following and see if it helps:
explain_instances = model.transform(train_data).limit(5).repartition(200).cache()
train_sample = broadcast(train_data.orderBy(F.rand()).limit(100).cache())
Also, out of curiosity, how many features do you have?
@Ishitori -- awaiting your response here.
I have the exact same issue and I tried repartioning like @memoryz suggested and it didn't work. This is the case for a small dataset of 100 entries as well. I have 17 features in my dataset. Was there any solution found to the problem ?
Py4JError: com.microsoft.azure.synapse.ml.explainers.TabularSHAP does not exist in the JVM. Please, I need help
se = connection.send_command(command)
File "/Users/bluelambda/anaconda3/envs/pyspark-env/lib/python3.10/site-packages/pyspark/python/lib/py4j-0.10.9.7-src.z at py4j.commands.ReflectionCommand.execute(ReflectionCommand.java:87)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
ip/py4j/clientserver.py", line 539, in send_command
raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving
2023-10-29 09:28:49,190 - INFO : Closing down clientserver connection
Traceback (most recent call last):
File "/Users/bluelambda/Documents/GitHub/code/pyspark-anomaly-detection/test.py", line 248, in
Looks like I'm having the exact same issue with TabularLIME. Tried reducing numSamples, no result
TabularSHAP also hangs indefinitely. Using pyspark's PipelineModel including a VectorAssembler and a RandomForestClassificationModel
I get a feeling this is not intended for a large number of features. At least TabularSHAP. Because https://github.com/microsoft/SynapseML/blob/master/core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/KernelSHAPBase.scala#L134 . We have 38 input features and I suspect this will result in an integer overflow as the number is being forced to int a couple of lines below.
Not sure as yet what's causing TabularLIME to fail as the code there seems to be a bit more convoluted but the above looks so sloppy that I'm ready to abandon my hope for this package.