sdv helper function for generating generator-discriminator loss charts
Problem Description
When working with GAN models in sdv, visualizing the loss for the generator & discriminator help you understand GAN model performance and what experiment to try next.
Creating a good looking generator-discriminator loss chart requires you to import a viz library, plot the dataframe of loss values, and tweak a bunch of chart settings. Example of one such chart from the interpreting progress of CTGANs post
This was the code needed to make this chart:
import plotly.express as px
# Tidy up the loss values data
loss_df = synthesizer.get_loss_values()
loss_df['Generator Loss'] = loss_df['Generator Loss'].apply(lambda x: x.item())
loss_df['Discriminator Loss'] = loss_df['Discriminator Loss'].apply(lambda x: x.item())
# Create a pretty chart using Plotly Express
fig = px.line(loss_df, x='Epoch', y=['Generator Loss', 'Discriminator Loss'])
fig.update_layout(template='plotly_white',legend_title_text='', legend_orientation="v", legend=dict(x=1.1, y=0.3))
title = 'CTGAN loss function for Census dataset'
fig.update_layout(title=title, xaxis_title='Epoch', yaxis_title='Loss')
fig.show()
Suggested Improvement
It would be great if there was a function in sdv itself that did all this for the user. We've done this for other evaluation plots, like so: https://docs.sdv.dev/sdv/single-table-data/evaluation/visualization
API
CTGANSynthesizer.get_loss_values_plot(): Use this function on a fitted CTGAN synthesizer to plot the generator and discriminator loss values. Under-the-hood, this will use the data from CTGANSynthesizer.get_loss_values().
Parameters: None
Output: A plotly.Figure object containing a line plot showing the generator and discriminator loss values per epoch that was trained.
Error State: If the CTGANSynthesizer has not been fitted yet, raise an error. This should be the same error that is raised if using get_loss_values() before fitting.
UX: Align the design/colors with SDMetrics visualizations. This includes all elements such as:
- Line colors (use green/blue)
- Plot background (should be a gray)
- Font sizes should large (for titles, legend, x/y-axis, etc.)
Hello, is there any similar function for CopulaGAN to plot the generator and discriminator loss values?