Slider in `scatter_3d` and `scatter` makes some data points go missing
Only 2 out of 4 categories are plotted when I use a slider. Other data points do not appear at all. When I slide the slider, different categories are plotted. E.g. in the MWE below, only TP and FP show up when the slider is below 0.9. At 0.9 only TN and FN show up.
This behavior also happens for both 2d and 3d scatter plots. See the MWE below for 3d.
import numpy as np
import pandas as pd
import plotly.express as px
def plot_scatter_3d_mwe():
# Create a small DataFrame with fake data
data = {
'Dim1': np.random.rand(10),
'Dim2': np.random.rand(10),
'Dim3': np.random.rand(10),
'due': [1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
'serial_number': range(10),
'predicted_probabilities': [0.9, 0.8, 0.4, 0.2, 0.6, 0.7, 0.1, 0.5, 0.3, 0.95]
}
df = pd.DataFrame(data)
thresholds = np.arange(0, 1.1, 0.1)
all_frames = []
for threshold in thresholds:
# Recalculate predictions based on the threshold
predicted = (df['predicted_probabilities'] >= threshold).astype(int)
# Create the 4 categories for coloring: TP, TN, FP, FN
conditions = [
(df['due'] == 1) & (predicted == 1), # TP
(df['due'] == 0) & (predicted == 0), # TN
(df['due'] == 0) & (predicted == 1), # FP
(df['due'] == 1) & (predicted == 0), # FN
]
categories = ['TP', 'TN', 'FP', 'FN']
# Assign the categories to a new column
df['category'] = np.select(conditions, categories, default='Unknown')
df['threshold'] = threshold # Add threshold as a column for animation frame
all_frames.append(df.copy())
# Concatenate all frames for animation
df_all_frames = pd.concat(all_frames)
# Plot the scatter 3D with the categories as color and animate over thresholds
fig = px.scatter_3d(df_all_frames,
x='Dim1', y='Dim2', z='Dim3',
color='category',
animation_frame='threshold',
animation_group='serial_number')
fig.show()
# Call the function
plot_scatter_3d_mwe()
Bug Report: Data Points Missing in scatter_3d and scatter with Slider
Issue Description:
When using a slider with scatter_3d or scatter plots, only some of the data points are shown, and others go missing. As I move the slider, different categories of data points appear or disappear. For example, in the MWE below, only the TP and FP categories are displayed when the slider is below 0.9, while at 0.9, only the TN and FN categories show up. This behavior is consistent for both 2D and 3D scatter plots with animated frames.
Minimum Working Example (MWE):
The following example demonstrates the issue with a 3D scatter plot. Data points are categorized as TP, TN, FP, and FN based on a threshold for predicted probabilities. The slider adjusts the threshold, and data points should appear or change categories dynamically. However, many data points go missing as the slider is moved.
import numpy as np
import pandas as pd
import plotly.express as px
def plot_scatter_3d_mwe():
# Create a small DataFrame with fake data
data = {
'Dim1': np.random.rand(10),
'Dim2': np.random.rand(10),
'Dim3': np.random.rand(10),
'due': [1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
'serial_number': range(10),
'predicted_probabilities': [0.9, 0.8, 0.4, 0.2, 0.6, 0.7, 0.1, 0.5, 0.3, 0.95]
}
df = pd.DataFrame(data)
thresholds = np.arange(0, 1.1, 0.1)
all_frames = []
for threshold in thresholds:
# Recalculate predictions based on the threshold
predicted = (df['predicted_probabilities'] >= threshold).astype(int)
# Create the 4 categories for coloring: TP, TN, FP, FN
conditions = [
(df['due'] == 1) & (predicted == 1), # TP
(df['due'] == 0) & (predicted == 0), # TN
(df['due'] == 0) & (predicted == 1), # FP
(df['due'] == 1) & (predicted == 0), # FN
]
categories = ['TP', 'TN', 'FP', 'FN']
# Assign the categories to a new column
df['category'] = np.select(conditions, categories, default='Unknown')
df['threshold'] = threshold # Add threshold as a column for animation frame
all_frames.append(df.copy())
# Concatenate all frames for animation
df_all_frames = pd.concat(all_frames)
# Plot the scatter 3D with the categories as color and animate over thresholds
fig = px.scatter_3d(df_all_frames,
x='Dim1', y='Dim2', z='Dim3',
color='category',
animation_frame='threshold',
animation_group='serial_number')
fig.show()
# Call the function
plot_scatter_3d_mwe()