torch_ecg icon indicating copy to clipboard operation
torch_ecg copied to clipboard

Training with patient demographics

Open OmarAshkar opened this issue 1 year ago • 1 comments

Thank you for this amazing work. I am trying to implement the available architectures, for my problem, I will need to have not only the ECG, but also patient data like age, weight ... etc.

Is there is a readily available way to pass them (may be after conv layers?). If not what would you recommend.

Thanks, Omar

OmarAshkar avatar Apr 18 '24 11:04 OmarAshkar

Hi, Omar. The models in this library are assumed to accept time-series signal data only. However, integrating the demographic features is easy. Say for example you are doing some classification task, you can use such a model as a base model (or you can call it a backbone), and wrap it as a child of your model like:

import torch.nn as nn

class CustomModel(nn.Module):

    def __init__(self, base_model, num_dem_features, num_classes):
        self.base_model = base_model
        self.num_dem_features = num_dem_features
        self.num_classes= num_classes

        dem_feature_embed_dim = 32  # this can be a parameter for the __init__ function
        # If you do not want to do embedding for demographic features, you can omit this layer
        self.dem_feature_embedding = nn.Linear(num_dem_features, dem_feature_embed_dim)

        # you know it beforehand or compute by self.base_model.compute_output_shape
        base_model_output_dim = 256
        self.fc = nn.Linear(base_model_output_dim + dem_feature_embed_dim, num_classes)

    def forward(self, ecg_signal, dem_features):
        # assume that ecg_signal is of shape (batch_size, num_leads, sig_len)
        # assume that dem_features is of shape (batch_size, num_dem_features)
        ecg_features = self.base_model(ecg_signal)
        # assume this base model contains a global pooling layer on its top
        # then ecg_features is of shape (batch_size, base_model_output_dim)
        dem_features = self.dem_feature_embedding(dem_features)
        total_features = torch.cat((ecg_features, dem_features), dim=-1)
        logits = self.fc(total_features)  # the logits can be further passed to your loss function
        return logits

wenh06 avatar Apr 19 '24 15:04 wenh06