SDV icon indicating copy to clipboard operation
SDV copied to clipboard

sdv helper function for generating generator-discriminator loss charts

Open srinify opened this issue 2 years ago • 1 comments

Problem Description

When working with GAN models in sdv, visualizing the loss for the generator & discriminator help you understand GAN model performance and what experiment to try next.

Creating a good looking generator-discriminator loss chart requires you to import a viz library, plot the dataframe of loss values, and tweak a bunch of chart settings. Example of one such chart from the interpreting progress of CTGANs post

image

This was the code needed to make this chart:

import plotly.express as px

# Tidy up the loss values data
loss_df = synthesizer.get_loss_values()
loss_df['Generator Loss'] = loss_df['Generator Loss'].apply(lambda x: x.item())
loss_df['Discriminator Loss'] = loss_df['Discriminator Loss'].apply(lambda x: x.item())

# Create a pretty chart using Plotly Express
fig = px.line(loss_df, x='Epoch', y=['Generator Loss', 'Discriminator Loss'])
fig.update_layout(template='plotly_white',legend_title_text='', legend_orientation="v", legend=dict(x=1.1, y=0.3))
title = 'CTGAN loss function for Census dataset'
fig.update_layout(title=title, xaxis_title='Epoch', yaxis_title='Loss')
fig.show()

Suggested Improvement

It would be great if there was a function in sdv itself that did all this for the user. We've done this for other evaluation plots, like so: https://docs.sdv.dev/sdv/single-table-data/evaluation/visualization

srinify avatar Mar 04 '24 16:03 srinify

API

CTGANSynthesizer.get_loss_values_plot(): Use this function on a fitted CTGAN synthesizer to plot the generator and discriminator loss values. Under-the-hood, this will use the data from CTGANSynthesizer.get_loss_values().

Parameters: None Output: A plotly.Figure object containing a line plot showing the generator and discriminator loss values per epoch that was trained.

Error State: If the CTGANSynthesizer has not been fitted yet, raise an error. This should be the same error that is raised if using get_loss_values() before fitting.

UX: Align the design/colors with SDMetrics visualizations. This includes all elements such as:

  • Line colors (use green/blue)
  • Plot background (should be a gray)
  • Font sizes should large (for titles, legend, x/y-axis, etc.)

npatki avatar Mar 05 '24 19:03 npatki

Hello, is there any similar function for CopulaGAN to plot the generator and discriminator loss values?

shahenoor avatar Mar 18 '24 16:03 shahenoor