Slow performance with plotly chart builder
Hello all,
I need to use plotly as the backend of a microservice who generates charts dynamically.
Unfortunately, after a little benchmarking, I found that the plotly.express framework is very slow (around 5 secs to generate a chart from 500 lines dataset).
Here is the script I use to generate a scatter matrix:
import sys
import os
import traceback
import json
import time
sys.path.append('c:\\statwolf\\python\packages\Lib\site-packages')
input = json.loads('{\"file\":\"/tmp/data.tsv",\"color\":\"club_country\",\"dimensions\":[\"nolo\",\"tolo\",\"yolo\"]}')
def action():
def run():
import plotly.express as px
from pandas import read_csv
color = None if input['color'] == "" else input['color']
first = time.time()
d = read_csv(input['file'], sep='\t')
second = time.time()
fig = px.scatter_matrix(d, dimensions=input['dimensions'], color=color)
third = time.time()
j = fig.to_json()
fourth = time.time()
print('read: ' + str(second - first))
print('plot: ' + str(third - second))
print('json: ' + str(fourth - third))
return j
import time
for i in range(0, 3):
start = time.time()
result = run()
end = time.time()
print('iteration: ' + str(i) + '\ntime: ' + str(end - start))
return result
result = None
try:
result = { 'outcome': action() }
except Exception as e:
traceback.print_exc()
result = { 'error': str(e) }
resultDir = os.path.dirname(os.path.realpath(__file__))
resultFile = open(resultDir + '/result.json', 'w')
json.dump(result, resultFile)
resultFile.close()
from the dataset: https://www.dropbox.com/s/cm9i3pfv10exbba/data.tsv?dl=1
and this is the report with timing: https://www.dropbox.com/s/l2x3jqzea4i4xqw/report.txt?dl=1
Now:
- Is there any tweak I can implement to improve performances?
- Do you plan to focus on speed for the following releases?
Hello all, I am also disappointed at the performance of plotly.py because my dataset is not that small. plotly only utilized one of my CPUs to exectute its chart drawing. Does plotly plan to improve the performance by parallel processing or something else?
The same issue. 18 seconds for 800 data points for my data_frame with 10+ columns. Absolutely unreasonable timings for such small data sets.
settings["x"] = "time"
settings["y"] = "value"
settings["labels"] = {"value": y_title, "time_converted": x_title}
settings["facet_col"] = "temperature"
temperatures = str(sorted(data_frame["temperature"].astype(int).unique()))
plot_settings["category_orders"] = {"temperature": temperatures }
# information to be appearted in the plot
settings["hover_data"] = sorted(data_frame.columns)
px.scatter(data_frame, **settings)
I'm sorry that you're seeing such poor performance! In general Plotly Express is much, much faster than this, for example the test suite, which generates hundreds of plots with varying amounts of data, runs in around 30 seconds on my 5-year-old laptop.
@ievgennaida if you can provide a fully-runnable example with data that shows this 18-second performance I'd be happy to take a look.
@nicolaskruchten I cannot share the data but will try to prepare the script as the topic starter did with some fake data.
@nicolaskruchten
I have generated fake data, depends on the PC in our team it's taking 4-8 seconds. It's taking more than getting a query from the DB. On a prod, it's slower for me, but might be other reasons.
Basically, I have noticed that using the number group instead of the string is dropping time to 1 second!
# Fast -> column_data.append(current_step)
# Slow -> column_data.append(f"Group Number #{current_step}")
We will also try to use GO scatter with the multithreading to see whether it will help.
import plotly
import time
import plotly.express as px
import json
import random
# Generate fake data:
data_frame = pd.DataFrame()
rows_count = 800
for column in range(10):
column +=1
current_step= 0
data_points_step = 1
# just generate some groups, not all row should be different
if column == 2:
# y
data_points_step = rows_count / 20
elif column == 3:
# color 50 groups
data_points_step = rows_count / 50
elif column == 4:
# facet_col
data_points_step = rows_count / 5
column_data = []
for row in range(rows_count):
# x
if column == 1:
column_data.append(row)
# y
elif column == 2:
column_data.append(current_step)
# color
elif column == 3:
# THIS CAN IMPROVE PERFORMANCE
# column_data.append(current_step)
column_data.append(f"Group Number #{current_step}")
# facet_col
elif column == 4:
column_data.append(current_step)
# hover_data
else:
column_data.append(f'row-value-{row}')
if current_step + data_points_step <= row:
current_step = current_step + data_points_step
# unsort
random.shuffle(column_data)
data_frame[f'x{column}'] = column_data
print(data_frame.shape)
# starting time
start = time.time()
plot_settings = dict()
plot_settings["hover_data"] = sorted(data_frame.columns)
plot_settings["x"] = "x1"
plot_settings["y"] = "x2"
plot_settings["color"] = 'x3'
plot_settings["facet_col"] = "x4"
fig = px.scatter(data_frame, **plot_settings)
figure = json.loads(plotly.io.to_json(fig, validate=True))
# starting time
end = time.time()
print(f"Runtime of the program is {end - start}")
# fig.show()
Hey @ievgennaida,
I ran your code snippet on my PC, and it finished in 2.866 seconds. The largest chunk (> 90%) of that time is bcs of the px.scatter.
I also tried plotly.express for my use-case (scattering a lot of points; in the following test example there are 2 million points). However, when comparing plotly.express against plotly.grah_objects, I observe that px.scatter is 100x slower than its plotly.graph_objects equivalent.

When further line-profiling px.scatter for this simple use-case (i.e., plotting 1 large sequence), I observed that most of the time is used because of a groupby operation.
As far as I know such operation is not required for creating a scatter plot of 1D array?
If you would wonder why I am interested in this; for plotly-resampler I am looking to create a registration method that adds scalability (under the hood) to the plotly.express interface (see https://github.com/predict-idlab/plotly-resampler/issues/68). However, it seems that plotly.express is significantly slower than plotly its graph_objects interface. I would love to hear some more adivce on how to consume the plotly.express functionality efficiently for very large scatter / line plots?
Profiling of px.scatter:
x = np.arange(2_000_000)
%lprun -f px._core.make_figure px.scatter(x) # line-profile the make_figure function
Profiling result
plotly==5.8.0 python==3.8.10
Timer unit: 1e-06 s
Total time: 18.0257 s
File: /home/jeroen/venv/lib/python3.8/site-packages/plotly/express/_core.py
Function: make_figure at line 1943
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1943 def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1944 1 75.0 75.0 0.0 trace_patch = trace_patch or {}
1945 1 6.0 6.0 0.0 layout_patch = layout_patch or {}
1946 1 1691.0 1691.0 0.0 apply_default_cascade(args)
1947
1948 1 1035154.0 1035154.0 5.7 args = build_dataframe(args, constructor)
1949 1 602.0 602.0 0.0 if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:
1950 args = process_dataframe_hierarchy(args)
1951 1 8.0 8.0 0.0 if constructor == "timeline":
1952 constructor = go.Bar
1953 args = process_dataframe_timeline(args)
1954
1955 2 3157.0 1578.5 0.0 trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
1956 1 8.0 8.0 0.0 args, constructor, trace_patch, layout_patch
1957 )
1958 1 24.0 24.0 0.0 grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
1959 1 3078711.0 3078711.0 17.1 grouped = args["data_frame"].groupby(grouper, sort=False)
1960
1961 1 10608398.0 10608398.0 58.9 orders, sorted_group_names = get_orderings(args, grouper, grouped)
1962
1963 1 6.0 6.0 0.0 col_labels = []
1964 1 3.0 3.0 0.0 row_labels = []
1965 1 4.0 4.0 0.0 nrows = ncols = 1
1966 6 26.0 4.3 0.0 for m in grouped_mappings:
1967 5 23.0 4.6 0.0 if m.grouper not in orders:
1968 4 22.0 5.5 0.0 m.val_map[""] = m.sequence[0]
1969 else:
1970 1 4.0 4.0 0.0 sorted_values = orders[m.grouper]
1971 1 5.0 5.0 0.0 if m.facet == "col":
1972 prefix = get_label(args, args["facet_col"]) + "="
1973 col_labels = [prefix + str(s) for s in sorted_values]
1974 ncols = len(col_labels)
1975 1 4.0 4.0 0.0 if m.facet == "row":
1976 prefix = get_label(args, args["facet_row"]) + "="
1977 row_labels = [prefix + str(s) for s in sorted_values]
1978 nrows = len(row_labels)
1979 2 9.0 4.5 0.0 for val in sorted_values:
1980 1 5.0 5.0 0.0 if val not in m.val_map: # always False if it's an IdentityMap
1981 1 10.0 10.0 0.0 m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
1982
1983 1 1749.0 1749.0 0.0 subplot_type = _subplot_type_for_trace_type(constructor().type)
1984
1985 1 14.0 14.0 0.0 trace_names_by_frame = {}
1986 1 13.0 13.0 0.0 frames = OrderedDict()
1987 1 8.0 8.0 0.0 trendline_rows = []
1988 1 8.0 8.0 0.0 trace_name_labels = None
1989 1 14.0 14.0 0.0 facet_col_wrap = args.get("facet_col_wrap", 0)
1990 2 19.0 9.5 0.0 for group_name in sorted_group_names:
1991 1 2456203.0 2456203.0 13.6 group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
1992 1 20.0 20.0 0.0 mapping_labels = OrderedDict()
1993 1 9.0 9.0 0.0 trace_name_labels = OrderedDict()
1994 1 8.0 8.0 0.0 frame_name = ""
1995 6 55.0 9.2 0.0 for col, val, m in zip(grouper, group_name, grouped_mappings):
1996 5 31.0 6.2 0.0 if col != one_group:
1997 1 45.0 45.0 0.0 key = get_label(args, col)
1998 1 25.0 25.0 0.0 if not isinstance(m.val_map, IdentityMap):
1999 1 15.0 15.0 0.0 mapping_labels[key] = str(val)
2000 1 5.0 5.0 0.0 if m.show_in_trace_name:
2001 1 4.0 4.0 0.0 trace_name_labels[key] = str(val)
2002 1 4.0 4.0 0.0 if m.variable == "animation_frame":
2003 frame_name = val
2004 1 19.0 19.0 0.0 trace_name = ", ".join(trace_name_labels.values())
2005 1 10.0 10.0 0.0 if frame_name not in trace_names_by_frame:
2006 1 11.0 11.0 0.0 trace_names_by_frame[frame_name] = set()
2007 1 4.0 4.0 0.0 trace_names = trace_names_by_frame[frame_name]
2008
2009 2 13.0 6.5 0.0 for trace_spec in trace_specs:
2010 # Create the trace
2011 1 653.0 653.0 0.0 trace = trace_spec.constructor(name=trace_name)
2012 2 23.0 11.5 0.0 if trace_spec.constructor not in [
2013 1 286.0 286.0 0.0 go.Parcats,
2014 1 213.0 213.0 0.0 go.Parcoords,
2015 1 167.0 167.0 0.0 go.Choropleth,
2016 1 64.0 64.0 0.0 go.Choroplethmapbox,
2017 1 70.0 70.0 0.0 go.Densitymapbox,
2018 1 92.0 92.0 0.0 go.Histogram2d,
2019 1 170.0 170.0 0.0 go.Sunburst,
2020 1 142.0 142.0 0.0 go.Treemap,
2021 1 147.0 147.0 0.0 go.Icicle,
2022 ]:
2023 2 2045.0 1022.5 0.0 trace.update(
2024 1 4.0 4.0 0.0 legendgroup=trace_name,
2025 1 7.0 7.0 0.0 showlegend=(trace_name != "" and trace_name not in trace_names),
2026 )
2027 1 789.0 789.0 0.0 if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]:
2028 trace.update(alignmentgroup=True, offsetgroup=trace_name)
2029 1 5.0 5.0 0.0 trace_names.add(trace_name)
2030
2031 # Init subplot row/col
2032 1 12.0 12.0 0.0 trace._subplot_row = 1
2033 1 5.0 5.0 0.0 trace._subplot_col = 1
2034
2035 6 57.0 9.5 0.0 for i, m in enumerate(grouped_mappings):
2036 5 44.0 8.8 0.0 val = group_name[i]
2037 5 41.0 8.2 0.0 try:
2038 5 5480.0 1096.0 0.0 m.updater(trace, m.val_map[val]) # covers most cases
2039 except ValueError:
2040 # this catches some odd cases like marginals
2041 if (
2042 trace_spec != trace_specs[0]
2043 and (
2044 trace_spec.constructor in [go.Violin, go.Box]
2045 and m.variable in ["symbol", "pattern", "dash"]
2046 )
2047 or (
2048 trace_spec.constructor in [go.Histogram]
2049 and m.variable in ["symbol", "dash"]
2050 )
2051 ):
2052 pass
2053 elif (
2054 trace_spec != trace_specs[0]
2055 and trace_spec.constructor in [go.Histogram]
2056 and m.variable == "color"
2057 ):
2058 trace.update(marker=dict(color=m.val_map[val]))
2059 elif (
2060 trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]
2061 and m.variable == "color"
2062 ):
2063 trace.update(
2064 z=[1] * len(group),
2065 colorscale=[m.val_map[val]] * 2,
2066 showscale=False,
2067 showlegend=True,
2068 )
2069 else:
2070 raise
2071
2072 # Find row for trace, handling facet_row and marginal_x
2073 5 79.0 15.8 0.0 if m.facet == "row":
2074 1 7.0 7.0 0.0 row = m.val_map[val]
2075 else:
2076 4 74.0 18.5 0.0 if (
2077 4 54.0 13.5 0.0 args.get("marginal_x") is not None # there is a marginal
2078 and trace_spec.marginal != "x" # and we're not it
2079 ):
2080 row = 2
2081 else:
2082 4 39.0 9.8 0.0 row = 1
2083
2084 # Find col for trace, handling facet_col and marginal_y
2085 5 84.0 16.8 0.0 if m.facet == "col":
2086 1 9.0 9.0 0.0 col = m.val_map[val]
2087 1 9.0 9.0 0.0 if facet_col_wrap: # assumes no facet_row, no marginals
2088 row = 1 + ((col - 1) // facet_col_wrap)
2089 col = 1 + ((col - 1) % facet_col_wrap)
2090 else:
2091 4 55.0 13.8 0.0 if trace_spec.marginal == "y":
2092 col = 2
2093 else:
2094 4 55.0 13.8 0.0 col = 1
2095
2096 5 69.0 13.8 0.0 if row > 1:
2097 trace._subplot_row = row
2098
2099 5 60.0 12.0 0.0 if col > 1:
2100 trace._subplot_col = col
2101 1 13.0 13.0 0.0 if (
2102 1 291.0 291.0 0.0 trace_specs[0].constructor == go.Histogram2dContour
2103 and trace_spec.constructor == go.Box
2104 and trace.line.color
2105 ):
2106 trace.update(marker=dict(color=trace.line.color))
2107
2108 1 11.0 11.0 0.0 if "ecdfmode" in args:
2109 base = args["x"] if args["orientation"] == "v" else args["y"]
2110 var = args["x"] if args["orientation"] == "h" else args["y"]
2111 ascending = args.get("ecdfmode", "standard") != "reversed"
2112 group = group.sort_values(by=base, ascending=ascending)
2113 group_sum = group[var].sum() # compute here before next line mutates
2114 group[var] = group[var].cumsum()
2115 if not ascending:
2116 group = group.sort_values(by=base, ascending=True)
2117
2118 if args.get("ecdfmode", "standard") == "complementary":
2119 group[var] = group_sum - group[var]
2120
2121 if args["ecdfnorm"] == "probability":
2122 group[var] = group[var] / group_sum
2123 elif args["ecdfnorm"] == "percent":
2124 group[var] = 100.0 * group[var] / group_sum
2125
2126 2 10811.0 5405.5 0.1 patch, fit_results = make_trace_kwargs(
2127 1 13.0 13.0 0.0 args, trace_spec, group, mapping_labels.copy(), sizeref
2128 )
2129 1 35173.0 35173.0 0.2 trace.update(patch)
2130 1 10.0 10.0 0.0 if fit_results is not None:
2131 trendline_rows.append(mapping_labels.copy())
2132 trendline_rows[-1]["px_fit_results"] = fit_results
2133 1 9.0 9.0 0.0 if frame_name not in frames:
2134 1 24.0 24.0 0.0 frames[frame_name] = dict(data=[], name=frame_name)
2135 1 13.0 13.0 0.0 frames[frame_name]["data"].append(trace)
2136 1 24.0 24.0 0.0 frame_list = [f for f in frames.values()]
2137 1 11.0 11.0 0.0 if len(frame_list) > 1:
2138 frame_list = sorted(
2139 frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"])
2140 )
2141
2142 1 7.0 7.0 0.0 if show_colorbar:
2143 colorvar = "z" if constructor in [go.Histogram2d, go.Densitymapbox] else "color"
2144 range_color = args["range_color"] or [None, None]
2145
2146 colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
2147 layout_patch["coloraxis1"] = dict(
2148 colorscale=colorscale_validator.validate_coerce(
2149 args["color_continuous_scale"]
2150 ),
2151 cmid=args["color_continuous_midpoint"],
2152 cmin=range_color[0],
2153 cmax=range_color[1],
2154 colorbar=dict(
2155 title_text=get_decorated_label(args, args[colorvar], colorvar)
2156 ),
2157 )
2158 3 25.0 8.3 0.0 for v in ["height", "width"]:
2159 2 17.0 8.5 0.0 if args[v]:
2160 layout_patch[v] = args[v]
2161 1 10.0 10.0 0.0 layout_patch["legend"] = dict(tracegroupgap=0)
2162 1 7.0 7.0 0.0 if trace_name_labels:
2163 1 22.0 22.0 0.0 layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels)
2164 1 43.0 43.0 0.0 if args["title"]:
2165 layout_patch["title_text"] = args["title"]
2166 1 617.0 617.0 0.0 elif args["template"].layout.margin.t is None:
2167 1 12.0 12.0 0.0 layout_patch["margin"] = {"t": 60}
2168 2 14.0 7.0 0.0 if (
2169 1 7.0 7.0 0.0 "size" in args
2170 1 8.0 8.0 0.0 and args["size"]
2171 and args["template"].layout.legend.itemsizing is None
2172 ):
2173 layout_patch["legend"]["itemsizing"] = "constant"
2174
2175 1 10.0 10.0 0.0 if facet_col_wrap:
2176 nrows = math.ceil(ncols / facet_col_wrap)
2177 ncols = min(ncols, facet_col_wrap)
2178
2179 1 11.0 11.0 0.0 if args.get("marginal_x") is not None:
2180 nrows += 1
2181
2182 1 8.0 8.0 0.0 if args.get("marginal_y") is not None:
2183 ncols += 1
2184
2185 2 76233.0 38116.5 0.4 fig = init_figure(
2186 1 8.0 8.0 0.0 args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
2187 )
2188
2189 # Position traces in subplots
2190 2 18.0 9.0 0.0 for frame in frame_list:
2191 2 14.0 7.0 0.0 for trace in frame["data"]:
2192 1 338.0 338.0 0.0 if isinstance(trace, go.Splom):
2193 # Special case that is not compatible with make_subplots
2194 continue
2195
2196 2 2027.0 1013.5 0.0 _set_trace_grid_reference(
2197 1 11.0 11.0 0.0 trace,
2198 1 79.0 79.0 0.0 fig.layout,
2199 1 13.0 13.0 0.0 fig._grid_ref,
2200 1 16.0 16.0 0.0 nrows - trace._subplot_row + 1,
2201 1 11.0 11.0 0.0 trace._subplot_col,
2202 )
2203
2204 # Add traces, layout and frames to figure
2205 1 493599.0 493599.0 2.7 fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
2206 1 18070.0 18070.0 0.1 fig.update_layout(layout_patch)
2207 1 24.0 24.0 0.0 if "template" in args and args["template"] is not None:
2208 1 167201.0 167201.0 0.9 fig.update_layout(template=args["template"], overwrite=True)
2209 1 150.0 150.0 0.0 fig.frames = frame_list if len(frames) > 1 else []
2210
2211 1 11.0 11.0 0.0 if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":
2212 trendline_spec = make_trendline_spec(args, constructor)
2213 trendline_trace = trendline_spec.constructor(
2214 name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False
2215 )
2216 if "line" not in trendline_spec.trace_patch: # no color override
2217 for m in grouped_mappings:
2218 if m.variable == "color":
2219 next_color = m.sequence[len(m.val_map) % len(m.sequence)]
2220 trendline_spec.trace_patch["line"] = dict(color=next_color)
2221 patch, fit_results = make_trace_kwargs(
2222 args, trendline_spec, args["data_frame"], {}, sizeref
2223 )
2224 trendline_trace.update(patch)
2225 fig.add_trace(
2226 trendline_trace, row="all", col="all", exclude_empty_subplots=True
2227 )
2228 fig.update_traces(selector=-1, showlegend=True)
2229 if fit_results is not None:
2230 trendline_rows.append(dict(px_fit_results=fit_results))
2231
2232 1 2555.0 2555.0 0.0 fig._px_trendlines = pd.DataFrame(trendline_rows)
2233
2234 1 20715.0 20715.0 0.1 configure_axes(args, constructor, fig, orders)
2235 1 14.0 14.0 0.0 configure_animation_controls(args, constructor, fig)
2236 1 3.0 3.0 0.0 return fig
this think is horrendously slow w/ financial data of more than few years