torch_ecg
torch_ecg copied to clipboard
Training with patient demographics
Thank you for this amazing work. I am trying to implement the available architectures, for my problem, I will need to have not only the ECG, but also patient data like age, weight ... etc.
Is there is a readily available way to pass them (may be after conv layers?). If not what would you recommend.
Thanks, Omar
Hi, Omar. The models in this library are assumed to accept time-series signal data only. However, integrating the demographic features is easy. Say for example you are doing some classification task, you can use such a model as a base model (or you can call it a backbone), and wrap it as a child of your model like:
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self, base_model, num_dem_features, num_classes):
self.base_model = base_model
self.num_dem_features = num_dem_features
self.num_classes= num_classes
dem_feature_embed_dim = 32 # this can be a parameter for the __init__ function
# If you do not want to do embedding for demographic features, you can omit this layer
self.dem_feature_embedding = nn.Linear(num_dem_features, dem_feature_embed_dim)
# you know it beforehand or compute by self.base_model.compute_output_shape
base_model_output_dim = 256
self.fc = nn.Linear(base_model_output_dim + dem_feature_embed_dim, num_classes)
def forward(self, ecg_signal, dem_features):
# assume that ecg_signal is of shape (batch_size, num_leads, sig_len)
# assume that dem_features is of shape (batch_size, num_dem_features)
ecg_features = self.base_model(ecg_signal)
# assume this base model contains a global pooling layer on its top
# then ecg_features is of shape (batch_size, base_model_output_dim)
dem_features = self.dem_feature_embedding(dem_features)
total_features = torch.cat((ecg_features, dem_features), dim=-1)
logits = self.fc(total_features) # the logits can be further passed to your loss function
return logits