graphnet icon indicating copy to clipboard operation
graphnet copied to clipboard

Add cnn model

Open sevmag opened this issue 7 months ago • 1 comments

This is the big PR for the goal of adding CNN support to GraphNeT, enabling direct comparisons (see #771).

The CNN support consists of:

  • ImageDefinition to represent data as an image
  • CNN architectures to train
  • Unit tests
  • Example script for CNN training

An ImageDefinition consists of 2 parts:

  1. A NodeDefinition that preprocesses the raw data and makes sure that the pulses are aggregated at the optical modules (e.g. ClusterSummaryFeatures, or PercentileClusters )
  2. A PixelMapping, which is responsible for creating the images and mapping the nodes into the right location in the image

There are 2 CNN architectures implemented:

  1. LCSC from Alexander Harnisch
  2. TheosMuonEUpgoing, which is the Energy reconstruction architecture from Theo Glauch, used in IceCube

Timing of the ImageDefinition in Comparison to Other Datareps

At a low number of pulses, the bottleneck of the ImageDefinition is the initialisation of zero tensors

Timed Modules

input_feature_names = ['string', 'dom_number', 'dom_time', 'charge']
node_def = PercentileClusters(
    input_feature_names=input_feature_names,
    cluster_on = ['string', 'dom_number'],
    percentiles=np.linspace(0.2, 1.0, 5),
)
data_rep = {
    'edgeless': EdgelessGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
    ),
    'knn_graph_8NN': KNNGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
        nb_nearest_neighbours=8,
    ),
    'knn_graph_16NN': KNNGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
        nb_nearest_neighbours=16,
    ),
    'knn_graph_64NN': KNNGraph(
        node_definition=node_def,
        detector=IceCube86(
            replace_with_identity=input_feature_names,
        ),
        input_feature_names=input_feature_names,
        nb_nearest_neighbours=64,
    ),
    'ic86_dnn': IC86DNNImage(
        node_definition=node_def,
        input_feature_names=input_feature_names,
        include_lower_dc=True,
        include_upper_dc=True,
    ),
}

5000-200000 Mock Pulses (log scale)

Screenshot 2025-07-18 at 17 40 42

1-5000 Mock Pulses

Screenshot 2025-07-18 at 17 33 58

1-500 Mock Pulses

Screenshot 2025-07-18 at 17 32 09

sevmag avatar Jul 18 '25 15:07 sevmag

Hey @sevmag - thank you for implementing CNNs!! 🚀 .

I like the approach you've taken, and I think the PR is generally in pretty good shape. In addition to the specific comments above, I've been thinking that we can simplify the user experience and eliminate the need for new files by introducing a slight refactor of the "Pixelmapping," which changes the role it plays in the image representation.

In essence, I propose that "Pixelmapping" (referred to as "GridDefinition" below) defines the number of images, their sizes, and a method for generating the key-value store(s) that is used to insert pixels into the grid(s) using the existing Detector classes. The functionality of generating grids and inserting pixels would be handled by the image representation. More details below.

Could you take a look and let me know if this fits your use-case?

Preluding observations

  1. Orthonormal grids for image representations in neutrino telescopes are detector-specific and often manually crafted. I.e., picking an image grid equals defining which detector the method will run on. As a result, a grid is not detector agnostic but should rely on a fixed detector geometry.
  2. Pixels are detector-agnostic.

These two observations essentially foresee the existence of two central arguments for image representations. I summarize my proposed scope of each below:

PixelDefinition Defines the meaning of a single pixel. Because a pixel is conceptually similar to a node, existing NodeDefinitions from our graph representations should be compatible here, but the user shouldn't need prior knowledge of graphs in order to use the CNNs. To avoid confusion on the user end, we can consider calling the argument pixel_definition: NodeDefinition with a helpful docstring that points out the similarity. I.e. given a set of [n,d]-dimensional pulses X, the PixelDefinition/NodeDefinition produces a [p,j]-dimensional set of (unordered) pixels P. GridDefinition Defines one or multiple orthonormal grids for image representations of a single detector. It should depend on the Detector component and generate a mapping that identifies the position of a pixel in the grid(s). Images are assumed to have either two or three spatial dimensions. I.e., GridDefinition defines the shape(s) of image(s) and map(s) that identify the position of individual pixels in P in image. ** Role of ImageRepresentation** The glue between the two methods above. I.e.: Given a set of [n,d]-dimensional pulses X, a [p,j]-dimensional set of (unordered) pixels P is produced using PixelDefinition/NodeDefinition. An empty image image is produced, given the shape defined by GridDefinition. Each pixel in P is inserted into image using the key-value store defined in map.

In pseudo-code, the ImageRepresentation could take the form:

from typing import Optional, List, Dict, Union, Tuple, Any, Callable
from numpy.random import Generator
import numpy as np
import pandas as pd

import torch
from torch_geometric.data import Data

from graphnet.models.data_representation import DataRepresentation
from graphnet.models.detector import Detector
from graphnet.models.graphs.nodes import NodeDefinition

class ImageRepresentation(DataRepresentation):
    """ A base class for image representations in GraphNeT."""

    def __init__(self,
                pixel_definition: NodeDefinition,
                grid_definition: GridDefinition,
                input_feature_names: Optional[List[str]] = None,
                dtype: Optional[torch.dtype] = torch.float,
                perturbation_dict: Optional[Dict[str, float]] = None,
                seed: Optional[Union[int, Generator]] = None,
                add_inactive_sensors: bool = False,
                sensor_mask: Optional[List[int]] = None,
                string_mask: Optional[List[int]] = None,
                repeat_labels: bool = False, ) -> None:
        
        # Base class constructor
        super().__init__(
            detector=grid_definition.detector, # defines detector
            input_feature_names=input_feature_names,
            dtype=dtype,
            perturbation_dict=perturbation_dict,
            seed=seed,
            add_inactive_sensors=add_inactive_sensors,
            sensor_mask=sensor_mask,
            string_mask=string_mask,
            repeat_labels=repeat_labels,
        )

        self._pixel_definition = pixel_definition
        self._grid_definition = grid_definition
        self._pixel_mappings = grid_definition.mappings() # yields key-value store(s)
        self._image_shapes = grid_definition.shape # Shape of image(s)
        self._map_pixels_by = self._grid_definition.map_pixels_by

    def forward(  # type: ignore
        self,
        input_features: np.ndarray,
        input_feature_names: List[str],
        truth_dicts: Optional[List[Dict[str, Any]]] = None,
        custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
        loss_weight_column: Optional[str] = None,
        loss_weight: Optional[float] = None,
        loss_weight_default_value: Optional[float] = None,
        data_path: Optional[str] = None,
    ) -> Data:
        """Construct graph as ´Data´ object.

        Args:
            input_features: Input features for graph construction.
                Shape ´[num_rows, d]´
            input_feature_names: name of each column. Shape ´[,d]´.
            truth_dicts: Dictionary containing truth labels.
            custom_label_functions: Custom label functions.
            loss_weight_column: Name of column that holds loss weight.
                                Defaults to None.
            loss_weight: Loss weight associated with event. Defaults to None.
            loss_weight_default_value: default value for loss weight.
                    Used in instances where some events have
                    no pre-defined loss weight. Defaults to None.
            data_path: Path to dataset data files. Defaults to None.

        Returns:
            graph
        """
        # Process low-level pulses using base-class
        data = super().forward(
            input_features=input_features,
            input_feature_names=input_feature_names,
            truth_dicts=truth_dicts,
            custom_label_functions=custom_label_functions,
            loss_weight_column=loss_weight_column,
            loss_weight=loss_weight,
            loss_weight_default_value=loss_weight_default_value,
            data_path=data_path,
        )

        # Transform pulses to pixels
        x = self._pixel_definition(x = data.x)

        # Map pixels to positions in image(s)
        x = self._map_pixels_to_grid(x = x,
                                     pixel_mappings = self._pixel_mappings,
                                     image_shapes = self._image_shapes)
        
        # Assign to Data
        data.x = x

        # other stuff..

        return data

    def _map_pixels_to_grid(self,
                            x: torch.Tensor,
                            pixel_mappings: List[pd.DataFrame],
                            image_shapes: List[int]) -> List[torch.Tensor]:
        """Insert unorderedpixel values in `x` 
           into empty image(s) with shape(s) `image_shapes` using the 
           key-value store defined by `pixel_mappings`."""
        
        # Check that the number of image shapes is equal to number of mappings
        assert len(pixel_mappings) == len(image_shapes)

        # Create and fill images with pixels
        images = []
        # We assume the ordering is identical here
        for shape, mapping in zip(pixel_mappings, image_shapes):
            empty_image = torch.zeros(size = shape)
            filled_image = self._apply_map(empty_image = empty_image,
                                    pixels = x,
                                    mapping = mapping,
                                    map_pixels_by = self._map_pixels_by)
            # [F,D,H,W] -> [1, F, D, H, W] for 3D
            # [F,D,H] -> [1, F, D, H] for 2D
            filled_image = filled_image.unsqueeze(0) 
            images.append(filled_image)
        return images

    def _apply_map(self, 
                   empty_image: torch.Tensor,
                   pixels: torch.Tensor,
                   mapping: pd.DataFrame,
                   map_pixels_by: List[int]) -> torch.Tensor:
        """ 
        Insert values from `pixels` into `empty_image` at positions
        identified by indexing `mapping` with columns `map_pixels_by` in `pixels`

        `empty_image` can either be [F,D,H,W]-dimensional (3D) or [F,D,H] (2D)
        where F denotes the number of pixel features.
        """

    @property
    def shape(self) -> List[Tuple[int]]:
        return self._image_shapes

    def _set_output_feature_names(
        self, input_feature_names: List[str]
    ) -> List[str]:
        """Return ordered list of pixel feature names."""
        return self._pixel_definition.output_feature_names

Note, I didn't write out _apply_map explicitly. This should obviously be done.

Given this structure, the GridDefinition could take the form

from abc import abstractmethod
from typing import Optional, List, Dict, Union, Tuple, Any, Callable

from numpy.random import Generator
import numpy as np
import pandas as pd

import torch
from torch_geometric.data import Data

from graphnet.models import Model
from graphnet.models.detector import Detector
from graphnet.models.graphs.nodes import NodeDefinition

class GridDefinition(Model):
    """ Base class for constructing image partitions in GraphNeT.
    
        The image partitions define orthonormal grids from detector geometry."""
    
    def __init__(self,
                detector: Detector,
                pixel_feature_names: List[str],
                map_pixels_by: List[str]) -> None:
        """detector: Regular graphnet detector class that holds geometry
           pixel_features: list of all available pixel features. Assumed to ordered.
           map_pixels_by: sbuset of pixel_features to map by."""
        super().__init__(name=__name__, class_name=self.__class__.__name__)
        # Checks
        assert isinstance(map_pixels_by, list)
        assert isinstance(pixel_feature_names, list)
        assert isinstance(detector, Detector)

        self.detector = detector
        self._pixel_features = pixel_feature_names
        self._map_pixels_by = map_pixels_by
        self._geometry_table = detector.geometry_table


    @abstractmethod
    def _generate_mappings(self, 
                          geometry_table: pd.DataFrame,
                          map_pixels_by: List[str],
                          pixel_feature_names: List[str]) -> Union[List[pd.DataFrame], pd.DataFrame]:
        """Generate a single, or a list of, key-value stores that relates
        a pixel position defined by `map_pixels_by` to a position in 
        the orthonormal grid using the detector geometry table.
        
        The resulting key-value store is required to be an indexed
        pd.DataFrame, and may use geometric detector features such as 
        
        `from graphnet.models.detector.icecube import IceCube86

        detector = IceCube86() # or any other

        # Natively indexed on xyz positions
        geometry_table = detector.geometry_table.reset_index(drop = False)
        unique_sensor_id = detector.sensor_id_column
        unique_string_id = detector.string_id_column
        unique_sensor_position = detector.xyz`
        """
        return NotImplementedError
    
    @abstractmethod
    def _generate_shapes(self,
                        geometry_table: pd.DataFrame,
                        pixel_features: List[str],
                        map_pixels_by: List[str]) -> Union[Tuple[int],
                                                            List[Tuple[int]]]:
        """Generate the shape(s) of the image grid(s).
         
          E.g. [(10, 5, 2,10), (256, 50, 10, 2)] """

        return NotImplementedError
    
    @property
    def shape(self) -> Union[Tuple[int],List[Tuple[int]]]:
        """Return the shape(s) of the image(s)."""
        if hasattr(self, '_shapes'):
            return self._shapes
        else:
            self._shapes = self._generate_shapes(geometry_table = self._geometry_table,
                                                 pixel_features = self._pixel_features,
                                                 map_pixels_by= self._map_pixels_by)
            return self._shapes

    @property
    def map_pixels_by(self) -> List[str]:
        return self._map_pixels_by
    
    @property
    def mappings(self) -> Union[pd.DataFrame, List[pd.DataFrame]]:
        """Return the key-value stores that map a pixel to a point in the grid(s)."""
        if hasattr(self, "_mappings"):
            return self._mappings
        else:
            self._mappings = self._generate_mappings(geometry_table = self._geometry_table,
                                                 pixel_features = self._pixel_features,
                                                 map_pixels_by= self._map_pixels_by)
            return self._mappings

Within this formalism, your existing IC86 representation could look something like this:

from graphnet.models.detector import IceCube86

from typing import List, Tuple, Union, Dict
import pandas as pd

# Fixed 10x10 placement for strings 1..78 (from your generator)
_IC86_STRING_TO_AX01: Dict[int, Tuple[int, int]] = {
    1:(9,4),  2:(9,5),  3:(9,6),  4:(9,7),  5:(9,8),  6:(9,9),
    7:(8,3),  8:(8,4),  9:(8,5), 10:(8,6), 11:(8,7), 12:(8,8), 13:(8,9),
    14:(7,2), 15:(7,3), 16:(7,4), 17:(7,5), 18:(7,6), 19:(7,7), 20:(7,8), 21:(7,9),
    22:(6,1), 23:(6,2), 24:(6,3), 25:(6,4), 26:(6,5), 27:(6,6), 28:(6,7), 29:(6,8), 30:(6,9),
    31:(5,0), 32:(5,1), 33:(5,2), 34:(5,3), 35:(5,4), 36:(5,5), 37:(5,6), 38:(5,7), 39:(5,8), 40:(5,9),
    41:(4,0), 42:(4,1), 43:(4,2), 44:(4,3), 45:(4,4), 46:(4,5), 47:(4,6), 48:(4,7), 49:(4,8), 50:(4,9),
    51:(3,0), 52:(3,1), 53:(3,2), 54:(3,3), 55:(3,4), 56:(3,5), 57:(3,6), 58:(3,7), 59:(3,8),
    60:(2,0), 61:(2,1), 62:(2,2), 63:(2,3), 64:(2,4), 65:(2,5), 66:(2,6), 67:(2,7),
    68:(1,0), 69:(1,1), 70:(1,2), 71:(1,3), 72:(1,4), 73:(1,5), 74:(1,6),
    75:(0,0), 76:(0,1), 77:(0,2), 78:(0,3),
}

class IC86Grid(GridDefinition):

    def __init__(
        self,
        pixel_feature_names: List[str],
        string_label: str = "string",
        dom_number_label: str = "sensor_id",  # will be aliased to detector.sensor_id_column
        include_main_array: bool = True,
        include_lower_dc: bool = True,
        include_upper_dc: bool = True,
    ) -> None:
        super().__init__(
            detector=IceCube86(),
            pixel_feature_names=pixel_feature_names,
            map_pixels_by=[string_label, dom_number_label],
        )
        if not any([include_main_array, include_lower_dc, include_upper_dc]):
            raise ValueError("Include at least one array type.")

        self._string_label = string_label
        self._dom_number_label = dom_number_label
        self._include_main_array = include_main_array
        self._include_lower_dc = include_lower_dc
        self._include_upper_dc = include_upper_dc

        # channels = all features except the mapping keys
        self._nb_channels = len(pixel_feature_names) - 2


    # ---- GridDefinition interface ----

    def _generate_mappings(
        self,
        geometry_table: pd.DataFrame,
        map_pixels_by: List[str],
        pixel_features: List[str],
    ) -> Union[List[pd.DataFrame], pd.DataFrame]:
        """
        Build one mapping DataFrame per included grid using 
        detector.geometry_table.
        """
       # Your logic goes here
       # Ideally use the "sensor_id" which defines unique DOMs
       # Or, if you prefer, we can add the non-unique "dom_number"
       # to the geometry table 
       # Use global variable above as you wish
       

    def _generate_shapes(
        self,
        geometry_table: pd.DataFrame,
        pixel_features: List[str],
        map_pixels_by: List[str],
    ) -> Union[Tuple[int], List[Tuple[int]]]:
        """ Define the dimension(s) of the image(s) here"""
        # Make sure as little as possible is hardcoded

RasmusOrsoe avatar Aug 25 '25 09:08 RasmusOrsoe