darts icon indicating copy to clipboard operation
darts copied to clipboard

[Question] How to define a custom loss function for my Torch Forecasting Models (TFMs) based on y_pred, y_true and *_covariates.

Open fmerinocasallo opened this issue 2 years ago • 3 comments

I have been using torchmetrics.regression.MeanSquaredError as the default loss function for my Torch Forecasting Models (TFMs).

However, now I am interested in using a custom loss function loss_fn for my TFMs that would integrate two terms f1 and f2, where f1 computes a ratio based on y_pred and y_true and is artificially upper-bounded by 1. Conversely, f2 returns 0 or 1 based on y_pred and *_covariates. Something like:

def f1(y_pred, y_true):
    total_y_pred = y_pred.sum()
    total_y_true = y_true.sum()

    return min(abs(total_y_pred - total_y_true)/total_y_true, 1)


def f2(y_pred, covariates):
    return int(y_pred[(covariates == 0) & (y_pred != 0)].any())


def loss_fn(y_pred, y_true, covariates):
    return (1-alpha)*f1(y_pred, y_true)  + alpha*f2(y_pred, covariates)

Note that alpha is an regulatory hyperparameter that I would manually tune to balance the weight of each term (f1 and f2) based on some experimentation.

How could I implement and pass such a custom loss function depending not only on y_pred and y_true (as usual) but also on *_covariates to my TFMs in Darts? :thinking:

fmerinocasallo avatar Dec 21 '23 20:12 fmerinocasallo

Hi @fmerinocasallo, there is a bit of work required to get this working.

For this you need to (I take TFTModel as an example here):

  • subclass from TFM Module (for TFTModel it's _TFTModule in darts.models.forecasting.tft_model.py
  • overwrite _TFTModule._produce_train_output() to return forward output and batch covariates.
  • overwrite _TFTModule._compue_loss() so that it works with with what _produce_train_output() returns.
  • overwrite _TFTModule.training_step() that calls _produce_train_output. Pass forward output and covariates to _compute_loss(). Only pass forward output to _calculate_metrics().
  • overwrite _TFTModule.valiation_step() that calls _produce_train_output. Pass forward output and covariates to _compute_loss(). Only pass forward output to _calculate_metrics().
  • subclass from TFT Model and overwrite _create_model() to return the TFM Module from above.
  • Write your custom loss function that handles the covariates and pass it to the new TFT Model at model creation.

I haven't tested this, but it should work theoretically. It might require some additional work though.

Hope this helped.

dennisbader avatar Dec 24 '23 10:12 dennisbader

Hey @dennisbader, thank you once again for your thoughtful reply :relaxed:

For several days, I have been trying to follow the steps from your proposal using, as suggested, the TFTModel as an example of the procedure. Once I have a working solution, it should be easily adapted to other predictive models supporting future_covariates1, right?

I think I have finally understood the whole process :tada: Let's see if I have actually done it :crossed_fingers:, step by step:

  • overwrite _TFTModule._produce_train_output() to return forward output and batch covariates.

My understanding is that the original _TFTModule._produce_train_output() is inherited from PLMixedCovariatesModule._produce_train_output(), which, in turn, calls to PLMixedCovariatesModule._process_input_batch().

This PLMixedCovariatesModule._process_input_batch() takes the following tuple:

input_batch = (
    torch.Tensor(past_target, past_covariates, historic_future_covariates),
    future_covariates,
    static_covariates,
)

and returns the forward output you mentioned (computed by calling _TFTModule.forward(input_batch)).

My _produce_train_output() should return not only this forward output but also the batch covariates, which I think are future_covariates. Therefore, it could be defined as:

    def _produce_train_output(
        self, input_batch: Tuple
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Feeds MixedCovariatesTorchModel with input and output chunks of a MixedCovariatesSequentialDataset for
        training.

        Parameters:
        ----------
        input_batch
            ``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.
        """
        (
            past_target,
            past_covariates,
            historic_future_covariates,
            future_covariates,
            static_covariates,
        ) = input_batch
        return (self(self._process_input_batch(input_batch)), future_covariates)

  • overwrite _TFTModule.training_step() that calls _produce_train_output. Pass forward output and covariates to _compute_loss(). Only pass forward output to _calculate_metrics().

PLForecastingModule (a parent class from PLMixedCovariatesModule, which in turn is a parent class of _TFTModule), defines the method training_step() as follows:

    def training_step(self, train_batch, batch_idx) -> torch.Tensor:
        """performs the training step"""
        output = self._produce_train_output(train_batch[:-1])
        target = train_batch[
            -1
        ]  # By convention target is always the last element returned by datasets
        loss = self._compute_loss(output, target)
        self.log(
            "train_loss",
            loss,
            batch_size=train_batch[0].shape[0],
            prog_bar=True,
            sync_dist=True,
        )
        self._calculate_metrics(output, target, self.train_metrics)
        return loss

My training_step() method could be defined as:

    def training_step(self, train_batch, batch_idx) -> torch.Tensor:
        """performs the training step"""
        output, batch_cov = self._produce_train_output(train_batch[:-1])
        target = train_batch[
            -1
        ]  # By convention target is always the last element returned by datasets
        loss = self._compute_loss(output, target, batch_cov)
        self.log(
            "train_loss",
            loss,
            batch_size=train_batch[0].shape[0],
            prog_bar=True,
            sync_dist=True,
        )
        self._calculate_metrics(output, target, self.train_metrics)
        return loss

Still, assuming that by convention the batch covariates you mentioned (future_covariates) are always the fourth element returned by datasets (which may be too much of an assumption :sweat_smile:), I would not have to overwrite the original _TFTModule._produce_train_output() method and instead do something like this:

    def training_step(self, train_batch, batch_idx) -> torch.Tensor:
        """performs the training step"""
        output = self._produce_train_output(train_batch[:-1])
        batch_cov = train_batch[3]
        target = train_batch[
            -1
        ]  # By convention target is always the last element returned by datasets
        loss = self._compute_loss(output, target, batch_cov)
        self.log(
            "train_loss",
            loss,
            batch_size=train_batch[0].shape[0],
            prog_bar=True,
            sync_dist=True,
        )
        self._calculate_metrics(output, target, self.train_metrics)
        return loss

  • overwrite _TFTModule.valiation_step() that calls _produce_train_output. Pass forward output and covariates to _compute_loss(). Only pass forward output to _calculate_metrics().

PLForecastingModule also defines the method validation_step() as follows:

    def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
        """performs the validation step"""
        output = self._produce_train_output(val_batch[:-1])
        target = val_batch[-1]
        loss = self._compute_loss(output, target)
        self.log(
            "val_loss",
            loss,
            batch_size=val_batch[0].shape[0],
            prog_bar=True,
            sync_dist=True,
        )
        self._calculate_metrics(output, target, self.val_metrics)
        return loss

Conversely, mine could be defined as:

    def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
        """performs the validation step"""
        output = self._produce_train_output(val_batch[:-1])
        batch_cov = val_batch[3]
        target = val_batch[-1]
        loss = self._compute_loss(output, target, batch_cov)
        self.log(
            "val_loss",
            loss,
            batch_size=val_batch[0].shape[0],
            prog_bar=True,
            sync_dist=True,
        )
        self._calculate_metrics(output, target, self.val_metrics)
        return loss

  • overwrite _TFTModule._compue_loss() so that it works with with what _produce_train_output() returns.

PLForecastingModule also defines the method _compute_loss() as follows:

    def _compute_loss(self, output, target):
        # output is of shape (batch_size, n_timesteps, n_components, n_params)
        if self.likelihood:
            return self.likelihood.compute_loss(output, target)
        else:
            # If there's no likelihood, nr_params=1, and we need to squeeze out the
            # last dimension of model output, for properly computing the loss.
            return self.criterion(output.squeeze(dim=-1), target)

My _compute_loss() method could be instead defined as:

    def _compute_loss(self, output, target, batch_cov):
        # output is of shape (batch_size, n_timesteps, n_components, n_params)
        if self.likelihood:
            return self.likelihood.compute_loss(output, target)
        else:
            # If there's no likelihood, nr_params=1, and we need to squeeze out the
            # last dimension of model output, for properly computing the loss.
            try:
                return self.criterion(output.squeeze(dim=-1), target, batch_cov)
            except TypeError:
                return self.criterion(output.squeeze(dim=-1), target)

I wanted my new subclass to not only support these new custom loss functions but also the usual ones taking only y_pred and y_true as input arguments. Therefore, I opted for using this try block. However, it would only be appropriate provided the function will never raise a TypeError, which may be too much to assume :grimacing: Please let me know if there is a more appropriate way of doing this :pray:

Are there any pitfalls or shortcomings in my line of thought that I should take care of and resolve? :thinking:


1As discussed in #2099, predictive models supporting only past_covariates would require one of the following alternatives:

  1. I shift my availability data stored in past_covariates by -output_chunk_length time steps so that I have this info available in the input_chunk (only valid if input_chunk_length >= output_chunk_length).
  2. I define a new pair of subclasses from darts.utils.data.TrainingDataset and darts.utils.data.InferenceDataset to specify a custom way of slicing the data (target and *_covariates series) to obtain training and inference samples.

fmerinocasallo avatar Jan 11 '24 08:01 fmerinocasallo

I have ended up with a couple of files implementing the whole thing based on @dennisbader feedback. I have some doubts about how to define my custom loss function (I include a couple of questions about this issue later on). I am currently running some tests to see if everything works as expected :crossed_fingers:

The first file defines the newly implemented _TFTConstrainedModule and TFTConstrainedModel classes:

from typing import Tuple

from darts.logging import get_logger, raise_if_not
from darts.models.forecasting.tft_model import _TFTModule, TFTModel
from darts.models.forecasting.tft_submodels import get_embedding_size
import numpy as np
import pandas as pd
import torch

logger = get_logger(__name__)

MixedCovariatesTrainTensorType = Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]


class _TFTConstrainedModule(_TFTModule):
    def _produce_train_output(self, input_batch: Tuple) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Feeds MixedCovariatesTorchModel with input and output chunks of a
        MixedCovariatesTorchModel for training.

        Parameters:
        ----------
        input_batch
            ``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.
        """
        (
            past_target,
            past_covariates,
            historic_future_covariates,
            future_covariates,
            static_covariates
        ) = input_batch
        return (self(self._process_input_batch(input_batch)), future_covariates)

    def training_step(self, train_batch, batch_idx) -> torch.Tensor:
        """performs the training step"""
        output, batch_cov = self._produce_train_output(train_batch[:-1])
        # By convention target is always the last element returned by datasets
        target = train_batch[-1]
        loss = self._compute_loss(output, target, batch_cov)

        self.log(
            "train_loss",
            loss,
            batch_size=train_batch[0].shape[0],
            prog_bar=True,
            sync_dist=True
        )
        self._calculate_metrics(output, target, self.train_metrics)

        return loss

    def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
        """performs the validation step"""
        output, batch_cov = self._produce_train_output(val_batch[:-1])
        target = val_batch[-1]
        loss = self._compute_loss(output, target, batch_cov)

        self.log(
            "val_loss",
            loss,
            batch_size=val_batch[0].shape[0],
            prog_bar=True,
            sync_dist=True
        )
        self._calculate_metrics(output, target, self.val_metrics)

        return loss

    def _compute_loss(self, output, target, batch_cov):
        # output is of shape (batch_size, n_timesteps, n_components, n_params)
        if self.likelihood:
            return self.likelihood.compute_loss(output, target)
        else:
            # If there's no likelihood, nr_params=1, and we need to squeeze out the last dimension of model output,
            # for properly computing the loss.
            try:
                return self.criterion(output.squeeze(dim=-1), target, batch_cov)
            except TypeError:
                return self.criterion(output.squeeze(dim=-1), target)


class TFTConstrainedModel(TFTModel):
    def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> torch.nn.Module:
        """
        `train_sample` contains the following tensors:
            (past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates,
            future_target)

            each tensor has shape (n_timesteps, n_variables)
            - past/historic tensors have shape (input_chunk_length, n_variables)
            - future tensors have shape (output_chunk_length, n_variables)
            - static covariates have shape (component, static variable)

        Darts Interpretation of pytorch-forecasting's TimeSeriesDataSet:
            time_varying_knowns : future_covariates (including historic_future_covariates)
            time_varying_unknowns : past_targets, past_covariates

            time_varying_encoders : [past_targets, past_covariates, historic_future_covariates, future_covariates]
            time_varying_decoders : [historic_future_covariates, future_covariates]

        `variable_meta` is used in TFT to access specific variables
        """
        (
            past_target,
            past_covariate,
            historic_future_covariate,
            future_covariate,
            static_covariates,
            future_target,
        ) = train_sample

        # add a covariate placeholder so that relative index will be included
        if self.add_relative_index:
            time_steps = self.input_chunk_length + self.output_chunk_length

            expand_future_covariate = np.arange(time_steps).reshape((time_steps, 1))

            historic_future_covariate = np.concatenate(
                [
                    ts[: self.input_chunk_length]
                    for ts in [historic_future_covariate, expand_future_covariate]
                    if ts is not None
                ],
                axis=1,
            )
            future_covariate = np.concatenate(
                [
                    ts[-self.output_chunk_length :]
                    for ts in [future_covariate, expand_future_covariate]
                    if ts is not None
                ],
                axis=1,
            )

        self.output_dim = (
            (future_target.shape[1], 1)
            if self.likelihood is None
            else (future_target.shape[1], self.likelihood.num_parameters)
        )

        tensors = [
            past_target,
            past_covariate,
            historic_future_covariate,  # for time varying encoders
            future_covariate,
            future_target,  # for time varying decoders
            static_covariates,  # for static encoder
        ]
        type_names = [
            "past_target",
            "past_covariate",
            "historic_future_covariate",
            "future_covariate",
            "future_target",
            "static_covariate",
        ]
        variable_names = [
            "target",
            "past_covariate",
            "future_covariate",
            "future_covariate",
            "target",
            "static_covariate",
        ]

        variables_meta = {
            "input": {
                type_name: [f"{var_name}_{i}" for i in range(tensor.shape[1])]
                for type_name, var_name, tensor in zip(
                    type_names, variable_names, tensors
                )
                if tensor is not None
            },
            "model_config": {},
        }

        reals_input = []
        categorical_input = []
        time_varying_encoder_input = []
        time_varying_decoder_input = []
        static_input = []
        static_input_numeric = []
        static_input_categorical = []
        categorical_embedding_sizes = {}
        for input_var in type_names:
            if input_var in variables_meta["input"]:
                vars_meta = variables_meta["input"][input_var]
                if input_var in [
                    "past_target",
                    "past_covariate",
                    "historic_future_covariate",
                ]:
                    time_varying_encoder_input += vars_meta
                    reals_input += vars_meta
                elif input_var in ["future_covariate"]:
                    time_varying_decoder_input += vars_meta
                    reals_input += vars_meta
                elif input_var in ["static_covariate"]:
                    if (
                        self.static_covariates is None
                    ):  # when training with fit_from_dataset
                        static_cols = pd.Index(
                            [i for i in range(static_covariates.shape[1])]
                        )
                    else:
                        static_cols = self.static_covariates.columns
                    numeric_mask = ~static_cols.isin(self.categorical_embedding_sizes)
                    for idx, (static_var, col_name, is_numeric) in enumerate(zip(vars_meta, static_cols, numeric_mask)):
                        static_input.append(static_var)
                        if is_numeric:
                            static_input_numeric.append(static_var)
                            reals_input.append(static_var)
                        else:
                            # get embedding sizes for each categorical variable
                            embedding = self.categorical_embedding_sizes[col_name]
                            raise_if_not(
                                isinstance(embedding, (int, tuple)),
                                "Dict values of `categorical_embedding_sizes` must either be integers or tuples. Read "
                                "the TFTModel documentation for more information.",
                                logger,
                            )
                            if isinstance(embedding, int):
                                embedding = (embedding, get_embedding_size(n=embedding))
                            categorical_embedding_sizes[vars_meta[idx]] = embedding

                            static_input_categorical.append(static_var)
                            categorical_input.append(static_var)

        variables_meta["model_config"]["reals_input"] = list(dict.fromkeys(reals_input))
        variables_meta["model_config"]["categorical_input"] = list(
            dict.fromkeys(categorical_input)
        )
        variables_meta["model_config"]["time_varying_encoder_input"] = list(
            dict.fromkeys(time_varying_encoder_input)
        )
        variables_meta["model_config"]["time_varying_decoder_input"] = list(
            dict.fromkeys(time_varying_decoder_input)
        )
        variables_meta["model_config"]["static_input"] = list(
            dict.fromkeys(static_input)
        )
        variables_meta["model_config"]["static_input_numeric"] = list(
            dict.fromkeys(static_input_numeric)
        )
        variables_meta["model_config"]["static_input_categorical"] = list(
            dict.fromkeys(static_input_categorical)
        )

        n_static_components = (
            len(static_covariates) if static_covariates is not None else 0
        )

        self.categorical_embedding_sizes = categorical_embedding_sizes

        return _TFTConstrainedModule(
            output_dim=self.output_dim,
            variables_meta=variables_meta,
            num_static_components=n_static_components,
            hidden_size=self.hidden_size,
            lstm_layers=self.lstm_layers,
            dropout=self.dropout,
            num_attention_heads=self.num_attention_heads,
            full_attention=self.full_attention,
            feed_forward=self.feed_forward,
            hidden_continuous_size=self.hidden_continuous_size,
            categorical_embedding_sizes=self.categorical_embedding_sizes,
            add_relative_index=self.add_relative_index,
            norm_type=self.norm_type,
            **self.pl_module_params,
        )

The second file defines my custom loss function, which handles covariates. I was not sure about how should I implement this custom loss function; a simple function should be enough? a subclass from the torch.nn.Module with a forward() method? a subclass from torchmetrics.metric.Metric with a forward(), update() and compute() methods? :thinking::

import torch
import torch.nn as nn


class RAELoss(nn.Module):
    def __init__(self):
        """
        RAE (Ratio Absolute Error) loss.

        Given a time series of actual values :math:`y_t` and a time series of predicted values :math:`\\hat{y}_t`
        both of length :math:`T`, it is computed as

        .. math::
            \\frac{\\left| \\sum_{t=1}^{T}{y_t} - \\sum_{t=1}^{T}{\\hat{y}_t} \\right|}{\\sum_{t=1}^{T}{y_t}}.
        """
        super().__init__()

    def __repr__(self):
        return "RAELoss()"

    def __str__(self):
        return "RAELoss()"

    def forward(self, inpt, tgt):
        return torch.abs(torch.sum(inpt) - torch.sum(tgt))/torch.sum(tgt)


class ALLoss(nn.Module):
    def __init__(self):
        """
        AL (Availability Law) loss.

        Given a time series of actual values :math:`y_t` and a time series of predicted values :math:`\\hat{y}_t`
        both of length :math:`T`, it is computed as

        .. math::
            \\left\\{\\begin{array}{rcl}
                0 & \\nexists i \\ni (\\hat{y_{i}} \\neq 0)\\ \\cap (z_{i} = 0) & 1 \\leq i \\leq n \\
                1 & \\exists i \\ni (\\hat{y_{i}} \\neq 0)\\ \\cap (z_{i} = 0) & 1 \\leq i \\leq n
            \\end{array}\\right.
        """
        super().__init__()

    def __repr__(self):
        return "ALLoss()"

    def __str__(self):
        return "ALLoss()"

    def forward(self, inpt, tgt, avlblt):
        return torch.int(torch.any(inpt[(avlblt == 0) & (inpt != 0)]))


class RAEALLoss(nn.Module):
    def __init__(self, alpha=0.5):
        """
        Combined RAE (Ratio Absolute Error) and AL (Availability Law) loss.

        Given a time series of actual values :math:`y_t` and a time series of predicted values :math:`\\hat{y}_t`
        both of length :math:`T`, it is computed as

        .. math::
            (1-\\alpha)*\\frac{\\left|\\sum_{t=1}^{T}{y_t}-\\sum_{t=1}^{T}{\\hat{y}_t}\\right|}{\\sum_{t=1}^{T}{y_t}}+\\
            +(\\alpha)*\\left\\{\\begin{array}{rcl}
                0 & \\nexists i \\ni (\\hat{y_{i}} \\neq 0)\\ \\cap (z_{i} = 0) & 1 \\leq i \\leq n \\
                1 & \\exists i \\ni (\\hat{y_{i}} \\neq 0)\\ \\cap (z_{i} = 0) & 1 \\leq i \\leq n
            \\end{array}\\right.
        """
        super().__init__()
        self._alpha = alpha
        self._raeloss = RAELoss()
        self._alloss = ALLoss()

    def __repr__(self):
        return f"RAEALLoss(alpha={self._alpha:1.2f})"

    def __str__(self):
        return f"RAEALLoss(alpha={self._alpha:1.2f})"

    def forward(self, inpt, tgt, avlblt):
        return (1-self._alpha)*self._raeloss.forward(inpt, tgt) + self._alpha*self._alloss.forward(inpt, tgt, avlblt)

As always, please let me know if there is a more appropriate way of doing this 🙏

Are there any pitfalls or shortcomings in my line of thought that I should take care of and resolve? 🤔

fmerinocasallo avatar Jan 22 '24 10:01 fmerinocasallo