SynapseML icon indicating copy to clipboard operation
SynapseML copied to clipboard

TabularSHAP hangs indefinitely

Open Ishitori opened this issue 2 years ago • 11 comments

Hi SynapseML team,

I am using TabularSHAP to explain predictions of the LightGBM classification model using SynapseML version com.microsoft.azure:synapseml_2.12:0.9.5-35-e962330b-SNAPSHOT

Based on the tutorial, I have prepared my code below. I am using only one executor with five cores to ensure the code doesn't fail due to concurrency, and I repartition all data frames to use a single partition.

from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from synapse.ml.lightgbm import JavaMLReadable, LightGBMClassifier, LightGBMClassificationModel
from synapse.ml.explainers import TabularSHAP

model = LightGBMClassifier.load('<path_to_model>')
train_data = spark.read.parquet('<path_to_data>').repartition(1).cache() # type: about 15K records
features = [...list of features...]

vector_assembler = VectorAssembler(inputCols = features, outputCol = 'features').setHandleInvalid("skip")
pipeline = Pipeline(stages=[vector_assembler, model])
model = pipeline.fit(train_data)

explain_instances = model.transform(train_data.limit(5).repartition(1).cache()).cache()
train_sample = train_data.orderBy(F.rand()).limit(100).repartition(1).cache()

shap = TabularSHAP(
    inputCols=features,
    outputCol="shapValues",
    numSamples=10,
    model=model,
    targetCol="probability",
    targetClasses=[1],
    backgroundData=train_sample
)

explained = shap.transform(explain_instances).cache()

On the last line, when I call shap.transform(), the notebook shows me the progress bar but waits forever. I think it gets into some lock, but I am not sure. But I am sure it doesn't wait for the data frames to be calculated, because, while I have removed these lines from the code, I call .show() to force caching...

What am I missing?

Thank you.

Ishitori avatar Apr 27 '23 13:04 Ishitori

Hey @Ishitori :wave:! Thank you so much for reporting the issue/feature request :rotating_light:. Someone from SynapseML Team will be looking to triage this issue soon. We appreciate your patience.

github-actions[bot] avatar Apr 27 '23 13:04 github-actions[bot]

Any suggestions?

Ishitori avatar May 01 '23 11:05 Ishitori

@Ishitori im looping in @memoryz who created this code to help out here

mhamilton723 avatar May 01 '23 17:05 mhamilton723

Looking at the code, I don't see anything obviously wrong, but can you try the following and see if it helps:

explain_instances = model.transform(train_data).limit(5).repartition(200).cache()
train_sample = broadcast(train_data.orderBy(F.rand()).limit(100).cache())

Also, out of curiosity, how many features do you have?

memoryz avatar May 02 '23 07:05 memoryz

@Ishitori -- awaiting your response here.

ppruthi avatar May 11 '23 00:05 ppruthi

I have the exact same issue and I tried repartioning like @memoryz suggested and it didn't work. This is the case for a small dataset of 100 entries as well. I have 17 features in my dataset. Was there any solution found to the problem ?

elyesMi avatar Jun 30 '23 13:06 elyesMi

Py4JError: com.microsoft.azure.synapse.ml.explainers.TabularSHAP does not exist in the JVM. Please, I need help

omolewadavids avatar Oct 29 '23 13:10 omolewadavids

se = connection.send_command(command) File "/Users/bluelambda/anaconda3/envs/pyspark-env/lib/python3.10/site-packages/pyspark/python/lib/py4j-0.10.9.7-src.z at py4j.commands.ReflectionCommand.execute(ReflectionCommand.java:87) at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182) ip/py4j/clientserver.py", line 539, in send_command raise Py4JNetworkError( py4j.protocol.Py4JNetworkError: Error while sending or receiving 2023-10-29 09:28:49,190 - INFO : Closing down clientserver connection Traceback (most recent call last): File "/Users/bluelambda/Documents/GitHub/code/pyspark-anomaly-detection/test.py", line 248, in shap = TabularSHAP( File "/Users/bluelambda/anaconda3/envs/pyspark-env/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/init.py", line 139, in wrapper return func(self, **kwargs) File "/Users/bluelambda/.ivy2/jars/com.microsoft.azure_synapseml-core_2.12-0.11.4.jar/synapse/ml/explainers/TabularSHAP.py", line at py4j.ClientServerConnection.run(ClientServerConnection.java:106)78, in init self._java_obj = self._new_java_obj("com.microsoft.azure.synapse.ml.explainers.TabularSHAP", self.uid) File "/Users/bluelambda/anaconda3/envs/pyspark-env/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/ml/wrapper.py", line 84, in _new_java_obj java_obj = getattr(java_obj, name) File "/Users/bluelambda/anaconda3/envs/pyspark-env/lib/python3.10/site-packages/pyspark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1664, in getattr raise Py4JError("{0} does not exist in the JVM".format(new_fqn)) py4j.protocol.Py4JError: com.microsoft.azure.synapse.ml.explainers.TabularSHAP does not exist in the JVM 2023-10-29 09:28:49,192 - INFO : Closing down clientserver connection

omolewadavids avatar Oct 29 '23 13:10 omolewadavids

Looks like I'm having the exact same issue with TabularLIME. Tried reducing numSamples, no result

svolchkov avatar Aug 01 '24 23:08 svolchkov

TabularSHAP also hangs indefinitely. Using pyspark's PipelineModel including a VectorAssembler and a RandomForestClassificationModel

svolchkov avatar Aug 02 '24 00:08 svolchkov

I get a feeling this is not intended for a large number of features. At least TabularSHAP. Because https://github.com/microsoft/SynapseML/blob/master/core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/KernelSHAPBase.scala#L134 . We have 38 input features and I suspect this will result in an integer overflow as the number is being forced to int a couple of lines below.

Not sure as yet what's causing TabularLIME to fail as the code there seems to be a bit more convoluted but the above looks so sloppy that I'm ready to abandon my hope for this package.

svolchkov avatar Aug 02 '24 19:08 svolchkov