Skip to content

📝 API Reference

Aegear: a computer vision toolkit for tracking and analyzing fish behavior in controlled aquaculture environments.

calibration

Scene calibration module.

This module is used to calibrate the camera and the scene size to get the pixel to cm ratio. It includes a class SceneCalibration that handles the calibration process, including loading camera parameters, assigning scene reference points, calibrating the scene, and rectifying images. The calibration is performed using a set of screen points and a set of real-world reference points.

The class also provides a method to rectify images based on the calibration parameters. It uses OpenCV for image processing and assumes that the camera calibration parameters are stored in a file. The calibration points are expected to be in a specific order: top left, top right, bottom right, bottom left.

Note that this reference matching system is put in place due allow inconsistent camera placement with respect to the original take of the calibration pattern. This calibration system uses this information to rectify the image for easier tracking of the fish, and to estimate the pixel to cm ratio, hence allowing the correct metric tracking of the fish within the experiment.

SceneCalibration

Calibration of the camera and the scene size to get the pixel to cm ratio.

Source code in src\aegear\calibration.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class SceneCalibration:
    """
    Calibration of the camera and the scene size to get the pixel to cm ratio.
    """

    # Sample points used in the Russian Sturgeon experiment, Fazekas et al, 2025.
    DEFAULT_SCENE_REF = np.array([[0, 0], [149.0, 5.0], [149.0, 35.0], [0.0, 40.0]], dtype=np.float32)

    def __init__(self, calibration_path: str, scene_reference=DEFAULT_SCENE_REF):
        """
        Constructor.

        Parameters
        ----------
        calibration_path : str
            Path to the calibration file.
        scene_reference : np.ndarray, optional
            The reference points for the scene. 4x2 array of floats, designating the borders
            of the reference area used for final image rectification and pixel to cm ratio calculation.
            The default value is assume from the Russian Sturgeon experiment, Fazekas et al, 2025.
        """
        self.mtx, self.dist = self._load_calibration(calibration_path)
        self._scene_reference = scene_reference
        self._perspectiveTransform = None

    def _load_calibration(self, calibration_path: str) -> Tuple[np.ndarray, np.ndarray]:
        """
        Load the camera calibration parameters from a file.
        """

        storage = cv2.FileStorage(calibration_path, cv2.FILE_STORAGE_READ)
        mtx = storage.getNode("mtx").mat()
        dist = storage.getNode("dist").mat()
        storage.release()

        return (mtx, dist)

    def assign_scene_calibration(self, points: List[Tuple[float, float]]):
        """
        Assign the scene calibration points.

        Parameters
        ----------

        points : list
            The scene reference points to use for calibration.
            The 4x2 array of floats, designating the borders of the reference area used for final image rectification and pixel to cm ratio calculation.
            By convention, the points are in the order: top left, top right, bottom right, bottom left.
        """
        points = np.array(points, dtype=np.float32)
        assert points.shape == (4, 2), "Real points must be a 4x2 array"
        self._scene_reference = points

    def calibrate(self, screen_pts: List[Tuple[float, float]]) -> float:
        """
        Run the scene characterization.

        Parameters
        ----------
        screen_pts : list
            The screen points to use for calibration, which within the scene match the points assigned for the scene reference.
            As for the reference points, the points are in the order: top left, top right, bottom right, bottom left.

        Returns
        -------
        float
            The pixel to cm ratio.
        """
        sample_pts = np.array(screen_pts, dtype=np.float32)
        assert sample_pts.shape == (4, 2), "Screen points must be a 4x2 array"

        sample_pts = cv2.undistortPoints(
            np.array(sample_pts, dtype=np.float32).reshape(-1, 1, 2),
            self.mtx,
            self.dist,
            P=self.mtx
        ).reshape(-1, 2) # Reshape to (N, 2) for direct use

        sample_avg_scale = np.mean(np.linalg.norm(np.diff(sample_pts, axis=0)))
        scene_avg_scale = np.mean(np.linalg.norm(np.diff(self._scene_reference, axis=0)))

        img_scaling_factor = sample_avg_scale / scene_avg_scale 

        # move points to match starting x position of samples, and scale up to image scale
        transformed_real_pts = self._scene_reference * img_scaling_factor + sample_pts[0, :]

        # do perspective transform to rectify image
        persp_T = cv2.getPerspectiveTransform(sample_pts, transformed_real_pts)

        # add homogeneous coordinate
        sample_pts = np.hstack((sample_pts, np.ones((4, 1))))

        # also warp points to be able to calculate pixel to cm ratio
        sample_pts = np.dot(persp_T, sample_pts.T).T

        # divide by homogeneous coordinate
        sample_pts = sample_pts[:, 0:2] / sample_pts[:, 2].reshape((4, 1))

        # now calculate pixel to cm ratio
        sample_avg_scale = np.mean(np.linalg.norm(np.diff(sample_pts, axis=0)))
        pixel_to_cm_ratio = scene_avg_scale / sample_avg_scale

        self._perspectiveTransform = persp_T

        return pixel_to_cm_ratio

    def rectify_image(self, image: np.ndarray) -> np.ndarray:
        """
        Rectify the image.

        Parameters
        ----------
        image : np.ndarray
            The image to rectify.

        Returns
        -------
        np.ndarray
            The rectified image.

        """
        assert self._perspectiveTransform is not None, "Need to calibrate first"

        ret_image = cv2.undistort(image, self.mtx, self.dist)
        ret_image = cv2.warpPerspective(ret_image, self._perspectiveTransform, image.shape[0:2][::-1])

        return ret_image

    def rectify_point(self, point: tuple[float, float]) -> tuple[float, float]:
        """
        Rectify a single point using the current calibration.

        Parameters
        ----------
        point : tuple of float
            The (x, y) coordinates of the point to rectify.

        Returns
        -------
        tuple of float
            The rectified (x, y) coordinates.
        """
        assert self._perspectiveTransform is not None, "Need to calibrate first"

        # Step 1: Undistort
        undistorted_pt = cv2.undistortPoints(
            np.array([[point]], dtype=np.float32),
            self.mtx,
            self.dist,
            P=self.mtx
        )[0, 0]

        # Step 2: Perspective transform
        rectified_pt = cv2.perspectiveTransform(
            np.array([[undistorted_pt]], dtype=np.float32),
            self._perspectiveTransform
        )[0, 0]

        return tuple(rectified_pt)

    def unrectify_point(self, point: tuple[float, float]) -> tuple[float, float]:
        """
        Map a point from the rectified image back to its original (distorted) image coordinates.
        """
        assert self._perspectiveTransform is not None, "Need to calibrate first"

        # 1. undo the perspective warp
        inv_T = np.linalg.inv(self._perspectiveTransform)
        pt = np.array([point[0], point[1], 1.0], dtype=np.float32)
        undist_h = inv_T.dot(pt)
        undist_px = undist_h[:2] / undist_h[2]

        # 2. convert back to normalized camera coords
        inv_mtx = np.linalg.inv(self.mtx)
        uv1 = np.array([undist_px[0], undist_px[1], 1.0], dtype=np.float32)
        norm = inv_mtx.dot(uv1).reshape(1, 3)

        # 3. project through intrinsics+distortion to get the original pixel
        rvec = np.zeros(3, dtype=np.float32)
        tvec = np.zeros(3, dtype=np.float32)
        img_pts, _ = cv2.projectPoints(norm, rvec, tvec, self.mtx, self.dist)
        x, y = img_pts[0, 0]

        return (float(x), float(y))

assign_scene_calibration(points)

Assign the scene calibration points.

Parameters
list

The scene reference points to use for calibration. The 4x2 array of floats, designating the borders of the reference area used for final image rectification and pixel to cm ratio calculation. By convention, the points are in the order: top left, top right, bottom right, bottom left.

Source code in src\aegear\calibration.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def assign_scene_calibration(self, points: List[Tuple[float, float]]):
    """
    Assign the scene calibration points.

    Parameters
    ----------

    points : list
        The scene reference points to use for calibration.
        The 4x2 array of floats, designating the borders of the reference area used for final image rectification and pixel to cm ratio calculation.
        By convention, the points are in the order: top left, top right, bottom right, bottom left.
    """
    points = np.array(points, dtype=np.float32)
    assert points.shape == (4, 2), "Real points must be a 4x2 array"
    self._scene_reference = points

calibrate(screen_pts)

Run the scene characterization.

Parameters

screen_pts : list The screen points to use for calibration, which within the scene match the points assigned for the scene reference. As for the reference points, the points are in the order: top left, top right, bottom right, bottom left.

Returns

float The pixel to cm ratio.

Source code in src\aegear\calibration.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def calibrate(self, screen_pts: List[Tuple[float, float]]) -> float:
    """
    Run the scene characterization.

    Parameters
    ----------
    screen_pts : list
        The screen points to use for calibration, which within the scene match the points assigned for the scene reference.
        As for the reference points, the points are in the order: top left, top right, bottom right, bottom left.

    Returns
    -------
    float
        The pixel to cm ratio.
    """
    sample_pts = np.array(screen_pts, dtype=np.float32)
    assert sample_pts.shape == (4, 2), "Screen points must be a 4x2 array"

    sample_pts = cv2.undistortPoints(
        np.array(sample_pts, dtype=np.float32).reshape(-1, 1, 2),
        self.mtx,
        self.dist,
        P=self.mtx
    ).reshape(-1, 2) # Reshape to (N, 2) for direct use

    sample_avg_scale = np.mean(np.linalg.norm(np.diff(sample_pts, axis=0)))
    scene_avg_scale = np.mean(np.linalg.norm(np.diff(self._scene_reference, axis=0)))

    img_scaling_factor = sample_avg_scale / scene_avg_scale 

    # move points to match starting x position of samples, and scale up to image scale
    transformed_real_pts = self._scene_reference * img_scaling_factor + sample_pts[0, :]

    # do perspective transform to rectify image
    persp_T = cv2.getPerspectiveTransform(sample_pts, transformed_real_pts)

    # add homogeneous coordinate
    sample_pts = np.hstack((sample_pts, np.ones((4, 1))))

    # also warp points to be able to calculate pixel to cm ratio
    sample_pts = np.dot(persp_T, sample_pts.T).T

    # divide by homogeneous coordinate
    sample_pts = sample_pts[:, 0:2] / sample_pts[:, 2].reshape((4, 1))

    # now calculate pixel to cm ratio
    sample_avg_scale = np.mean(np.linalg.norm(np.diff(sample_pts, axis=0)))
    pixel_to_cm_ratio = scene_avg_scale / sample_avg_scale

    self._perspectiveTransform = persp_T

    return pixel_to_cm_ratio

rectify_image(image)

Rectify the image.

Parameters

image : np.ndarray The image to rectify.

Returns

np.ndarray The rectified image.

Source code in src\aegear\calibration.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def rectify_image(self, image: np.ndarray) -> np.ndarray:
    """
    Rectify the image.

    Parameters
    ----------
    image : np.ndarray
        The image to rectify.

    Returns
    -------
    np.ndarray
        The rectified image.

    """
    assert self._perspectiveTransform is not None, "Need to calibrate first"

    ret_image = cv2.undistort(image, self.mtx, self.dist)
    ret_image = cv2.warpPerspective(ret_image, self._perspectiveTransform, image.shape[0:2][::-1])

    return ret_image

rectify_point(point)

Rectify a single point using the current calibration.

Parameters

point : tuple of float The (x, y) coordinates of the point to rectify.

Returns

tuple of float The rectified (x, y) coordinates.

Source code in src\aegear\calibration.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def rectify_point(self, point: tuple[float, float]) -> tuple[float, float]:
    """
    Rectify a single point using the current calibration.

    Parameters
    ----------
    point : tuple of float
        The (x, y) coordinates of the point to rectify.

    Returns
    -------
    tuple of float
        The rectified (x, y) coordinates.
    """
    assert self._perspectiveTransform is not None, "Need to calibrate first"

    # Step 1: Undistort
    undistorted_pt = cv2.undistortPoints(
        np.array([[point]], dtype=np.float32),
        self.mtx,
        self.dist,
        P=self.mtx
    )[0, 0]

    # Step 2: Perspective transform
    rectified_pt = cv2.perspectiveTransform(
        np.array([[undistorted_pt]], dtype=np.float32),
        self._perspectiveTransform
    )[0, 0]

    return tuple(rectified_pt)

unrectify_point(point)

Map a point from the rectified image back to its original (distorted) image coordinates.

Source code in src\aegear\calibration.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def unrectify_point(self, point: tuple[float, float]) -> tuple[float, float]:
    """
    Map a point from the rectified image back to its original (distorted) image coordinates.
    """
    assert self._perspectiveTransform is not None, "Need to calibrate first"

    # 1. undo the perspective warp
    inv_T = np.linalg.inv(self._perspectiveTransform)
    pt = np.array([point[0], point[1], 1.0], dtype=np.float32)
    undist_h = inv_T.dot(pt)
    undist_px = undist_h[:2] / undist_h[2]

    # 2. convert back to normalized camera coords
    inv_mtx = np.linalg.inv(self.mtx)
    uv1 = np.array([undist_px[0], undist_px[1], 1.0], dtype=np.float32)
    norm = inv_mtx.dot(uv1).reshape(1, 3)

    # 3. project through intrinsics+distortion to get the original pixel
    rvec = np.zeros(3, dtype=np.float32)
    tvec = np.zeros(3, dtype=np.float32)
    img_pts, _ = cv2.projectPoints(norm, rvec, tvec, self.mtx, self.dist)
    x, y = img_pts[0, 0]

    return (float(x), float(y))

motiondetection

Motion detection module.

This module provides the MotionDetector class that identifies motion by comparing three consecutive frames. The algorithm converts frames to grayscale, computes the absolute difference between frames, applies binary thresholding, combines the results, and uses morphological operations to filter the motion regions before extracting contours.

MotionDetector

Motion detector class that identifies motion by comparing three consecutive frames.

Source code in src\aegear\motiondetection.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class MotionDetector:
    """
    Motion detector class that identifies motion by comparing three consecutive frames.
    """

    MIN_AREA: int = 10

    def __init__(self, motion_threshold: int, erode_kernel_size: int = 3,
                 dilate_kernel_size: int = 15, min_area: int = 800, max_area: int = 3000) -> None:
        """
        Initialize the MotionDetector.

        Parameters
        ----------
        motion_threshold : int
            The threshold used to detect motion based on pixel intensity difference.
        erode_kernel_size : int, optional
            Size of the kernel used for erosion (default is 3).
        dilate_kernel_size : int, optional
            Size of the kernel used for dilation (default is 15).
        min_area : int, optional
            Minimum contour area to be considered as good motion (default is 800).
        max_area : int, optional
            Maximum contour area to be considered as good motion (default is 3000).
        """
        self.motion_threshold = motion_threshold
        self.erode_kernel_size = erode_kernel_size
        self.dilate_kernel_size = dilate_kernel_size
        self.min_area = min_area
        self.max_area = max_area

    def detect(self, prev_frame: np.ndarray, this_frame: np.ndarray,
               next_frame: np.ndarray) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        Detect motion by comparing three consecutive frames.

        The function converts the frames to grayscale, computes the absolute differences,
        thresholds them to produce binary images, combines the thresholded images, applies
        morphological operations to remove noise, and finally extracts contours. Detected
        contours are classified into "good" (within the area range) and "bad" (outside the
        area range but above a minimum threshold).

        Parameters
        ----------
        prev_frame : numpy.ndarray
            Previous frame in BGR color space.
        this_frame : numpy.ndarray
            Current frame in BGR color space.
        next_frame : numpy.ndarray
            Next frame in BGR color space.

        Returns
        -------
        Tuple[List[numpy.ndarray], List[numpy.ndarray]]
            A tuple containing two lists of contours:
            - The first list contains contours with areas between min_area and max_area.
            - The second list contains contours with areas outside that range but above MIN_AREA.
        """
        # Convert frames to grayscale
        gprev_frame = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
        gframe = cv2.cvtColor(this_frame, cv2.COLOR_BGR2GRAY)
        gnext_frame = cv2.cvtColor(next_frame, cv2.COLOR_BGR2GRAY)

        # Compute absolute differences between the current frame and its neighbors
        diff_prev = np.abs(gframe.astype(np.float32) - gprev_frame.astype(np.float32)).astype(np.uint8)
        diff_next = np.abs(gframe.astype(np.float32) - gnext_frame.astype(np.float32)).astype(np.uint8)

        # Apply binary thresholding to highlight significant differences
        _, thresh_prev = cv2.threshold(diff_prev, self.motion_threshold, 255, cv2.THRESH_BINARY)
        _, thresh_next = cv2.threshold(diff_next, self.motion_threshold, 255, cv2.THRESH_BINARY)

        # Combine the thresholded images
        combined = cv2.bitwise_or(thresh_prev, thresh_next)

        # Apply morphological operations to reduce noise and close gaps
        erode_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (self.erode_kernel_size, self.erode_kernel_size))
        dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (self.dilate_kernel_size, self.dilate_kernel_size))
        morphed = cv2.erode(combined, erode_kernel)
        morphed = cv2.dilate(morphed, dilate_kernel)

        # Smooth the image and reapply thresholding to finalize the binary image
        blurred = cv2.GaussianBlur(morphed, (19, 19), 5.0)
        _, final_thresh = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY)

        # Find contours in the thresholded image
        contours, _ = cv2.findContours(final_thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

        good_contours: List[np.ndarray] = []
        bad_contours: List[np.ndarray] = []

        # Classify contours based on their area
        for contour in contours:
            area = cv2.contourArea(contour)
            if area < MotionDetector.MIN_AREA:
                continue

            if self.min_area <= area <= self.max_area:
                good_contours.append(contour)
            else:
                bad_contours.append(contour)

        return good_contours, bad_contours

detect(prev_frame, this_frame, next_frame)

Detect motion by comparing three consecutive frames.

The function converts the frames to grayscale, computes the absolute differences, thresholds them to produce binary images, combines the thresholded images, applies morphological operations to remove noise, and finally extracts contours. Detected contours are classified into "good" (within the area range) and "bad" (outside the area range but above a minimum threshold).

Parameters

prev_frame : numpy.ndarray Previous frame in BGR color space. this_frame : numpy.ndarray Current frame in BGR color space. next_frame : numpy.ndarray Next frame in BGR color space.

Returns

Tuple[List[numpy.ndarray], List[numpy.ndarray]] A tuple containing two lists of contours: - The first list contains contours with areas between min_area and max_area. - The second list contains contours with areas outside that range but above MIN_AREA.

Source code in src\aegear\motiondetection.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def detect(self, prev_frame: np.ndarray, this_frame: np.ndarray,
           next_frame: np.ndarray) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """
    Detect motion by comparing three consecutive frames.

    The function converts the frames to grayscale, computes the absolute differences,
    thresholds them to produce binary images, combines the thresholded images, applies
    morphological operations to remove noise, and finally extracts contours. Detected
    contours are classified into "good" (within the area range) and "bad" (outside the
    area range but above a minimum threshold).

    Parameters
    ----------
    prev_frame : numpy.ndarray
        Previous frame in BGR color space.
    this_frame : numpy.ndarray
        Current frame in BGR color space.
    next_frame : numpy.ndarray
        Next frame in BGR color space.

    Returns
    -------
    Tuple[List[numpy.ndarray], List[numpy.ndarray]]
        A tuple containing two lists of contours:
        - The first list contains contours with areas between min_area and max_area.
        - The second list contains contours with areas outside that range but above MIN_AREA.
    """
    # Convert frames to grayscale
    gprev_frame = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
    gframe = cv2.cvtColor(this_frame, cv2.COLOR_BGR2GRAY)
    gnext_frame = cv2.cvtColor(next_frame, cv2.COLOR_BGR2GRAY)

    # Compute absolute differences between the current frame and its neighbors
    diff_prev = np.abs(gframe.astype(np.float32) - gprev_frame.astype(np.float32)).astype(np.uint8)
    diff_next = np.abs(gframe.astype(np.float32) - gnext_frame.astype(np.float32)).astype(np.uint8)

    # Apply binary thresholding to highlight significant differences
    _, thresh_prev = cv2.threshold(diff_prev, self.motion_threshold, 255, cv2.THRESH_BINARY)
    _, thresh_next = cv2.threshold(diff_next, self.motion_threshold, 255, cv2.THRESH_BINARY)

    # Combine the thresholded images
    combined = cv2.bitwise_or(thresh_prev, thresh_next)

    # Apply morphological operations to reduce noise and close gaps
    erode_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (self.erode_kernel_size, self.erode_kernel_size))
    dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (self.dilate_kernel_size, self.dilate_kernel_size))
    morphed = cv2.erode(combined, erode_kernel)
    morphed = cv2.dilate(morphed, dilate_kernel)

    # Smooth the image and reapply thresholding to finalize the binary image
    blurred = cv2.GaussianBlur(morphed, (19, 19), 5.0)
    _, final_thresh = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY)

    # Find contours in the thresholded image
    contours, _ = cv2.findContours(final_thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

    good_contours: List[np.ndarray] = []
    bad_contours: List[np.ndarray] = []

    # Classify contours based on their area
    for contour in contours:
        area = cv2.contourArea(contour)
        if area < MotionDetector.MIN_AREA:
            continue

        if self.min_area <= area <= self.max_area:
            good_contours.append(contour)
        else:
            bad_contours.append(contour)

    return good_contours, bad_contours

nn

datasets

TrackingDataset

Bases: Dataset

Source code in src\aegear\nn\datasets.py
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
class TrackingDataset(Dataset):

    _MAX_NEGATIVE_OFFSET = 50  # Maximum offset for negative samples

    def __init__(
        self,
        tracking_data,
        video_dir="",
        output_size=128,
        crop_size=168,
        future_frame_seek=[1, 3, 5, 7],
        random_pick_future_seek=False,
        interpolation_smoothness=0.5,
        temporal_jitter_range=0,
        gaussian_sigma=6.0,
        augmentation_transform=None,
        rotation_range=None,
        scale_range=None,
        negative_sample_prob=0.0,
        centroid_perturbation_range=0.0,
    ):

        self.video_path = os.path.join(video_dir, tracking_data["video"])
        self.tracking = sorted(
            tracking_data["tracking"], key=lambda x: x["frame_id"])
        self.smooth_trajectory, self.min_frame, self.max_frame = self._interpolate_tracking(
            interpolation_smoothness)
        self.future_frame_seek = future_frame_seek
        self.output_size = output_size
        self.crop_size = crop_size
        self.random_pick_future_seek = random_pick_future_seek
        self.rotation_range = rotation_range
        self.scale_range = scale_range
        self.negative_sample_prob = negative_sample_prob
        self.centroid_perturbation_range = centroid_perturbation_range
        self.temporal_jitter_range = temporal_jitter_range
        self.gaussian_sigma = gaussian_sigma

        # Estimate FPS from video file
        self.video = cv2.VideoCapture(self.video_path)
        if not self.video.isOpened():
            raise Exception(f"Could not open video file: {self.video_path}")

        self.fps = self.video.get(cv2.CAP_PROP_FPS)
        self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.resolution = np.array([self.frame_width, self.frame_height])

        self.augmentation_transform = augmentation_transform

        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

    @staticmethod
    def build_split_datasets(json_filepaths, video_dir, train_fraction=0.9,
                             future_frame_seek=[1, 3, 5, 7], interpolation_smoothness=0.5, gaussian_sigma=6.0,
                             augmentation_transforms=None, rotation_range=None, scale_range=None, negative_sample_prob=0.0):

        train_datasets = []
        val_datasets = []

        for path in json_filepaths:
            with open(path, 'r') as f:
                data = json.load(f)

            full_tracking = data['tracking']
            video = data['video']

            # Shuffle and split indices
            indices = list(range(len(full_tracking)))
            random.shuffle(indices)

            split_idx = int(len(indices) * train_fraction)
            train_idx = indices[:split_idx]
            val_idx = indices[split_idx:]

            # Subsets of tracking samples
            train_tracking = [full_tracking[i] for i in train_idx]
            val_tracking = [full_tracking[i] for i in val_idx]

            train_data = {
                "video": video,
                "tracking": train_tracking
            }

            val_data = {
                "video": video,
                "tracking": val_tracking
            }

            # Build train dataset
            train_dataset = TrackingDataset(
                tracking_data=train_data,
                video_dir=video_dir,
                future_frame_seek=future_frame_seek,
                random_pick_future_seek=True,
                interpolation_smoothness=interpolation_smoothness,
                gaussian_sigma=gaussian_sigma,
                rotation_range=rotation_range,
                scale_range=scale_range,
                negative_sample_prob=negative_sample_prob,
                augmentation_transform=augmentation_transforms
            )
            train_datasets.append(train_dataset)

            # Build val dataset
            val_dataset = TrackingDataset(
                tracking_data=val_data,
                video_dir=video_dir,
                future_frame_seek=future_frame_seek,
                random_pick_future_seek=False,
                interpolation_smoothness=interpolation_smoothness,
                gaussian_sigma=gaussian_sigma
            )
            val_datasets.append(val_dataset)

        # Concat across all videos
        final_train_dataset = ConcatDataset(train_datasets)
        final_val_dataset = ConcatDataset(val_datasets)

        return final_train_dataset, final_val_dataset

    def _interpolate_tracking(self, interpolation_smoothness):
        frame_ids = np.array([pt["frame_id"] for pt in self.tracking])
        coords = np.array([pt["coordinates"] for pt in self.tracking])

        min_frame = int(frame_ids.min())
        max_frame = int(frame_ids.max())
        dense_frames = np.arange(min_frame, max_frame)

        rbf_x = Rbf(
            frame_ids, coords[:, 0], function='multiquadric', epsilon=interpolation_smoothness)
        rbf_y = Rbf(
            frame_ids, coords[:, 1], function='multiquadric', epsilon=interpolation_smoothness)

        x_interp = rbf_x(dense_frames)
        y_interp = rbf_y(dense_frames)

        trajectory = np.stack([x_interp, y_interp], axis=1)

        return trajectory, min_frame, max_frame

    def test_sequence_cache(self):
        for frame_id in range(self.min_frame, self.max_frame):
            try:
                frame = self._read_frame(frame_id)
            except:
                print(f"Frame {frame_id} not found in video {self.video_path}")
                continue

            img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            coodinate = self.smooth_trajectory[frame_id - self.min_frame]

            cv2.circle(img, (int(coodinate[0]), int(
                coodinate[1])), 5, (0, 255, 0), -1)

            cv2.imshow("Test", np.array(img))
            cv2.waitKey(0)

    def _read_frame(self, frame_id):
        self.video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
        ret, img = self.video.read()
        if not ret:
            raise Exception(
                f"Could not read frame {frame_id} from video {self.video_path}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img

    def _get_crop(self, frame_id, center, transform: Tuple[float, float]):
        frame = self._read_frame(frame_id)

        crop_size = self.crop_size
        output_size = self.output_size

        if transform is None:
            x1 = int(center[0] - output_size // 2)
            y1 = int(center[1] - output_size // 2)
            x2 = x1 + output_size
            y2 = y1 + output_size

            if x1 < 0 or y1 < 0 or x2 > frame.shape[1] or y2 > frame.shape[0]:
                raise IndexError("Crop out of bounds")

            return frame[y1:y2, x1:x2, :]
        else:
            rotation_deg, scale = transform
            # Compute top-left corner of the large crop
            x1 = int(center[0] - crop_size // 2)
            y1 = int(center[1] - crop_size // 2)
            x2 = x1 + crop_size
            y2 = y1 + crop_size

            if x1 < 0 or y1 < 0 or x2 > frame.shape[1] or y2 > frame.shape[0]:
                raise IndexError("Crop out of bounds")

            crop = frame[y1:y2, x1:x2, :]

            center_point = (crop_size // 2, crop_size // 2)
            M = cv2.getRotationMatrix2D(center_point, rotation_deg, scale)

            rotated = cv2.warpAffine(
                crop, M, (crop_size, crop_size), flags=cv2.INTER_LINEAR)

            # Final center crop to self.crop_size
            start = crop_size // 2 - output_size // 2
            end = start + output_size

            return rotated[start:end, start:end, :]

    def transform_offset_for_heatmap(self, offset, transform: Tuple[float, float]):
        """
        Apply rotation and scale to an offset vector, then map to heatmap coordinates.

        Args:
            offset: np.ndarray shape (2,), the vector (search - template)
            transform: Tuple[float, float] = (rotation_deg, scale)

        Returns:
            np.ndarray of shape (2,), transformed and rescaled offset in heatmap coordinates
        """

        crop_size = self.crop_size
        output_size = self.output_size

        if transform:
            rotation_deg, scale = transform
            theta = np.deg2rad(rotation_deg)

            # 2D rotation matrix with scale
            R = np.array([
                [np.cos(theta), -np.sin(theta)],
                [np.sin(theta),  np.cos(theta)]
            ]) * scale

            offset = R @ offset

        heatmap_scale = output_size / crop_size
        search_roi_hit = offset * heatmap_scale + output_size // 2

        return search_roi_hit

    def generate_gaussian_heatmap(self, center):
        output_size = self.output_size

        x = torch.arange(0, output_size, 1).float()
        y = torch.arange(0, output_size, 1).float()
        y = y[:, None]

        x0, y0 = center
        heatmap = torch.exp(-((x - x0)**2 + (y - y0)**2) /
                            (2 * self.gaussian_sigma**2))
        return heatmap

    def __len__(self):
        max_future_seek = max(self.future_frame_seek) + \
            self.temporal_jitter_range
        last_frame = self.tracking[-1]["frame_id"]
        num_margin_frames = 0

        for i in range(len(self.tracking) - 1, -1, -1):
            num_margin_frames += 1
            if self.tracking[i]["frame_id"] + max_future_seek < last_frame:
                break

        num_samples = len(self.tracking) - num_margin_frames - 1

        if not self.random_pick_future_seek:
            num_samples *= len(self.future_frame_seek)

        return num_samples

    def __del__(self):
        if self.video.isOpened():
            self.video.release()

    def __getitem__(self, idx):
        if self.random_pick_future_seek:
            # Reset seed with  time for max randomness
            frame_jump = random.choice(self.future_frame_seek)
            template_tracking = self.tracking[idx]
        else:
            # use modulo to cycle through future_frame_seek
            frame_jump = self.future_frame_seek[idx % len(
                self.future_frame_seek)]
            template_tracking = self.tracking[idx //
                                              len(self.future_frame_seek)]

        if self.rotation_range or self.scale_range:
            rotation_deg = np.random.uniform(-self.rotation_range,
                                             self.rotation_range) if self.rotation_range else 0.0
            scale = np.random.uniform(
                1 - self.scale_range, 1 + self.scale_range) if self.scale_range else 1.0
            transform = (rotation_deg, scale)
        else:
            transform = None

        template_frame_id = template_tracking["frame_id"]

        if self.temporal_jitter_range > 0:
            jitter = random.randint(-self.temporal_jitter_range,
                                    self.temporal_jitter_range)
            template_frame_id += jitter

        search_frame_id = template_frame_id + frame_jump

        template_smooth_id = template_frame_id - self.min_frame
        search_smooth_id = template_smooth_id + frame_jump

        template_coordinate = self.smooth_trajectory[template_smooth_id]
        search_coordinate = self.smooth_trajectory[search_smooth_id]

        if self.centroid_perturbation_range > 0.0:
            perturbation_x = np.random.uniform(
                -self.centroid_perturbation_range, self.centroid_perturbation_range)
            perturbation_y = np.random.uniform(
                -self.centroid_perturbation_range, self.centroid_perturbation_range)
            template_coordinate = (
                template_coordinate[0] + perturbation_x, template_coordinate[1] + perturbation_y)

        is_negative = random.random() < self.negative_sample_prob

        if is_negative:
            offset_x = random.choice([-1, 1]) * random.randint(
                TrackingDataset._MAX_NEGATIVE_OFFSET // 2, TrackingDataset._MAX_NEGATIVE_OFFSET)
            offset_y = random.choice([-1, 1]) * random.randint(
                TrackingDataset._MAX_NEGATIVE_OFFSET // 2, TrackingDataset._MAX_NEGATIVE_OFFSET)

            template_coordinate = (
                search_coordinate[0] + offset_x,
                search_coordinate[1] + offset_y
            )

            max_frame_seek = max(self.future_frame_seek)
            search_frame_id = search_smooth_id + \
                random.randint(-max_frame_seek, max_frame_seek)

        try:
            template = self._get_crop(
                template_frame_id, template_coordinate, transform)
            search = self._get_crop(
                search_frame_id, template_coordinate, transform)
        except IndexError:
            return self.__getitem__((idx + 1) % len(self))

        to_tensor = transforms.ToTensor()
        template = to_tensor(template)
        search = to_tensor(search)

        # Transform/augment both images with same function.
        if self.augmentation_transform:
            stacked = torch.stack([template, search])
            transformed = self.augmentation_transform(stacked)
            template, search = transformed[0], transformed[1]

        # Normalize the images
        template = self.normalize(template)
        search = self.normalize(search)

        if is_negative:
            heatmap = torch.zeros(
                (1, self.output_size, self.output_size))
        else:
            offset = np.array(search_coordinate) - \
                np.array(template_coordinate)
            search_roi_hit = self.transform_offset_for_heatmap(
                offset, transform)
            heatmap = self.generate_gaussian_heatmap(
                search_roi_hit).unsqueeze(0)

        return (
            template, search, heatmap
        )
transform_offset_for_heatmap(offset, transform)

Apply rotation and scale to an offset vector, then map to heatmap coordinates.

Parameters:

Name Type Description Default
offset

np.ndarray shape (2,), the vector (search - template)

required
transform Tuple[float, float]

Tuple[float, float] = (rotation_deg, scale)

required

Returns:

Type Description

np.ndarray of shape (2,), transformed and rescaled offset in heatmap coordinates

Source code in src\aegear\nn\datasets.py
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
def transform_offset_for_heatmap(self, offset, transform: Tuple[float, float]):
    """
    Apply rotation and scale to an offset vector, then map to heatmap coordinates.

    Args:
        offset: np.ndarray shape (2,), the vector (search - template)
        transform: Tuple[float, float] = (rotation_deg, scale)

    Returns:
        np.ndarray of shape (2,), transformed and rescaled offset in heatmap coordinates
    """

    crop_size = self.crop_size
    output_size = self.output_size

    if transform:
        rotation_deg, scale = transform
        theta = np.deg2rad(rotation_deg)

        # 2D rotation matrix with scale
        R = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta),  np.cos(theta)]
        ]) * scale

        offset = R @ offset

    heatmap_scale = output_size / crop_size
    search_roi_hit = offset * heatmap_scale + output_size // 2

    return search_roi_hit

CachedTrackingDataset

Bases: Dataset

Cached version of TrackingDataset. Loads crops and metadata from disk, avoiding video decoding at runtime. Each sample contains (template, search, heatmap).

Source code in src\aegear\nn\datasets.py
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
class CachedTrackingDataset(Dataset):
    """
    Cached version of TrackingDataset.
    Loads crops and metadata from disk, avoiding video decoding at runtime.
    Each sample contains (template, search, heatmap).
    """

    def __init__(self, root_dir, output_size=128, gaussian_sigma=6.0):
        with open(os.path.join(root_dir, "metadata.json"), 'r') as f:
            self.metadata = json.load(f)["samples"]

        self.root_dir = root_dir
        self.output_size = output_size
        self.gaussian_sigma = gaussian_sigma

        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

    def __len__(self):
        return len(self.metadata)

    def generate_heatmap(self, center):
        x = torch.arange(0, self.output_size).float()
        y = torch.arange(0, self.output_size).float()[:, None]
        x0, y0 = center
        heatmap = torch.exp(-((x - x0)**2 + (y - y0)**2) /
                            (2 * self.gaussian_sigma**2))
        return heatmap.unsqueeze(0)  # Shape: [1, H, W]

    def __getitem__(self, idx):
        item = self.metadata[idx]
        template_path = os.path.join(
            self.root_dir, item["template_path"])
        search_path = os.path.join(self.root_dir, item["search_path"])
        template = self.to_tensor(
            Image.open(template_path).convert("RGB"))
        search = self.to_tensor(Image.open(search_path).convert("RGB"))
        template = self.normalize(template)
        search = self.normalize(search)

        if item.get("background", False):
            heatmap = torch.zeros(
                (1, self.output_size, self.output_size))
        else:
            heatmap = self.generate_heatmap(item["centroid"])

        return template, search, heatmap

BackgroundWindowDataset

Bases: Dataset

Dataset for sampling background (no-fish) windows from a video, using a sliding window approach. The user provides a list of frame indices known to contain only background (no fish present). Each sample is a cropped window from a background frame, with optional augmentation, rotation, and scaling. The output is (image, heatmap), where heatmap is always a zero tensor.

Source code in src\aegear\nn\datasets.py
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
class BackgroundWindowDataset(torch.utils.data.Dataset):
    """
    Dataset for sampling background (no-fish) windows from a video, using a sliding window approach.
    The user provides a list of frame indices known to contain only background (no fish present).
    Each sample is a cropped window from a background frame, with optional augmentation, rotation, and scaling.
    The output is (image, heatmap), where heatmap is always a zero tensor.
    """

    def __init__(
        self,
        video_path: str,
        background_frames: list[int],
        output_size: int = 128,
        crop_size: int = 168,
        siamese: bool = False,
        stride_portion: float = 0.5,
        augmentation_transform=None,
        rotation_range=None,
        scale_range=None,
    ):
        self.video_path = video_path
        self.background_frames = sorted(background_frames)
        self.output_size = output_size
        self.crop_size = crop_size
        self.siamese = siamese
        self.stride_portion = stride_portion
        self.augmentation_transform = augmentation_transform
        self.rotation_range = rotation_range
        self.scale_range = scale_range
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        # Open video and get frame size
        self.video = cv2.VideoCapture(self.video_path)
        if not self.video.isOpened():
            raise Exception(f"Could not open video file: {self.video_path}")
        self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # Precompute all valid (frame, y, x) window positions
        self.samples = []
        stride = max(1, int(self.stride_portion * self.output_size))
        for frame_id in self.background_frames:
            for y in range(0, self.frame_height - self.crop_size + 1, stride):
                for x in range(0, self.frame_width - self.crop_size + 1, stride):
                    self.samples.append((frame_id, y, x))

    def __len__(self):
        return len(self.samples)

    def __del__(self):
        if hasattr(self, 'video') and self.video.isOpened():
            self.video.release()

    def _read_frame(self, frame_id):
        self.video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
        ret, img = self.video.read()
        if not ret:
            raise Exception(
                f"Could not read frame {frame_id} from video {self.video_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img

    def __getitem__(self, idx):
        frame_id, y, x = self.samples[idx]
        # Optionally apply rotation/scale
        if self.rotation_range or self.scale_range:
            rotation_deg = np.random.uniform(-self.rotation_range,
                                             self.rotation_range) if self.rotation_range else 0.0
            scale = np.random.uniform(
                1 - self.scale_range, 1 + self.scale_range) if self.scale_range else 1.0
        else:
            rotation_deg = 0.0
            scale = 1.0
        # Read frame and crop
        frame = self._read_frame(frame_id)
        crop = frame[y:y+self.crop_size, x:x+self.crop_size, :]
        # Apply rotation/scale if needed
        if rotation_deg != 0.0 or scale != 1.0:
            center_point = (self.crop_size // 2, self.crop_size // 2)
            M = cv2.getRotationMatrix2D(center_point, rotation_deg, scale)
            crop = cv2.warpAffine(
                crop, M, (self.crop_size, self.crop_size), flags=cv2.INTER_LINEAR)
        # Final center crop to output_size
        start = self.crop_size // 2 - self.output_size // 2
        end = start + self.output_size
        crop = crop[start:end, start:end, :]
        # To tensor
        crop = transforms.ToTensor()(crop)
        # Augmentation
        if self.augmentation_transform:
            crop = self.augmentation_transform(crop.unsqueeze(0)).squeeze(0)
        crop = self.normalize(crop)
        heatmap = torch.zeros((1, self.output_size, self.output_size))

        if self.siamese:
            # For Siamese networks, return two identical crops
            return crop, crop, heatmap
        else:
            return crop, heatmap

WebTrackingDataset

Bases: IterableDataset

Webdataset-based tracking dataset.

Reads template/search image pairs and metadata from tar files. Each sample contains (template, search, heatmap).

Parameters:

Name Type Description Default
tar_urls

Path or list of paths to tar files (can include wildcards) Examples: - "path/to/tracking-{000000..000009}.tar" - ["path/to/shard1.tar", "path/to/shard2.tar"] - "s3://bucket/tracking-*.tar"

required
output_size int

Size of output heatmap (default: 128)

128
gaussian_sigma float

Sigma for Gaussian heatmap generation (default: 6.0)

6.0
shuffle bool

Whether to shuffle samples (default: True)

True
transform Optional[Callable]

Optional transform to apply to images

None
empty_check bool

If True, raises an error if no samples are found in shards

False
Source code in src\aegear\nn\datasets.py
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
class WebTrackingDataset(IterableDataset):
    """
    Webdataset-based tracking dataset.

    Reads template/search image pairs and metadata from tar files.
    Each sample contains (template, search, heatmap).

    Args:
        tar_urls: Path or list of paths to tar files (can include wildcards)
                  Examples: 
                    - "path/to/tracking-{000000..000009}.tar"
                    - ["path/to/shard1.tar", "path/to/shard2.tar"]
                    - "s3://bucket/tracking-*.tar"
        output_size: Size of output heatmap (default: 128)
        gaussian_sigma: Sigma for Gaussian heatmap generation (default: 6.0)
        shuffle: Whether to shuffle samples (default: True)
        transform: Optional transform to apply to images
        empty_check: If True, raises an error if no samples are found in shards
    """

    def __init__(
        self,
        tar_urls,
        output_size: int = 128,
        gaussian_sigma: float = 6.0,
        shuffle: bool = True,
        transform: Optional[Callable] = None,
        empty_check: bool = False,
        max_samples: Optional[int] = None
    ):
        self.tar_urls = tar_urls
        self.output_size = output_size
        self.gaussian_sigma = gaussian_sigma
        self.shuffle = shuffle
        self.custom_transform = transform
        self.empty_check = empty_check
        self.max_samples = max_samples

        # Standard image preprocessing
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

    def generate_heatmap(self, center):
        """Generate Gaussian heatmap centered at the given position."""
        x = torch.arange(0, self.output_size).float()
        y = torch.arange(0, self.output_size).float()[:, None]
        x0, y0 = center
        heatmap = torch.exp(-((x - x0)**2 + (y - y0)**2) /
                            (2 * self.gaussian_sigma**2))
        return heatmap.unsqueeze(0)  # Shape: [1, H, W]

    def decode_sample(self, sample):
        """
        Decode a webdataset sample into (template, search, heatmap) tuple.

        Args:
            sample: Dictionary containing keys like 'template.jpg', 'search.jpg', 'json'

        Returns:
            Tuple of (template_tensor, search_tensor, heatmap_tensor)
        """
        # Load and preprocess template image
        template_bytes = sample['template.jpg']
        template = Image.open(io.BytesIO(template_bytes)).convert("RGB")
        template = self.to_tensor(template)
        template = self.normalize(template)

        # Load and preprocess search image
        search_bytes = sample['search.jpg']
        search = Image.open(io.BytesIO(search_bytes)).convert("RGB")
        search = self.to_tensor(search)
        search = self.normalize(search)

        # Load metadata
        metadata = json.loads(sample['json'])

        # Generate heatmap
        if metadata.get("background", False):
            heatmap = torch.zeros((1, self.output_size, self.output_size))
        else:
            heatmap = self.generate_heatmap(metadata["centroid"])

        return template, search, heatmap

    def __iter__(self):
        """Create an iterator over the dataset."""
        # Create webdataset pipeline
        dataset = wds.WebDataset(self.tar_urls, shardshuffle=False, empty_check=self.empty_check)

        # Optionally shuffle
        if self.shuffle:
            dataset = dataset.shuffle(1000)  # Shuffle buffer of 1000 samples

        # Map to our format
        dataset = dataset.map(self.decode_sample)

        # Limit samples AFTER decoding to ensure hard limit per worker
        if self.max_samples is not None:
            # Get worker info to divide samples among workers
            worker_info = torch.utils.data.get_worker_info()
            if worker_info is not None:
                # Multiple workers: each worker gets a portion of max_samples
                per_worker = int(np.ceil(self.max_samples / worker_info.num_workers))
                return itertools.islice(iter(dataset), per_worker)
            else:
                # Single worker: use all max_samples
                return itertools.islice(iter(dataset), self.max_samples)

        return iter(dataset)
generate_heatmap(center)

Generate Gaussian heatmap centered at the given position.

Source code in src\aegear\nn\datasets.py
1107
1108
1109
1110
1111
1112
1113
1114
def generate_heatmap(self, center):
    """Generate Gaussian heatmap centered at the given position."""
    x = torch.arange(0, self.output_size).float()
    y = torch.arange(0, self.output_size).float()[:, None]
    x0, y0 = center
    heatmap = torch.exp(-((x - x0)**2 + (y - y0)**2) /
                        (2 * self.gaussian_sigma**2))
    return heatmap.unsqueeze(0)  # Shape: [1, H, W]
decode_sample(sample)

Decode a webdataset sample into (template, search, heatmap) tuple.

Parameters:

Name Type Description Default
sample

Dictionary containing keys like 'template.jpg', 'search.jpg', 'json'

required

Returns:

Type Description

Tuple of (template_tensor, search_tensor, heatmap_tensor)

Source code in src\aegear\nn\datasets.py
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
def decode_sample(self, sample):
    """
    Decode a webdataset sample into (template, search, heatmap) tuple.

    Args:
        sample: Dictionary containing keys like 'template.jpg', 'search.jpg', 'json'

    Returns:
        Tuple of (template_tensor, search_tensor, heatmap_tensor)
    """
    # Load and preprocess template image
    template_bytes = sample['template.jpg']
    template = Image.open(io.BytesIO(template_bytes)).convert("RGB")
    template = self.to_tensor(template)
    template = self.normalize(template)

    # Load and preprocess search image
    search_bytes = sample['search.jpg']
    search = Image.open(io.BytesIO(search_bytes)).convert("RGB")
    search = self.to_tensor(search)
    search = self.normalize(search)

    # Load metadata
    metadata = json.loads(sample['json'])

    # Generate heatmap
    if metadata.get("background", False):
        heatmap = torch.zeros((1, self.output_size, self.output_size))
    else:
        heatmap = self.generate_heatmap(metadata["centroid"])

    return template, search, heatmap

WebTrackingDatasetWithLength

Bases: WebTrackingDataset

Extended version with approximate length for DataLoader compatibility.

This is useful when you need a DataLoader with a known length for progress bars or epoch-based training.

Parameters:

Name Type Description Default
tar_urls

Path or list of paths to tar files

required
length int

Total number of samples in the dataset

required
output_size int

Size of output heatmap (default: 128)

128
gaussian_sigma float

Sigma for Gaussian heatmap generation (default: 6.0)

6.0
shuffle bool

Whether to shuffle samples (default: True)

True
transform Optional[Callable]

Optional transform to apply to images

None
Source code in src\aegear\nn\datasets.py
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
class WebTrackingDatasetWithLength(WebTrackingDataset):
    """
    Extended version with approximate length for DataLoader compatibility.

    This is useful when you need a DataLoader with a known length for
    progress bars or epoch-based training.

    Args:
        tar_urls: Path or list of paths to tar files
        length: Total number of samples in the dataset
        output_size: Size of output heatmap (default: 128)
        gaussian_sigma: Sigma for Gaussian heatmap generation (default: 6.0)
        shuffle: Whether to shuffle samples (default: True)
        transform: Optional transform to apply to images
    """

    def __init__(
        self,
        tar_urls,
        length: int,
        output_size: int = 128,
        gaussian_sigma: float = 6.0,
        shuffle: bool = True,
        transform: Optional[Callable] = None,
        empty_check: bool = False
    ):
        super().__init__(tar_urls, output_size, gaussian_sigma, shuffle, transform, empty_check, max_samples=length)
        self._length = length

    def __len__(self):
        """Return the approximate length of the dataset."""
        return self._length

split_coco_annotations(coco_json_path, train_ratio=0.8, seed=42)

Loads a COCO JSON and splits it into train/val dictionaries based on image-level split.

Parameters:

Name Type Description Default
coco_json_path Path

Path to the COCO annotations.json.

required
train_ratio float

Ratio of images to assign to the training set.

0.8
seed int

Random seed for reproducibility.

42

Returns:

Type Description
Tuple[dict, dict]

Tuple[dict, dict]: (train_dict, val_dict)

Source code in src\aegear\nn\datasets.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def split_coco_annotations(
    coco_json_path: Path,
    train_ratio: float = 0.8,
    seed: int = 42
) -> Tuple[dict, dict]:
    """
    Loads a COCO JSON and splits it into train/val dictionaries based on image-level split.

    Args:
        coco_json_path (Path): Path to the COCO annotations.json.
        train_ratio (float): Ratio of images to assign to the training set.
        seed (int): Random seed for reproducibility.

    Returns:
        Tuple[dict, dict]: (train_dict, val_dict)
    """
    with open(coco_json_path, 'r') as f:
        coco = json.load(f)

    images = coco["images"]
    annotations = coco["annotations"]
    categories = coco["categories"]

    # Reproducible shuffle
    random.seed(seed)
    shuffled_images = images[:]
    random.shuffle(shuffled_images)

    split_idx = int(len(shuffled_images) * train_ratio)
    train_images = shuffled_images[:split_idx]
    val_images = shuffled_images[split_idx:]

    train_img_ids = {img["id"] for img in train_images}
    val_img_ids = {img["id"] for img in val_images}

    # Filter annotations
    train_annotations = [
        ann for ann in annotations if ann["image_id"] in train_img_ids]
    val_annotations = [
        ann for ann in annotations if ann["image_id"] in val_img_ids]

    train_dict = {
        "images": train_images,
        "annotations": train_annotations,
        "categories": categories
    }

    val_dict = {
        "images": val_images,
        "annotations": val_annotations,
        "categories": categories
    }

    return train_dict, val_dict

fetch_shard_dataset(output_dir, verbose=True)

Fetch the shards dataset from GCS to a given directory

Source code in src\aegear\nn\datasets.py
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
def fetch_shard_dataset(output_dir: str, verbose=True):
    """Fetch the shards dataset from GCS to a given directory"""

    # If output_dir does not exist, create it
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Check if we have write rights on this directory and that its empty
    if not os.access(output_dir, os.W_OK):
        raise PermissionError(f"Cannot write to directory: {output_dir}")

    if os.listdir(output_dir):
        raise FileExistsError(f"Directory is not empty: {output_dir}")

    gcs_shards_dir = "shards"

    # Initialize a guest client

    client = storage.Client.create_anonymous_client()

    bucket = client.bucket("aegear-training-data")
    blobs = bucket.list_blobs(prefix=gcs_shards_dir)
    blobs = tqdm(blobs, desc="Downloading shards") if verbose else blobs

    # Tarballs and the manifest file
    for blob in blobs: 
        if not blob.name.endswith(".tar") and not blob.name.endswith(".json"):
            continue

        # Download to output_dir
        destination_path = os.path.join(output_dir, os.path.basename(blob.name))
        blob.download_to_filename(destination_path)

set_seed(seed=42)

Set random seed for reproducibility.

Source code in src\aegear\nn\datasets.py
1305
1306
1307
1308
1309
1310
1311
1312
1313
def set_seed(seed: int = 42):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # For deterministic behavior (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

split_shards_train_val(shard_dir, train_ratio=0.8, seed=42)

Split tar files into train and validation sets with a predictable seed.

Parameters:

Name Type Description Default
shard_dir str

Directory containing the tar files

required
train_ratio float

Ratio of data to use for training (default: 0.8)

0.8
seed int

Random seed for reproducibility

42

Returns:

Type Description
tuple

Tuple of (train_tar_files, val_tar_files)

Source code in src\aegear\nn\datasets.py
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
def split_shards_train_val(
    shard_dir: str,
    train_ratio: float = 0.8,
    seed: int = 42
) -> tuple:
    """
    Split tar files into train and validation sets with a predictable seed.

    Args:
        shard_dir: Directory containing the tar files
        train_ratio: Ratio of data to use for training (default: 0.8)
        seed: Random seed for reproducibility

    Returns:
        Tuple of (train_tar_files, val_tar_files)
    """
    # Get all tar files
    tar_files = sorted(Path(shard_dir).glob("*.tar"))
    tar_files = [str(f) for f in tar_files]

    # Set seed for reproducible splitting
    random.seed(seed)
    random.shuffle(tar_files)

    # Split
    split_idx = int(len(tar_files) * train_ratio)
    train_files = tar_files[:split_idx]
    val_files = tar_files[split_idx:]

    print(f"Split {len(tar_files)} shards into:")
    print(f"  Training: {len(train_files)} shards")
    print(f"  Validation: {len(val_files)} shards")

    return train_files, val_files

create_webdataset_from_manifest(manifest_path, output_size=128, gaussian_sigma=6.0, train_ratio=0.8, seed=42, autodownload=True, verbose=True)

Create train and validation WebTrackingDatasets from a manifest file.

Parameters:

Name Type Description Default
manifest_path str

Path to the manifest JSON file

required
output_size int

Size of output heatmap

128
gaussian_sigma float

Sigma for Gaussian heatmap generation

6.0
train_ratio float

Ratio of data to use for training (default: 0.8)

0.8
seed int

Random seed for reproducible splitting (default: 42)

42
autodownload bool

Whether to auto-download shards if not present

True
verbose bool

Whether to print download progress

True

Returns:

Type Description
Tuple[WebTrackingDatasetWithLength, WebTrackingDatasetWithLength]

Tuple of (train_dataset, val_dataset) as WebTrackingDatasetWithLength instances

Raises:

Type Description
ValueError

If number of tar files does not match num_shards in manifest

Source code in src\aegear\nn\datasets.py
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
def create_webdataset_from_manifest(
    manifest_path: str,
    output_size: int = 128,
    gaussian_sigma: float = 6.0,
    train_ratio: float = 0.8,
    seed: int = 42,
    autodownload: bool = True,
    verbose: bool = True
) -> Tuple[WebTrackingDatasetWithLength, WebTrackingDatasetWithLength]:
    """
    Create train and validation WebTrackingDatasets from a manifest file.

    Args:
        manifest_path: Path to the manifest JSON file
        output_size: Size of output heatmap
        gaussian_sigma: Sigma for Gaussian heatmap generation
        train_ratio: Ratio of data to use for training (default: 0.8)
        seed: Random seed for reproducible splitting (default: 42)
        autodownload: Whether to auto-download shards if not present
        verbose: Whether to print download progress

    Returns:
        Tuple of (train_dataset, val_dataset) as WebTrackingDatasetWithLength instances

    Raises:
        ValueError: If number of tar files does not match num_shards in manifest
    """

    # Get directory of manifest to build tar file paths
    data_dir = os.path.dirname(manifest_path)

    # If manifest or tar files do not exist, auto-download
    if autodownload and (not os.path.exists(manifest_path) or not any(Path(data_dir).glob("*.tar"))):
        if verbose:
            print("Manifest or tar files not found. Auto-downloading shards...")
        fetch_shard_dataset(data_dir, verbose=verbose)

    # Load manifest
    with open(manifest_path, 'r') as f:
        manifest = json.load(f)

    # Check if the same directory contains the tar files
    tar_files = sorted(Path(data_dir).glob("*.tar"))
    tar_files = [str(f) for f in tar_files]

    num_shards = manifest['num_shards']
    total_samples = manifest['total_samples']

    # Check num_shards matching with tar files found
    if len(tar_files) != num_shards:
        raise ValueError(
            f"Number of tar files found ({len(tar_files)}) does not match num_shards in manifest ({num_shards})")

    # Split tar files into train/val with seed
    random.seed(seed)
    random.shuffle(tar_files)

    split_idx = int(len(tar_files) * train_ratio)
    train_files = tar_files[:split_idx]
    val_files = tar_files[split_idx:]

    print(f"Split {len(tar_files)} shards into:")
    print(f"  Training: {len(train_files)} shards")
    print(f"  Validation: {len(val_files)} shards")

    # Calculate approximate samples per split
    samples_per_shard = total_samples / num_shards
    train_samples = int(len(train_files) * samples_per_shard)
    val_samples = int(len(val_files) * samples_per_shard)

    print(f"  Approximate samples - Train: {train_samples}, Val: {val_samples}")

    # Create train dataset
    train_dataset = WebTrackingDatasetWithLength(
        tar_urls=train_files,
        length=train_samples,
        output_size=output_size,
        gaussian_sigma=gaussian_sigma,
        shuffle=True,  # Shuffle training data
        empty_check=False
    )

    # Create validation dataset
    val_dataset = WebTrackingDatasetWithLength(
        tar_urls=val_files,
        length=val_samples,
        output_size=output_size,
        gaussian_sigma=gaussian_sigma,
        shuffle=False,  # Don't shuffle validation data
        empty_check=False
    )

    return train_dataset, val_dataset

load_dataset_from_shards(manifest_path, output_size=128, gaussian_sigma=6.0, batch_size=128, train_ratio=0.8, num_workers=0, seed=42, autodownload=True, verbose=True)

Load training and validation datasets from tar shards.

Parameters:

Name Type Description Default
manifest_path str

Path to the manifest JSON file

required
output_size int

Size of output heatmap

128
gaussian_sigma float

Sigma for Gaussian heatmap generation

6.0
batch_size int

Batch size for DataLoader

128
train_ratio float

Ratio of data to use for training

0.8
num_workers int

Number of DataLoader workers

0
seed int

Random seed for reproducibility

42
autodownload bool

Whether to auto-download shards if not present

True
verbose bool

Whether to print download progress

True

Returns: Tuple of (train_dataset, val_dataset)

Source code in src\aegear\nn\datasets.py
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
def load_dataset_from_shards(manifest_path: str,
                             output_size: int = 128,
                             gaussian_sigma: float = 6.0,
                             batch_size: int = 128,
                             train_ratio: float = 0.8,
                             num_workers: int = 0,
                             seed: int = 42,
                             autodownload: bool = True,
                             verbose: bool = True) -> tuple:
    """
    Load training and validation datasets from tar shards.

    Args:
        manifest_path: Path to the manifest JSON file
        output_size: Size of output heatmap
        gaussian_sigma: Sigma for Gaussian heatmap generation
        batch_size: Batch size for DataLoader
        train_ratio: Ratio of data to use for training
        num_workers: Number of DataLoader workers
        seed: Random seed for reproducibility
        autodownload: Whether to auto-download shards if not present
        verbose: Whether to print download progress
    Returns:
        Tuple of (train_dataset, val_dataset)
    """
    train_dataset, val_dataset = create_webdataset_from_manifest(
        manifest_path=manifest_path,
        output_size=output_size,
        gaussian_sigma=gaussian_sigma,
        train_ratio=train_ratio,
        seed=seed,
        autodownload=autodownload,
        verbose=verbose
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=num_workers
    )

    return train_loader, val_loader

model

CBAM

Bases: Module

Lightweight convolutional block attention module (CBAM) for channel and spatial attention.

Source code in src\aegear\nn\model.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class CBAM(nn.Module):
    """Lightweight convolutional block attention module (CBAM) for channel and spatial attention."""

    def __init__(self, in_channels):
        super().__init__()
        # Channel attention
        self.channel = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // 8, 1),
            nn.SiLU(inplace=True),
            nn.Conv2d(in_channels // 8, in_channels, 1),
            nn.Sigmoid()
        )
        # Spatial attention
        self.spatial = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Channel attention
        ca = self.channel(x)
        x = x * ca

        # Spatial attention
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        sa = self.spatial(torch.cat([max_pool, avg_pool], dim=1))
        return x * sa

EfficientUNet

Bases: Module

EfficientUNet backbone based on EfficientNet-B0, enhanced with CBAM (Convolutional Block Attention Module) attention blocks after each encoder and decoder stage.

The architecture removes the deepest (last) encoder and decoder stages compared to a standard UNet, resulting in a lighter model with fewer parameters and reduced memory usage, while retaining strong feature extraction and localization capabilities.

CBAM modules are used to improve feature representation by applying both channel and spatial attention at multiple levels of the network, allowing the model to focus on the object of interest while ignoring irrelevant information. This is particularly useful in scenarios where the object of interest (e.g., fish) may be small and difficult to distinguish from the background, or when there are multiple objects present in the image.

Source code in src\aegear\nn\model.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class EfficientUNet(nn.Module):
    """
    EfficientUNet backbone based on EfficientNet-B0, enhanced with CBAM
    (Convolutional Block Attention Module) attention blocks after each encoder
    and decoder stage.

    The architecture removes the deepest (last) encoder and
    decoder stages compared to a standard UNet, resulting in a lighter model
    with fewer parameters and reduced memory usage, while retaining strong
    feature extraction and localization capabilities.

    CBAM modules are used to improve feature representation by applying both
    channel and spatial attention at multiple levels of the network, allowing
    the model to focus on the object of interest while ignoring irrelevant information.
    This is particularly useful in scenarios where the object of interest (e.g., fish)
    may be small and difficult to distinguish from the background, or when there
    are multiple objects present in the image.
    """

    def __init__(self, weights=None, use_cbam=True, activation=nn.SiLU):
        super().__init__()
        self.use_cbam = use_cbam
        self.activation = activation
        backbone = efficientnet_b0(weights=weights)
        features = list(backbone.features.children())

        # Encoder stages
        self.enc1 = nn.Sequential(*features[:2])  # Output: 16 ch, S/2
        self.enc2 = nn.Sequential(*features[2:3])  # Output: 24 ch, S/4
        self.enc3 = nn.Sequential(*features[3:4])  # Output: 40 ch, S/8
        self.enc4 = nn.Sequential(*features[4:5])  # Output: 80 ch, S/16
        self.enc5 = nn.Sequential(*features[5:6])  # Output: 112 ch, S/16

        # Bottleneck with dilated convs.
        self.bottleneck = nn.Sequential(
            nn.Conv2d(112, 256, kernel_size=3, padding=2, dilation=2),
            nn.BatchNorm2d(256),
            _make_activation(activation),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            _make_activation(activation),
        )

        self.att_bottleneck = CBAM(256) if use_cbam else nn.Identity()

        # Decoder with CBAM after skip merges
        self.att4 = CBAM(256 + 112) if use_cbam else nn.Identity()
        self.up4 = self._conf_block(256 + 112, 64)

        self.att3 = CBAM(64 + 80) if use_cbam else nn.Identity()
        self.up3 = self._up_block(64 + 80, 32)

        self.att2 = CBAM(32 + 40) if use_cbam else nn.Identity()
        self.up2 = self._up_block(32 + 40, 24)

        self.att1 = CBAM(24 + 24) if use_cbam else nn.Identity()
        self.up1 = self._up_block(24 + 24, 16)

        self.att0 = CBAM(16 + 16) if use_cbam else nn.Identity()
        self.up0 = self._up_block(16 + 16, 8)

        # Final 1-channel output
        self.out = nn.Conv2d(8, 1, kernel_size=1)

    def _up_block(self, in_ch, out_ch):
        act = self.activation
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            _make_activation(act),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            _make_activation(act),
        )

    def _conf_block(self, in_ch, out_ch):
        act = self.activation
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            _make_activation(act),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            _make_activation(act),
        )

    def forward(self, x):
        return self.forward_with_decoded(x)[0]

    def forward_with_decoded(self, x):
        # Encoder
        x1 = self.enc1(x)  # S/2
        x2 = self.enc2(x1)  # S/4
        x3 = self.enc3(x2)  # S/8
        x4 = self.enc4(x3)  # S/16
        x5 = self.enc5(x4)  # S/16

        b = self.bottleneck(x5)
        b = self.att_bottleneck(b)

        # Decoder
        d4_cat = torch.cat([b, x5], dim=1)
        d4_att = self.att4(d4_cat)
        d4 = self.up4(d4_att)

        d3_cat = torch.cat([d4, x4], dim=1)
        d3_att = self.att3(d3_cat)
        d3 = self.up3(d3_att)

        d2_cat = torch.cat([d3, x3], dim=1)
        d2_att = self.att2(d2_cat)
        d2 = self.up2(d2_att)

        d1_cat = torch.cat([d2, x2], dim=1)
        d1_att = self.att1(d1_cat)
        d1 = self.up1(d1_att)

        d0_cat = torch.cat([d1, x1], dim=1)
        d0_att = self.att0(d0_cat)
        d0 = self.up0(d0_att)

        # Final output
        out = self.out(d0)

        # Resize to original input size
        out = F.interpolate(out,
                            size=x.shape[2:],
                            mode='bilinear',
                            align_corners=False)

        return out, d0

SiameseTracker

Bases: Module

Siamese UNet model for tracking, based on EfficientUNet.

This model is designed to take two inputs: a template image and a search image. The template image is the reference image of the object to be tracked, while the search image is the current frame in which the object is being searched for. The model processes both images through a shared UNet architecture, extracting features from both images and then concatenating them at each stage of the decoder. This allows the model to leverage the spatial information from both images, improving the tracking performance.

Source code in src\aegear\nn\model.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
class SiameseTracker(nn.Module):
    """
    Siamese UNet model for tracking, based on EfficientUNet.

    This model is designed to take two inputs: a template image and a search
    image. The template image is the reference image of the object to be
    tracked, while the search image is the current frame in which the object
    is being searched for. The model processes both images through a shared
    UNet architecture, extracting features from both images and then
    concatenating them at each stage of the decoder. This allows the model to
    leverage the spatial information from both images, improving the
    tracking performance.
    """

    def __init__(self, unet=EfficientUNet()):
        super().__init__()
        # Share encoder stages from the UNet
        self.enc1 = unet.enc1
        self.enc2 = unet.enc2
        self.enc3 = unet.enc3
        self.enc4 = unet.enc4
        self.enc5 = unet.enc5

        # Share bottleneck and bottleneck attention
        self.bottleneck = unet.bottleneck
        self.att_bottleneck = unet.att_bottleneck

        # Decoder blocks with adjusted input channel sizes for concatenated Siamese features
        # The input channels to att/up blocks will be double the UNet's combined input

        self.att4 = CBAM(256 * 2 + 112 * 2)
        self.up4 = unet._conf_block(256 * 2 + 112 * 2, 64)

        self.att3 = CBAM(64 + 80 * 2)
        self.up3 = unet._up_block(64 + 80 * 2, 32)

        self.att2 = CBAM(32 + 40 * 2)
        self.up2 = unet._up_block(32 + 40 * 2, 24)

        self.att1 = CBAM(24 + 24 * 2)
        self.up1 = unet._up_block(24 + 24 * 2, 16)

        self.att0 = CBAM(16 + 16 * 2)
        self.up0 = unet._up_block(16 + 16 * 2, 8)

        # Re-use the output layer from UNet
        self.out = unet.out

    def forward(self, template, search):
        # Encoder
        t1 = self.enc1(template)  # S/2
        s1 = self.enc1(search)

        t2 = self.enc2(t1)  # S/4
        s2 = self.enc2(s1)

        t3 = self.enc3(t2)  # S/8
        s3 = self.enc3(s2)

        t4 = self.enc4(t3)  # S/16
        s4 = self.enc4(s3)

        t5 = self.enc5(t4)  # S/16
        s5 = self.enc5(s4)

        # Bottleneck with attention.
        b_t = self.bottleneck(t5)
        b_s = self.bottleneck(s5)
        b_t_att = self.att_bottleneck(b_t)
        b_s_att = self.att_bottleneck(b_s)

        fused_bottleneck = torch.cat(
            [b_t_att, b_s_att], dim=1)

        # Decoder
        d4_cat = torch.cat(
            [fused_bottleneck, torch.cat([t5, s5], dim=1)], dim=1)
        d4_att = self.att4(d4_cat)
        d4_fused = self.up4(d4_att)

        d3_cat = torch.cat([d4_fused, torch.cat([t4, s4], dim=1)], dim=1)
        d3_att = self.att3(d3_cat)
        d3_fused = self.up3(d3_att)

        d2_cat = torch.cat([d3_fused, torch.cat([t3, s3], dim=1)], dim=1)
        d2_att = self.att2(d2_cat)
        d2_fused = self.up2(d2_att)

        d1_cat = torch.cat([d2_fused, torch.cat([t2, s2], dim=1)], dim=1)
        d1_att = self.att1(d1_cat)
        d1_fused = self.up1(d1_att)

        d0_cat = torch.cat([d1_fused, torch.cat([t1, s1], dim=1)], dim=1)
        d0_att = self.att0(d0_cat)
        d0_fused = self.up0(d0_att)

        out = self.out(d0_fused)
        return F.interpolate(out, size=template.shape[2:], mode='bilinear', align_corners=False)

ConvClassifier

Bases: Module

A simple convolutional network for binary classification. This model is designed to classify whether a fish is present in a given region of interest (ROI) of the image.

Source code in src\aegear\nn\model.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class ConvClassifier(nn.Module):
    """
    A simple convolutional network for binary classification.
    This model is designed to classify whether a fish is present in a given
    region of interest (ROI) of the image.
    """
    # Size of the region of interest (ROI) for classification.
    ROI_SIZE = 64

    def __init__(self):
        super(ConvClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * (ConvClassifier.ROI_SIZE // 8) ** 2, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten layer
        x = F.relu(self.fc1(x))
        return torch.sigmoid(self.fc2(x))

ops

RunPodLauncher

Manages RunPod pod lifecycle for training jobs.

Source code in src\aegear\nn\ops\runpod_launcher.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
class RunPodLauncher:
    """Manages RunPod pod lifecycle for training jobs."""
    RUNPOD_API_BASE = "https://api.runpod.io/graphql"
    DEFAULT_IMAGE = "docker.io/ljubobratovicrelja/aegear:latest"
    DEFAULT_GPU_TYPE = "NVIDIA GeForce RTX 5090"
    DEFAULT_VOLUME_SIZE = 10
    DEFAULT_CONTAINER_DISK_SIZE = 8

    def __init__(self, api_token: str, docker_username: Optional[str] = None, docker_pat: Optional[str] = None):
        self.api_token = api_token
        self.docker_username = docker_username
        self.docker_pat = docker_pat
        self.session = requests.Session()
        self.session.headers.update({
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_token}"
        })

    def _graphql_query(self, query: str, variables: Optional[Dict] = None) -> Dict:
        payload = {"query": query}
        if variables:
            payload["variables"] = variables
        response = self.session.post(self.RUNPOD_API_BASE, json=payload)
        response.raise_for_status()
        result = response.json()
        if "errors" in result:
            raise RuntimeError(f"GraphQL error: {result['errors']}")
        return result.get("data", {})

    def get_gpu_types(self) -> list:
        query = """
        query GpuTypes {
            gpuTypes {
                id
                displayName
                memoryInGb
            }
        }
        """
        try:
            result = self._graphql_query(query)
            return result.get("gpuTypes", [])
        except Exception as e:
            print(f"Warning: Could not fetch GPU types: {e}")
            return []

    def create_container_registry_auth(self) -> Optional[str]:
        if not self.docker_username or not self.docker_pat:
            return None
        query = """
        mutation SaveRegistryAuth($input: SaveRegistryAuthInput!) {
            saveRegistryAuth(input: $input) {
                id
                name
            }
        }
        """
        auth_name = f"dockerhub_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        variables = {
            "input": {
                "name": auth_name,
                "username": self.docker_username,
                "password": self.docker_pat
            }
        }
        try:
            result = self._graphql_query(query, variables)
            auth_id = result.get("saveRegistryAuth", {}).get("id")
            if auth_id:
                print(f"Created Docker Hub authentication: {auth_id}")
            return auth_id
        except Exception as e:
            print(f"Warning: Failed to create registry auth: {e}")
            print("  Proceeding without authentication (may hit Docker Hub rate limits)")
            return None

    def launch_pod(self, task_name: str, env_vars: Dict[str, str], gpu_type: str = DEFAULT_GPU_TYPE, gpu_count: int = 1, volume_size: int = DEFAULT_VOLUME_SIZE, container_disk_size: int = DEFAULT_CONTAINER_DISK_SIZE, image_name: str = DEFAULT_IMAGE) -> str:
        registry_auth_id = self.create_container_registry_auth()
        env_vars["RUNPOD_API_KEY"] = self.api_token
        query = """
        mutation PodFindAndDeployOnDemand($input: PodFindAndDeployOnDemandInput!) {
            podFindAndDeployOnDemand(input: $input) {
                id
                desiredStatus
                imageName
                env
                machineId
                machine {
                    gpuDisplayName
                }
            }
        }
        """
        pod_env = [
            {"key": k, "value": str(v)}
            for k, v in env_vars.items()
        ]
        variables = {
            "input": {
                "name": task_name,
                "imageName": image_name,
                "gpuTypeId": gpu_type,
                "gpuCount": gpu_count,
                "volumeInGb": volume_size,
                "containerDiskInGb": container_disk_size,
                "volumeMountPath": "/workspace",
                "cloudType": "SECURE",
                "env": pod_env,
                "startSsh": False,
                "startJupyter": False,
            }
        }
        if registry_auth_id:
            variables["input"]["containerRegistryAuthId"] = registry_auth_id
        print(f"\nLaunching pod: {task_name}")
        print(f"   Image: {image_name}")
        print(f"   GPU: {gpu_type} x{gpu_count}")
        print(f"   Volume: {volume_size}GB, Container Disk: {container_disk_size}GB")
        try:
            result = self._graphql_query(query, variables)
        except Exception as e:
            print(f"\nFailed to create pod. Request details:")
            print(f"   GPU Type: {variables['input'].get('gpuTypeId', 'Not specified')}")
            print(f"   Cloud Type: {variables['input']['cloudType']}")
            raise
        pod = result.get("podFindAndDeployOnDemand", {})
        pod_id = pod.get("id")
        if not pod_id:
            raise RuntimeError("Failed to create pod - no ID returned")
        print(f"Pod created: {pod_id}")
        print(f"   Machine: {pod.get('machine', {}).get('gpuDisplayName', 'Unknown')}")
        print(f"   Self-termination: RUNPOD_POD_ID and RUNPOD_API_KEY are available in container")
        print(f"   Container will auto-terminate on training completion")
        return pod_id

    def get_pod_status(self, pod_id: str) -> Dict[str, Any]:
        query = """
        query Pod($input: PodFilter!) {
            pod(input: $input) {
                id
                desiredStatus
                runtime {
                    uptimeInSeconds
                    ports {
                        ip
                        isIpPublic
                        privatePort
                        publicPort
                        type
                    }
                    gpus {
                        id
                        gpuUtilPercent
                        memoryUtilPercent
                    }
                }
            }
        }
        """
        variables = {"input": {"podId": pod_id}}
        result = self._graphql_query(query, variables)
        return result.get("pod", {})

    def get_pod_logs(self, pod_id: str, lines: int = 100) -> str:
        try:
            endpoint = f"https://api.runpod.io/v1/pods/{pod_id}/logs"
            response = self.session.get(endpoint, params={"lines": lines})
            if response.status_code == 200:
                return response.text
        except:
            pass
        return ""

    def get_pod_exit_code(self, pod_id: str) -> int:
        """
        Extract exit code from pod logs.

        Looks for exit code patterns in the last lines of logs.
        Returns 0 if no exit code found or if pod terminated successfully.
        Returns 42 if CUDA validation failed (machine issue).
        Returns 1 for other failures.

        Args:
            pod_id: The pod ID to check

        Returns:
            Exit code (0, 1, 42, etc.)
        """
        logs = self.get_pod_logs(pod_id, lines=200)

        if not logs:
            # No logs available, assume success if pod terminated
            return 0

        # Check for explicit exit code 42 (CUDA unavailable)
        if "exit 42" in logs.lower() or "exiting with code 42" in logs.lower():
            return 42

        # Check for CUDA validation failure messages
        if "CUDA VALIDATION FAILED" in logs:
            return 42

        if "DEVICE VALIDATION FAILED" in logs:
            return 42

        # Check for successful completion
        if "Training completed successfully" in logs:
            return 0

        # Check for training failure
        if "Training failed" in logs:
            return 1

        # Default to success if pod terminated cleanly
        return 0

    def terminate_pod(self, pod_id: str):
        query = """
        mutation PodTerminate($input: PodTerminateInput!) {
            podTerminate(input: $input)
        }
        """
        variables = {"input": {"podId": pod_id}}
        print(f"\nTerminating pod: {pod_id}")
        result = self._graphql_query(query, variables)
        if result.get("podTerminate"):
            print("Pod terminated successfully")
        else:
            print("Pod termination request sent (status unclear)")

    def monitor_pod(self, pod_id: str, check_interval: int = 60, timeout_hours: int = 24, auto_terminate: bool = True):
        start_time = time.time()
        timeout_seconds = timeout_hours * 3600
        print(f"\nMonitoring pod: {pod_id}")
        print(f"   Check interval: {check_interval}s")
        print(f"   Timeout: {timeout_hours}h")
        print(f"   Auto-terminate: {auto_terminate}")
        print("\n" + "="*60)
        try:
            while True:
                elapsed = time.time() - start_time
                if elapsed > timeout_seconds:
                    print(f"\nTimeout reached ({timeout_hours}h)")
                    if auto_terminate:
                        self.terminate_pod(pod_id)
                    return False
                try:
                    status = self.get_pod_status(pod_id)
                    if status is None or not status:
                        print(f"\n\nPod has been terminated (no longer found in API)")
                        print("   Training completed and pod self-terminated successfully!")
                        return True
                    desired_status = status.get("desiredStatus", "UNKNOWN")
                    runtime = status.get("runtime")
                    if desired_status in ["EXITED", "STOPPED", "TERMINATED"]:
                        print(f"\n\nPod stopped with status: {desired_status}")
                        return True
                    if runtime:
                        uptime = runtime.get("uptimeInSeconds", 0)
                        gpus = runtime.get("gpus", [])
                        print(f"\r[{time.strftime('%H:%M:%S')}] "
                              f"Status: {desired_status} | "
                              f"Uptime: {uptime//3600}h {(uptime%3600)//60}m | "
                              f"Elapsed: {elapsed//3600:.0f}h {(elapsed%3600)//60:.0f}m",
                              end="", flush=True)
                    else:
                        print(f"\r[{time.strftime('%H:%M:%S')}] "
                              f"Status: {desired_status} | "
                              f"Waiting for runtime...",
                              end="", flush=True)
                except Exception as e:
                    print(f"\n\nPod terminated (API query failed: {e})")
                    return True
                time.sleep(check_interval)
        except KeyboardInterrupt:
            print(f"\n\nMonitoring interrupted by user")
            if auto_terminate:
                print("Terminating pod...")
                self.terminate_pod(pod_id)
            return False
        finally:
            if auto_terminate:
                try:
                    self.terminate_pod(pod_id)
                except:
                    pass
get_pod_exit_code(pod_id)

Extract exit code from pod logs.

Looks for exit code patterns in the last lines of logs. Returns 0 if no exit code found or if pod terminated successfully. Returns 42 if CUDA validation failed (machine issue). Returns 1 for other failures.

Parameters:

Name Type Description Default
pod_id str

The pod ID to check

required

Returns:

Type Description
int

Exit code (0, 1, 42, etc.)

Source code in src\aegear\nn\ops\runpod_launcher.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def get_pod_exit_code(self, pod_id: str) -> int:
    """
    Extract exit code from pod logs.

    Looks for exit code patterns in the last lines of logs.
    Returns 0 if no exit code found or if pod terminated successfully.
    Returns 42 if CUDA validation failed (machine issue).
    Returns 1 for other failures.

    Args:
        pod_id: The pod ID to check

    Returns:
        Exit code (0, 1, 42, etc.)
    """
    logs = self.get_pod_logs(pod_id, lines=200)

    if not logs:
        # No logs available, assume success if pod terminated
        return 0

    # Check for explicit exit code 42 (CUDA unavailable)
    if "exit 42" in logs.lower() or "exiting with code 42" in logs.lower():
        return 42

    # Check for CUDA validation failure messages
    if "CUDA VALIDATION FAILED" in logs:
        return 42

    if "DEVICE VALIDATION FAILED" in logs:
        return 42

    # Check for successful completion
    if "Training completed successfully" in logs:
        return 0

    # Check for training failure
    if "Training failed" in logs:
        return 1

    # Default to success if pod terminated cleanly
    return 0

PodManager

Manages RunPod pods with listing and termination capabilities.

Source code in src\aegear\nn\ops\runpod_launcher.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
class PodManager:
    """Manages RunPod pods with listing and termination capabilities."""

    RUNPOD_API_BASE = "https://api.runpod.io/graphql"

    # Status descriptions for user-friendly explanations
    STATUS_DESCRIPTIONS = {
        "RUNNING": "🟢 Pod is actively running",
        "PENDING": "🟡 Pod is being created/started",
        "EXITED": "⚫ Pod has stopped/exited",
        "STOPPED": "⚫ Pod is stopped",
        "TERMINATED": "⚫ Pod has been terminated",
        "CREATED": "🟡 Pod created but not started",
        "STARTING": "🟡 Pod is starting up",
    }

    def __init__(self, api_token: str):
        """Initialize the pod manager."""
        self.api_token = api_token
        self.session = requests.Session()
        self.session.headers.update({
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_token}"
        })

    def _graphql_query(self, query: str, variables: Dict = None) -> Dict:
        """Execute a GraphQL query against RunPod API."""
        payload = {"query": query}
        if variables:
            payload["variables"] = variables

        try:
            response = self.session.post(self.RUNPOD_API_BASE, json=payload)
            response.raise_for_status()
            result = response.json()

            if "errors" in result:
                raise RuntimeError(f"GraphQL error: {result['errors']}")

            return result.get("data", {})
        except requests.exceptions.RequestException as e:
            raise RuntimeError(f"API request failed: {e}")

    def list_all_pods(self) -> List[Dict[str, Any]]:
        """List all pods in the account."""
        query = """
        query Pods {
            myself {
                pods {
                    id
                    name
                    desiredStatus
                    imageName
                    gpuCount
                    costPerHr
                    machine {
                        gpuDisplayName
                    }
                    runtime {
                        uptimeInSeconds
                        gpus {
                            gpuUtilPercent
                            memoryUtilPercent
                        }
                        ports {
                            privatePort
                            publicPort
                        }
                    }
                }
            }
        }
        """

        result = self._graphql_query(query)
        pods = result.get("myself", {}).get("pods", [])
        return pods

    def get_pod_details(self, pod_id: str) -> Dict[str, Any]:
        """Get detailed information about a specific pod."""
        query = """
        query Pod($input: PodFilter!) {
            pod(input: $input) {
                id
                name
                desiredStatus
                imageName
                gpuCount
                costPerHr
                machine {
                    gpuDisplayName
                }
                runtime {
                    uptimeInSeconds
                    gpus {
                        id
                        gpuUtilPercent
                        memoryUtilPercent
                    }
                    ports {
                        ip
                        isIpPublic
                        privatePort
                        publicPort
                        type
                    }
                }
            }
        }
        """

        variables = {"input": {"podId": pod_id}}
        result = self._graphql_query(query, variables)
        return result.get("pod", {})

    def terminate_pod(self, pod_id: str) -> bool:
        """Terminate a specific pod."""
        query = """
        mutation PodTerminate($input: PodTerminateInput!) {
            podTerminate(input: $input)
        }
        """

        variables = {"input": {"podId": pod_id}}

        try:
            result = self._graphql_query(query, variables)
            return result.get("podTerminate", False)
        except Exception as e:
            print(f"  ❌ Failed to terminate pod {pod_id}: {e}")
            return False

    def format_uptime(self, seconds: int) -> str:
        """Format uptime in human-readable format."""
        delta = timedelta(seconds=seconds)
        days = delta.days
        hours = delta.seconds // 3600
        minutes = (delta.seconds % 3600) // 60

        parts = []
        if days > 0:
            parts.append(f"{days}d")
        if hours > 0 or days > 0:
            parts.append(f"{hours}h")
        parts.append(f"{minutes}m")

        return " ".join(parts)

    def format_cost(self, cost_per_hr: float, uptime_seconds: int) -> str:
        """Calculate and format accumulated cost."""
        if cost_per_hr and uptime_seconds:
            hours = uptime_seconds / 3600
            total_cost = cost_per_hr * hours
            return f"${total_cost:.4f} (${cost_per_hr:.4f}/hr)"
        return "N/A"

    def calculate_cost(self, cost_per_hr: float, uptime_seconds: int) -> float:
        """Calculate accumulated cost."""
        if cost_per_hr and uptime_seconds:
            hours = uptime_seconds / 3600
            return cost_per_hr * hours
        return 0.0

    def print_pod_summary(self, pod: Dict[str, Any], detailed: bool = True):
        """Print formatted pod information."""
        pod_id = pod.get("id", "unknown")
        name = pod.get("name", "unnamed")
        status = pod.get("desiredStatus", "UNKNOWN")
        image = pod.get("imageName", "unknown")
        gpu_count = pod.get("gpuCount", 0)
        cost_per_hr = pod.get("costPerHr", 0)

        machine = pod.get("machine", {})
        gpu_name = machine.get("gpuDisplayName", "unknown GPU")

        runtime = pod.get("runtime")
        uptime_seconds = runtime.get("uptimeInSeconds", 0) if runtime else 0

        # Status icon and description
        status_desc = self.STATUS_DESCRIPTIONS.get(status, f"❓ Unknown status: {status}")

        print(f"\n{'='*80}")
        print(f"Pod: {name}")
        print(f"ID:  {pod_id}")
        print(f"{'='*80}")
        print(f"Status:  {status_desc}")
        print(f"GPU:     {gpu_name} x{gpu_count}")
        print(f"Image:   {image}")

        if uptime_seconds > 0:
            print(f"Uptime:  {self.format_uptime(uptime_seconds)}")
            print(f"Cost:    {self.format_cost(cost_per_hr, uptime_seconds)}")
        else:
            print(f"Uptime:  Not running")
            print(f"Cost:    ${cost_per_hr:.4f}/hr (when running)")

        if detailed and runtime:
            gpus = runtime.get("gpus", [])
            if gpus:
                print(f"\nGPU Utilization:")
                for i, gpu in enumerate(gpus):
                    gpu_util = gpu.get("gpuUtilPercent", 0)
                    mem_util = gpu.get("memoryUtilPercent", 0)
                    print(f"  GPU {i}: {gpu_util}% compute, {mem_util}% memory")

            ports = runtime.get("ports", [])
            if ports:
                print(f"\nExposed Ports:")
                for port in ports:
                    private = port.get("privatePort")
                    public = port.get("publicPort")
                    print(f"  {private} -> {public}")

    def list_pods_command(self):
        """List all pods with detailed information."""
        print("\n" + "="*80)
        print("RUNPOD - ALL PODS")
        print("="*80)

        try:
            pods = self.list_all_pods()

            if not pods:
                print("\n✓ No pods found")
                return

            print(f"\nFound {len(pods)} pod(s):")

            # Separate by status
            running = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING"]]
            pending = [p for p in pods if p.get("desiredStatus") in ["PENDING", "CREATED"]]
            stopped = [p for p in pods if p.get("desiredStatus") in ["EXITED", "STOPPED", "TERMINATED"]]

            # Print running pods first
            if running:
                print(f"\n{'─'*80}")
                print(f"RUNNING PODS ({len(running)})")
                print(f"{'─'*80}")
                for pod in running:
                    self.print_pod_summary(pod, detailed=True)

            # Then pending
            if pending:
                print(f"\n{'─'*80}")
                print(f"PENDING PODS ({len(pending)})")
                print(f"{'─'*80}")
                for pod in pending:
                    self.print_pod_summary(pod, detailed=False)

            # Finally stopped
            if stopped:
                print(f"\n{'─'*80}")
                print(f"STOPPED PODS ({len(stopped)})")
                print(f"{'─'*80}")
                for pod in stopped:
                    self.print_pod_summary(pod, detailed=False)

            # Summary
            print(f"\n{'='*80}")
            print(f"SUMMARY: {len(running)} running, {len(pending)} pending, {len(stopped)} stopped")
            print(f"{'='*80}\n")

        except Exception as e:
            print(f"\n❌ Error listing pods: {e}")
            import traceback
            traceback.print_exc()

    def kill_pod_command(self, pod_id: str) -> bool:
        """Kill a specific pod."""
        print(f"\nTerminating pod: {pod_id}")

        try:
            # Get pod details first
            pod = self.get_pod_details(pod_id)
            if not pod:
                print(f"❌ Pod not found: {pod_id}")
                return False

            self.print_pod_summary(pod, detailed=False)

            # Confirm
            response = input("\nTerminate this pod? (y/N): ").strip().lower()
            if response != 'y':
                print("Cancelled.")
                return False

            # Terminate
            success = self.terminate_pod(pod_id)
            if success:
                print(f"✓ Pod {pod_id} terminated successfully")
                return True
            else:
                print(f"❌ Failed to terminate pod {pod_id}")
                return False

        except Exception as e:
            print(f"❌ Error terminating pod: {e}")
            import traceback
            traceback.print_exc()
            return False

    def kill_all_command(self, running_only: bool = False):
        """Kill all pods (with confirmation)."""
        print("\n" + "="*80)
        print("RUNPOD - KILL ALL PODS")
        print("="*80)

        try:
            pods = self.list_all_pods()

            if not pods:
                print("\n✓ No pods found")
                return

            # Filter pods based on running_only flag
            if running_only:
                pods_to_kill = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING", "PENDING", "CREATED"]]
                print(f"\nFound {len(pods_to_kill)} active pod(s) to terminate:")
            else:
                pods_to_kill = pods
                print(f"\nFound {len(pods_to_kill)} pod(s) to terminate:")

            if not pods_to_kill:
                print("\n✓ No pods to terminate")
                return

            # Display all pods
            for pod in pods_to_kill:
                self.print_pod_summary(pod, detailed=False)

            # Calculate total cost
            total_cost = sum(
                self.calculate_cost(p.get("costPerHr", 0), 
                                  p.get("runtime", {}).get("uptimeInSeconds", 0) if p.get("runtime") else 0)
                for p in pods_to_kill
            )

            print(f"\n{'='*80}")
            print(f"Total accumulated cost: ${total_cost:.4f}")
            print(f"{'='*80}\n")

            # Confirm
            prompt = "⚠️  TERMINATE ALL THESE PODS? (y/N): "
            response = input(prompt).strip().lower()

            if response != 'y':
                print("\nCancelled. No pods were terminated.")
                return

            # Terminate all
            print("\nTerminating pods...")
            success_count = 0
            failed_count = 0

            for pod in pods_to_kill:
                pod_id = pod.get("id")
                name = pod.get("name", "unnamed")

                print(f"  Terminating {name} ({pod_id})...", end=" ")

                if self.terminate_pod(pod_id):
                    print("✓")
                    success_count += 1
                else:
                    print("❌")
                    failed_count += 1

            print(f"\n{'='*80}")
            print(f"Termination complete: {success_count} succeeded, {failed_count} failed")
            print(f"{'='*80}\n")

        except Exception as e:
            print(f"\n❌ Error in kill-all operation: {e}")
            import traceback
            traceback.print_exc()
list_all_pods()

List all pods in the account.

Source code in src\aegear\nn\ops\runpod_launcher.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def list_all_pods(self) -> List[Dict[str, Any]]:
    """List all pods in the account."""
    query = """
    query Pods {
        myself {
            pods {
                id
                name
                desiredStatus
                imageName
                gpuCount
                costPerHr
                machine {
                    gpuDisplayName
                }
                runtime {
                    uptimeInSeconds
                    gpus {
                        gpuUtilPercent
                        memoryUtilPercent
                    }
                    ports {
                        privatePort
                        publicPort
                    }
                }
            }
        }
    }
    """

    result = self._graphql_query(query)
    pods = result.get("myself", {}).get("pods", [])
    return pods
get_pod_details(pod_id)

Get detailed information about a specific pod.

Source code in src\aegear\nn\ops\runpod_launcher.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def get_pod_details(self, pod_id: str) -> Dict[str, Any]:
    """Get detailed information about a specific pod."""
    query = """
    query Pod($input: PodFilter!) {
        pod(input: $input) {
            id
            name
            desiredStatus
            imageName
            gpuCount
            costPerHr
            machine {
                gpuDisplayName
            }
            runtime {
                uptimeInSeconds
                gpus {
                    id
                    gpuUtilPercent
                    memoryUtilPercent
                }
                ports {
                    ip
                    isIpPublic
                    privatePort
                    publicPort
                    type
                }
            }
        }
    }
    """

    variables = {"input": {"podId": pod_id}}
    result = self._graphql_query(query, variables)
    return result.get("pod", {})
terminate_pod(pod_id)

Terminate a specific pod.

Source code in src\aegear\nn\ops\runpod_launcher.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def terminate_pod(self, pod_id: str) -> bool:
    """Terminate a specific pod."""
    query = """
    mutation PodTerminate($input: PodTerminateInput!) {
        podTerminate(input: $input)
    }
    """

    variables = {"input": {"podId": pod_id}}

    try:
        result = self._graphql_query(query, variables)
        return result.get("podTerminate", False)
    except Exception as e:
        print(f"  ❌ Failed to terminate pod {pod_id}: {e}")
        return False
format_uptime(seconds)

Format uptime in human-readable format.

Source code in src\aegear\nn\ops\runpod_launcher.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def format_uptime(self, seconds: int) -> str:
    """Format uptime in human-readable format."""
    delta = timedelta(seconds=seconds)
    days = delta.days
    hours = delta.seconds // 3600
    minutes = (delta.seconds % 3600) // 60

    parts = []
    if days > 0:
        parts.append(f"{days}d")
    if hours > 0 or days > 0:
        parts.append(f"{hours}h")
    parts.append(f"{minutes}m")

    return " ".join(parts)
format_cost(cost_per_hr, uptime_seconds)

Calculate and format accumulated cost.

Source code in src\aegear\nn\ops\runpod_launcher.py
451
452
453
454
455
456
457
def format_cost(self, cost_per_hr: float, uptime_seconds: int) -> str:
    """Calculate and format accumulated cost."""
    if cost_per_hr and uptime_seconds:
        hours = uptime_seconds / 3600
        total_cost = cost_per_hr * hours
        return f"${total_cost:.4f} (${cost_per_hr:.4f}/hr)"
    return "N/A"
calculate_cost(cost_per_hr, uptime_seconds)

Calculate accumulated cost.

Source code in src\aegear\nn\ops\runpod_launcher.py
459
460
461
462
463
464
def calculate_cost(self, cost_per_hr: float, uptime_seconds: int) -> float:
    """Calculate accumulated cost."""
    if cost_per_hr and uptime_seconds:
        hours = uptime_seconds / 3600
        return cost_per_hr * hours
    return 0.0
print_pod_summary(pod, detailed=True)

Print formatted pod information.

Source code in src\aegear\nn\ops\runpod_launcher.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def print_pod_summary(self, pod: Dict[str, Any], detailed: bool = True):
    """Print formatted pod information."""
    pod_id = pod.get("id", "unknown")
    name = pod.get("name", "unnamed")
    status = pod.get("desiredStatus", "UNKNOWN")
    image = pod.get("imageName", "unknown")
    gpu_count = pod.get("gpuCount", 0)
    cost_per_hr = pod.get("costPerHr", 0)

    machine = pod.get("machine", {})
    gpu_name = machine.get("gpuDisplayName", "unknown GPU")

    runtime = pod.get("runtime")
    uptime_seconds = runtime.get("uptimeInSeconds", 0) if runtime else 0

    # Status icon and description
    status_desc = self.STATUS_DESCRIPTIONS.get(status, f"❓ Unknown status: {status}")

    print(f"\n{'='*80}")
    print(f"Pod: {name}")
    print(f"ID:  {pod_id}")
    print(f"{'='*80}")
    print(f"Status:  {status_desc}")
    print(f"GPU:     {gpu_name} x{gpu_count}")
    print(f"Image:   {image}")

    if uptime_seconds > 0:
        print(f"Uptime:  {self.format_uptime(uptime_seconds)}")
        print(f"Cost:    {self.format_cost(cost_per_hr, uptime_seconds)}")
    else:
        print(f"Uptime:  Not running")
        print(f"Cost:    ${cost_per_hr:.4f}/hr (when running)")

    if detailed and runtime:
        gpus = runtime.get("gpus", [])
        if gpus:
            print(f"\nGPU Utilization:")
            for i, gpu in enumerate(gpus):
                gpu_util = gpu.get("gpuUtilPercent", 0)
                mem_util = gpu.get("memoryUtilPercent", 0)
                print(f"  GPU {i}: {gpu_util}% compute, {mem_util}% memory")

        ports = runtime.get("ports", [])
        if ports:
            print(f"\nExposed Ports:")
            for port in ports:
                private = port.get("privatePort")
                public = port.get("publicPort")
                print(f"  {private} -> {public}")
list_pods_command()

List all pods with detailed information.

Source code in src\aegear\nn\ops\runpod_launcher.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def list_pods_command(self):
    """List all pods with detailed information."""
    print("\n" + "="*80)
    print("RUNPOD - ALL PODS")
    print("="*80)

    try:
        pods = self.list_all_pods()

        if not pods:
            print("\n✓ No pods found")
            return

        print(f"\nFound {len(pods)} pod(s):")

        # Separate by status
        running = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING"]]
        pending = [p for p in pods if p.get("desiredStatus") in ["PENDING", "CREATED"]]
        stopped = [p for p in pods if p.get("desiredStatus") in ["EXITED", "STOPPED", "TERMINATED"]]

        # Print running pods first
        if running:
            print(f"\n{'─'*80}")
            print(f"RUNNING PODS ({len(running)})")
            print(f"{'─'*80}")
            for pod in running:
                self.print_pod_summary(pod, detailed=True)

        # Then pending
        if pending:
            print(f"\n{'─'*80}")
            print(f"PENDING PODS ({len(pending)})")
            print(f"{'─'*80}")
            for pod in pending:
                self.print_pod_summary(pod, detailed=False)

        # Finally stopped
        if stopped:
            print(f"\n{'─'*80}")
            print(f"STOPPED PODS ({len(stopped)})")
            print(f"{'─'*80}")
            for pod in stopped:
                self.print_pod_summary(pod, detailed=False)

        # Summary
        print(f"\n{'='*80}")
        print(f"SUMMARY: {len(running)} running, {len(pending)} pending, {len(stopped)} stopped")
        print(f"{'='*80}\n")

    except Exception as e:
        print(f"\n❌ Error listing pods: {e}")
        import traceback
        traceback.print_exc()
kill_pod_command(pod_id)

Kill a specific pod.

Source code in src\aegear\nn\ops\runpod_launcher.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def kill_pod_command(self, pod_id: str) -> bool:
    """Kill a specific pod."""
    print(f"\nTerminating pod: {pod_id}")

    try:
        # Get pod details first
        pod = self.get_pod_details(pod_id)
        if not pod:
            print(f"❌ Pod not found: {pod_id}")
            return False

        self.print_pod_summary(pod, detailed=False)

        # Confirm
        response = input("\nTerminate this pod? (y/N): ").strip().lower()
        if response != 'y':
            print("Cancelled.")
            return False

        # Terminate
        success = self.terminate_pod(pod_id)
        if success:
            print(f"✓ Pod {pod_id} terminated successfully")
            return True
        else:
            print(f"❌ Failed to terminate pod {pod_id}")
            return False

    except Exception as e:
        print(f"❌ Error terminating pod: {e}")
        import traceback
        traceback.print_exc()
        return False
kill_all_command(running_only=False)

Kill all pods (with confirmation).

Source code in src\aegear\nn\ops\runpod_launcher.py
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def kill_all_command(self, running_only: bool = False):
    """Kill all pods (with confirmation)."""
    print("\n" + "="*80)
    print("RUNPOD - KILL ALL PODS")
    print("="*80)

    try:
        pods = self.list_all_pods()

        if not pods:
            print("\n✓ No pods found")
            return

        # Filter pods based on running_only flag
        if running_only:
            pods_to_kill = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING", "PENDING", "CREATED"]]
            print(f"\nFound {len(pods_to_kill)} active pod(s) to terminate:")
        else:
            pods_to_kill = pods
            print(f"\nFound {len(pods_to_kill)} pod(s) to terminate:")

        if not pods_to_kill:
            print("\n✓ No pods to terminate")
            return

        # Display all pods
        for pod in pods_to_kill:
            self.print_pod_summary(pod, detailed=False)

        # Calculate total cost
        total_cost = sum(
            self.calculate_cost(p.get("costPerHr", 0), 
                              p.get("runtime", {}).get("uptimeInSeconds", 0) if p.get("runtime") else 0)
            for p in pods_to_kill
        )

        print(f"\n{'='*80}")
        print(f"Total accumulated cost: ${total_cost:.4f}")
        print(f"{'='*80}\n")

        # Confirm
        prompt = "⚠️  TERMINATE ALL THESE PODS? (y/N): "
        response = input(prompt).strip().lower()

        if response != 'y':
            print("\nCancelled. No pods were terminated.")
            return

        # Terminate all
        print("\nTerminating pods...")
        success_count = 0
        failed_count = 0

        for pod in pods_to_kill:
            pod_id = pod.get("id")
            name = pod.get("name", "unnamed")

            print(f"  Terminating {name} ({pod_id})...", end=" ")

            if self.terminate_pod(pod_id):
                print("✓")
                success_count += 1
            else:
                print("❌")
                failed_count += 1

        print(f"\n{'='*80}")
        print(f"Termination complete: {success_count} succeeded, {failed_count} failed")
        print(f"{'='*80}\n")

    except Exception as e:
        print(f"\n❌ Error in kill-all operation: {e}")
        import traceback
        traceback.print_exc()

exit_codes

Exit codes for training and HPO workflows.

These codes are used to communicate specific failure modes between the training script, container runtime, and HPO orchestrator.

get_exit_code_description(code)

Get human-readable description for an exit code.

Source code in src\aegear\nn\ops\exit_codes.py
20
21
22
23
24
25
26
27
def get_exit_code_description(code: int) -> str:
    """Get human-readable description for an exit code."""
    descriptions = {
        EXIT_SUCCESS: "Success",
        EXIT_TRAINING_FAILURE: "Training failure",
        EXIT_CUDA_UNAVAILABLE: "CUDA unavailable (machine issue - should retry)",
    }
    return descriptions.get(code, f"Unknown exit code: {code}")

runpod_launcher

RunPod pod management utilities for training and HPO workflows.

RunPodLauncher

Manages RunPod pod lifecycle for training jobs.

Source code in src\aegear\nn\ops\runpod_launcher.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
class RunPodLauncher:
    """Manages RunPod pod lifecycle for training jobs."""
    RUNPOD_API_BASE = "https://api.runpod.io/graphql"
    DEFAULT_IMAGE = "docker.io/ljubobratovicrelja/aegear:latest"
    DEFAULT_GPU_TYPE = "NVIDIA GeForce RTX 5090"
    DEFAULT_VOLUME_SIZE = 10
    DEFAULT_CONTAINER_DISK_SIZE = 8

    def __init__(self, api_token: str, docker_username: Optional[str] = None, docker_pat: Optional[str] = None):
        self.api_token = api_token
        self.docker_username = docker_username
        self.docker_pat = docker_pat
        self.session = requests.Session()
        self.session.headers.update({
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_token}"
        })

    def _graphql_query(self, query: str, variables: Optional[Dict] = None) -> Dict:
        payload = {"query": query}
        if variables:
            payload["variables"] = variables
        response = self.session.post(self.RUNPOD_API_BASE, json=payload)
        response.raise_for_status()
        result = response.json()
        if "errors" in result:
            raise RuntimeError(f"GraphQL error: {result['errors']}")
        return result.get("data", {})

    def get_gpu_types(self) -> list:
        query = """
        query GpuTypes {
            gpuTypes {
                id
                displayName
                memoryInGb
            }
        }
        """
        try:
            result = self._graphql_query(query)
            return result.get("gpuTypes", [])
        except Exception as e:
            print(f"Warning: Could not fetch GPU types: {e}")
            return []

    def create_container_registry_auth(self) -> Optional[str]:
        if not self.docker_username or not self.docker_pat:
            return None
        query = """
        mutation SaveRegistryAuth($input: SaveRegistryAuthInput!) {
            saveRegistryAuth(input: $input) {
                id
                name
            }
        }
        """
        auth_name = f"dockerhub_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        variables = {
            "input": {
                "name": auth_name,
                "username": self.docker_username,
                "password": self.docker_pat
            }
        }
        try:
            result = self._graphql_query(query, variables)
            auth_id = result.get("saveRegistryAuth", {}).get("id")
            if auth_id:
                print(f"Created Docker Hub authentication: {auth_id}")
            return auth_id
        except Exception as e:
            print(f"Warning: Failed to create registry auth: {e}")
            print("  Proceeding without authentication (may hit Docker Hub rate limits)")
            return None

    def launch_pod(self, task_name: str, env_vars: Dict[str, str], gpu_type: str = DEFAULT_GPU_TYPE, gpu_count: int = 1, volume_size: int = DEFAULT_VOLUME_SIZE, container_disk_size: int = DEFAULT_CONTAINER_DISK_SIZE, image_name: str = DEFAULT_IMAGE) -> str:
        registry_auth_id = self.create_container_registry_auth()
        env_vars["RUNPOD_API_KEY"] = self.api_token
        query = """
        mutation PodFindAndDeployOnDemand($input: PodFindAndDeployOnDemandInput!) {
            podFindAndDeployOnDemand(input: $input) {
                id
                desiredStatus
                imageName
                env
                machineId
                machine {
                    gpuDisplayName
                }
            }
        }
        """
        pod_env = [
            {"key": k, "value": str(v)}
            for k, v in env_vars.items()
        ]
        variables = {
            "input": {
                "name": task_name,
                "imageName": image_name,
                "gpuTypeId": gpu_type,
                "gpuCount": gpu_count,
                "volumeInGb": volume_size,
                "containerDiskInGb": container_disk_size,
                "volumeMountPath": "/workspace",
                "cloudType": "SECURE",
                "env": pod_env,
                "startSsh": False,
                "startJupyter": False,
            }
        }
        if registry_auth_id:
            variables["input"]["containerRegistryAuthId"] = registry_auth_id
        print(f"\nLaunching pod: {task_name}")
        print(f"   Image: {image_name}")
        print(f"   GPU: {gpu_type} x{gpu_count}")
        print(f"   Volume: {volume_size}GB, Container Disk: {container_disk_size}GB")
        try:
            result = self._graphql_query(query, variables)
        except Exception as e:
            print(f"\nFailed to create pod. Request details:")
            print(f"   GPU Type: {variables['input'].get('gpuTypeId', 'Not specified')}")
            print(f"   Cloud Type: {variables['input']['cloudType']}")
            raise
        pod = result.get("podFindAndDeployOnDemand", {})
        pod_id = pod.get("id")
        if not pod_id:
            raise RuntimeError("Failed to create pod - no ID returned")
        print(f"Pod created: {pod_id}")
        print(f"   Machine: {pod.get('machine', {}).get('gpuDisplayName', 'Unknown')}")
        print(f"   Self-termination: RUNPOD_POD_ID and RUNPOD_API_KEY are available in container")
        print(f"   Container will auto-terminate on training completion")
        return pod_id

    def get_pod_status(self, pod_id: str) -> Dict[str, Any]:
        query = """
        query Pod($input: PodFilter!) {
            pod(input: $input) {
                id
                desiredStatus
                runtime {
                    uptimeInSeconds
                    ports {
                        ip
                        isIpPublic
                        privatePort
                        publicPort
                        type
                    }
                    gpus {
                        id
                        gpuUtilPercent
                        memoryUtilPercent
                    }
                }
            }
        }
        """
        variables = {"input": {"podId": pod_id}}
        result = self._graphql_query(query, variables)
        return result.get("pod", {})

    def get_pod_logs(self, pod_id: str, lines: int = 100) -> str:
        try:
            endpoint = f"https://api.runpod.io/v1/pods/{pod_id}/logs"
            response = self.session.get(endpoint, params={"lines": lines})
            if response.status_code == 200:
                return response.text
        except:
            pass
        return ""

    def get_pod_exit_code(self, pod_id: str) -> int:
        """
        Extract exit code from pod logs.

        Looks for exit code patterns in the last lines of logs.
        Returns 0 if no exit code found or if pod terminated successfully.
        Returns 42 if CUDA validation failed (machine issue).
        Returns 1 for other failures.

        Args:
            pod_id: The pod ID to check

        Returns:
            Exit code (0, 1, 42, etc.)
        """
        logs = self.get_pod_logs(pod_id, lines=200)

        if not logs:
            # No logs available, assume success if pod terminated
            return 0

        # Check for explicit exit code 42 (CUDA unavailable)
        if "exit 42" in logs.lower() or "exiting with code 42" in logs.lower():
            return 42

        # Check for CUDA validation failure messages
        if "CUDA VALIDATION FAILED" in logs:
            return 42

        if "DEVICE VALIDATION FAILED" in logs:
            return 42

        # Check for successful completion
        if "Training completed successfully" in logs:
            return 0

        # Check for training failure
        if "Training failed" in logs:
            return 1

        # Default to success if pod terminated cleanly
        return 0

    def terminate_pod(self, pod_id: str):
        query = """
        mutation PodTerminate($input: PodTerminateInput!) {
            podTerminate(input: $input)
        }
        """
        variables = {"input": {"podId": pod_id}}
        print(f"\nTerminating pod: {pod_id}")
        result = self._graphql_query(query, variables)
        if result.get("podTerminate"):
            print("Pod terminated successfully")
        else:
            print("Pod termination request sent (status unclear)")

    def monitor_pod(self, pod_id: str, check_interval: int = 60, timeout_hours: int = 24, auto_terminate: bool = True):
        start_time = time.time()
        timeout_seconds = timeout_hours * 3600
        print(f"\nMonitoring pod: {pod_id}")
        print(f"   Check interval: {check_interval}s")
        print(f"   Timeout: {timeout_hours}h")
        print(f"   Auto-terminate: {auto_terminate}")
        print("\n" + "="*60)
        try:
            while True:
                elapsed = time.time() - start_time
                if elapsed > timeout_seconds:
                    print(f"\nTimeout reached ({timeout_hours}h)")
                    if auto_terminate:
                        self.terminate_pod(pod_id)
                    return False
                try:
                    status = self.get_pod_status(pod_id)
                    if status is None or not status:
                        print(f"\n\nPod has been terminated (no longer found in API)")
                        print("   Training completed and pod self-terminated successfully!")
                        return True
                    desired_status = status.get("desiredStatus", "UNKNOWN")
                    runtime = status.get("runtime")
                    if desired_status in ["EXITED", "STOPPED", "TERMINATED"]:
                        print(f"\n\nPod stopped with status: {desired_status}")
                        return True
                    if runtime:
                        uptime = runtime.get("uptimeInSeconds", 0)
                        gpus = runtime.get("gpus", [])
                        print(f"\r[{time.strftime('%H:%M:%S')}] "
                              f"Status: {desired_status} | "
                              f"Uptime: {uptime//3600}h {(uptime%3600)//60}m | "
                              f"Elapsed: {elapsed//3600:.0f}h {(elapsed%3600)//60:.0f}m",
                              end="", flush=True)
                    else:
                        print(f"\r[{time.strftime('%H:%M:%S')}] "
                              f"Status: {desired_status} | "
                              f"Waiting for runtime...",
                              end="", flush=True)
                except Exception as e:
                    print(f"\n\nPod terminated (API query failed: {e})")
                    return True
                time.sleep(check_interval)
        except KeyboardInterrupt:
            print(f"\n\nMonitoring interrupted by user")
            if auto_terminate:
                print("Terminating pod...")
                self.terminate_pod(pod_id)
            return False
        finally:
            if auto_terminate:
                try:
                    self.terminate_pod(pod_id)
                except:
                    pass
get_pod_exit_code(pod_id)

Extract exit code from pod logs.

Looks for exit code patterns in the last lines of logs. Returns 0 if no exit code found or if pod terminated successfully. Returns 42 if CUDA validation failed (machine issue). Returns 1 for other failures.

Parameters:

Name Type Description Default
pod_id str

The pod ID to check

required

Returns:

Type Description
int

Exit code (0, 1, 42, etc.)

Source code in src\aegear\nn\ops\runpod_launcher.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def get_pod_exit_code(self, pod_id: str) -> int:
    """
    Extract exit code from pod logs.

    Looks for exit code patterns in the last lines of logs.
    Returns 0 if no exit code found or if pod terminated successfully.
    Returns 42 if CUDA validation failed (machine issue).
    Returns 1 for other failures.

    Args:
        pod_id: The pod ID to check

    Returns:
        Exit code (0, 1, 42, etc.)
    """
    logs = self.get_pod_logs(pod_id, lines=200)

    if not logs:
        # No logs available, assume success if pod terminated
        return 0

    # Check for explicit exit code 42 (CUDA unavailable)
    if "exit 42" in logs.lower() or "exiting with code 42" in logs.lower():
        return 42

    # Check for CUDA validation failure messages
    if "CUDA VALIDATION FAILED" in logs:
        return 42

    if "DEVICE VALIDATION FAILED" in logs:
        return 42

    # Check for successful completion
    if "Training completed successfully" in logs:
        return 0

    # Check for training failure
    if "Training failed" in logs:
        return 1

    # Default to success if pod terminated cleanly
    return 0
PodManager

Manages RunPod pods with listing and termination capabilities.

Source code in src\aegear\nn\ops\runpod_launcher.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
class PodManager:
    """Manages RunPod pods with listing and termination capabilities."""

    RUNPOD_API_BASE = "https://api.runpod.io/graphql"

    # Status descriptions for user-friendly explanations
    STATUS_DESCRIPTIONS = {
        "RUNNING": "🟢 Pod is actively running",
        "PENDING": "🟡 Pod is being created/started",
        "EXITED": "⚫ Pod has stopped/exited",
        "STOPPED": "⚫ Pod is stopped",
        "TERMINATED": "⚫ Pod has been terminated",
        "CREATED": "🟡 Pod created but not started",
        "STARTING": "🟡 Pod is starting up",
    }

    def __init__(self, api_token: str):
        """Initialize the pod manager."""
        self.api_token = api_token
        self.session = requests.Session()
        self.session.headers.update({
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_token}"
        })

    def _graphql_query(self, query: str, variables: Dict = None) -> Dict:
        """Execute a GraphQL query against RunPod API."""
        payload = {"query": query}
        if variables:
            payload["variables"] = variables

        try:
            response = self.session.post(self.RUNPOD_API_BASE, json=payload)
            response.raise_for_status()
            result = response.json()

            if "errors" in result:
                raise RuntimeError(f"GraphQL error: {result['errors']}")

            return result.get("data", {})
        except requests.exceptions.RequestException as e:
            raise RuntimeError(f"API request failed: {e}")

    def list_all_pods(self) -> List[Dict[str, Any]]:
        """List all pods in the account."""
        query = """
        query Pods {
            myself {
                pods {
                    id
                    name
                    desiredStatus
                    imageName
                    gpuCount
                    costPerHr
                    machine {
                        gpuDisplayName
                    }
                    runtime {
                        uptimeInSeconds
                        gpus {
                            gpuUtilPercent
                            memoryUtilPercent
                        }
                        ports {
                            privatePort
                            publicPort
                        }
                    }
                }
            }
        }
        """

        result = self._graphql_query(query)
        pods = result.get("myself", {}).get("pods", [])
        return pods

    def get_pod_details(self, pod_id: str) -> Dict[str, Any]:
        """Get detailed information about a specific pod."""
        query = """
        query Pod($input: PodFilter!) {
            pod(input: $input) {
                id
                name
                desiredStatus
                imageName
                gpuCount
                costPerHr
                machine {
                    gpuDisplayName
                }
                runtime {
                    uptimeInSeconds
                    gpus {
                        id
                        gpuUtilPercent
                        memoryUtilPercent
                    }
                    ports {
                        ip
                        isIpPublic
                        privatePort
                        publicPort
                        type
                    }
                }
            }
        }
        """

        variables = {"input": {"podId": pod_id}}
        result = self._graphql_query(query, variables)
        return result.get("pod", {})

    def terminate_pod(self, pod_id: str) -> bool:
        """Terminate a specific pod."""
        query = """
        mutation PodTerminate($input: PodTerminateInput!) {
            podTerminate(input: $input)
        }
        """

        variables = {"input": {"podId": pod_id}}

        try:
            result = self._graphql_query(query, variables)
            return result.get("podTerminate", False)
        except Exception as e:
            print(f"  ❌ Failed to terminate pod {pod_id}: {e}")
            return False

    def format_uptime(self, seconds: int) -> str:
        """Format uptime in human-readable format."""
        delta = timedelta(seconds=seconds)
        days = delta.days
        hours = delta.seconds // 3600
        minutes = (delta.seconds % 3600) // 60

        parts = []
        if days > 0:
            parts.append(f"{days}d")
        if hours > 0 or days > 0:
            parts.append(f"{hours}h")
        parts.append(f"{minutes}m")

        return " ".join(parts)

    def format_cost(self, cost_per_hr: float, uptime_seconds: int) -> str:
        """Calculate and format accumulated cost."""
        if cost_per_hr and uptime_seconds:
            hours = uptime_seconds / 3600
            total_cost = cost_per_hr * hours
            return f"${total_cost:.4f} (${cost_per_hr:.4f}/hr)"
        return "N/A"

    def calculate_cost(self, cost_per_hr: float, uptime_seconds: int) -> float:
        """Calculate accumulated cost."""
        if cost_per_hr and uptime_seconds:
            hours = uptime_seconds / 3600
            return cost_per_hr * hours
        return 0.0

    def print_pod_summary(self, pod: Dict[str, Any], detailed: bool = True):
        """Print formatted pod information."""
        pod_id = pod.get("id", "unknown")
        name = pod.get("name", "unnamed")
        status = pod.get("desiredStatus", "UNKNOWN")
        image = pod.get("imageName", "unknown")
        gpu_count = pod.get("gpuCount", 0)
        cost_per_hr = pod.get("costPerHr", 0)

        machine = pod.get("machine", {})
        gpu_name = machine.get("gpuDisplayName", "unknown GPU")

        runtime = pod.get("runtime")
        uptime_seconds = runtime.get("uptimeInSeconds", 0) if runtime else 0

        # Status icon and description
        status_desc = self.STATUS_DESCRIPTIONS.get(status, f"❓ Unknown status: {status}")

        print(f"\n{'='*80}")
        print(f"Pod: {name}")
        print(f"ID:  {pod_id}")
        print(f"{'='*80}")
        print(f"Status:  {status_desc}")
        print(f"GPU:     {gpu_name} x{gpu_count}")
        print(f"Image:   {image}")

        if uptime_seconds > 0:
            print(f"Uptime:  {self.format_uptime(uptime_seconds)}")
            print(f"Cost:    {self.format_cost(cost_per_hr, uptime_seconds)}")
        else:
            print(f"Uptime:  Not running")
            print(f"Cost:    ${cost_per_hr:.4f}/hr (when running)")

        if detailed and runtime:
            gpus = runtime.get("gpus", [])
            if gpus:
                print(f"\nGPU Utilization:")
                for i, gpu in enumerate(gpus):
                    gpu_util = gpu.get("gpuUtilPercent", 0)
                    mem_util = gpu.get("memoryUtilPercent", 0)
                    print(f"  GPU {i}: {gpu_util}% compute, {mem_util}% memory")

            ports = runtime.get("ports", [])
            if ports:
                print(f"\nExposed Ports:")
                for port in ports:
                    private = port.get("privatePort")
                    public = port.get("publicPort")
                    print(f"  {private} -> {public}")

    def list_pods_command(self):
        """List all pods with detailed information."""
        print("\n" + "="*80)
        print("RUNPOD - ALL PODS")
        print("="*80)

        try:
            pods = self.list_all_pods()

            if not pods:
                print("\n✓ No pods found")
                return

            print(f"\nFound {len(pods)} pod(s):")

            # Separate by status
            running = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING"]]
            pending = [p for p in pods if p.get("desiredStatus") in ["PENDING", "CREATED"]]
            stopped = [p for p in pods if p.get("desiredStatus") in ["EXITED", "STOPPED", "TERMINATED"]]

            # Print running pods first
            if running:
                print(f"\n{'─'*80}")
                print(f"RUNNING PODS ({len(running)})")
                print(f"{'─'*80}")
                for pod in running:
                    self.print_pod_summary(pod, detailed=True)

            # Then pending
            if pending:
                print(f"\n{'─'*80}")
                print(f"PENDING PODS ({len(pending)})")
                print(f"{'─'*80}")
                for pod in pending:
                    self.print_pod_summary(pod, detailed=False)

            # Finally stopped
            if stopped:
                print(f"\n{'─'*80}")
                print(f"STOPPED PODS ({len(stopped)})")
                print(f"{'─'*80}")
                for pod in stopped:
                    self.print_pod_summary(pod, detailed=False)

            # Summary
            print(f"\n{'='*80}")
            print(f"SUMMARY: {len(running)} running, {len(pending)} pending, {len(stopped)} stopped")
            print(f"{'='*80}\n")

        except Exception as e:
            print(f"\n❌ Error listing pods: {e}")
            import traceback
            traceback.print_exc()

    def kill_pod_command(self, pod_id: str) -> bool:
        """Kill a specific pod."""
        print(f"\nTerminating pod: {pod_id}")

        try:
            # Get pod details first
            pod = self.get_pod_details(pod_id)
            if not pod:
                print(f"❌ Pod not found: {pod_id}")
                return False

            self.print_pod_summary(pod, detailed=False)

            # Confirm
            response = input("\nTerminate this pod? (y/N): ").strip().lower()
            if response != 'y':
                print("Cancelled.")
                return False

            # Terminate
            success = self.terminate_pod(pod_id)
            if success:
                print(f"✓ Pod {pod_id} terminated successfully")
                return True
            else:
                print(f"❌ Failed to terminate pod {pod_id}")
                return False

        except Exception as e:
            print(f"❌ Error terminating pod: {e}")
            import traceback
            traceback.print_exc()
            return False

    def kill_all_command(self, running_only: bool = False):
        """Kill all pods (with confirmation)."""
        print("\n" + "="*80)
        print("RUNPOD - KILL ALL PODS")
        print("="*80)

        try:
            pods = self.list_all_pods()

            if not pods:
                print("\n✓ No pods found")
                return

            # Filter pods based on running_only flag
            if running_only:
                pods_to_kill = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING", "PENDING", "CREATED"]]
                print(f"\nFound {len(pods_to_kill)} active pod(s) to terminate:")
            else:
                pods_to_kill = pods
                print(f"\nFound {len(pods_to_kill)} pod(s) to terminate:")

            if not pods_to_kill:
                print("\n✓ No pods to terminate")
                return

            # Display all pods
            for pod in pods_to_kill:
                self.print_pod_summary(pod, detailed=False)

            # Calculate total cost
            total_cost = sum(
                self.calculate_cost(p.get("costPerHr", 0), 
                                  p.get("runtime", {}).get("uptimeInSeconds", 0) if p.get("runtime") else 0)
                for p in pods_to_kill
            )

            print(f"\n{'='*80}")
            print(f"Total accumulated cost: ${total_cost:.4f}")
            print(f"{'='*80}\n")

            # Confirm
            prompt = "⚠️  TERMINATE ALL THESE PODS? (y/N): "
            response = input(prompt).strip().lower()

            if response != 'y':
                print("\nCancelled. No pods were terminated.")
                return

            # Terminate all
            print("\nTerminating pods...")
            success_count = 0
            failed_count = 0

            for pod in pods_to_kill:
                pod_id = pod.get("id")
                name = pod.get("name", "unnamed")

                print(f"  Terminating {name} ({pod_id})...", end=" ")

                if self.terminate_pod(pod_id):
                    print("✓")
                    success_count += 1
                else:
                    print("❌")
                    failed_count += 1

            print(f"\n{'='*80}")
            print(f"Termination complete: {success_count} succeeded, {failed_count} failed")
            print(f"{'='*80}\n")

        except Exception as e:
            print(f"\n❌ Error in kill-all operation: {e}")
            import traceback
            traceback.print_exc()
list_all_pods()

List all pods in the account.

Source code in src\aegear\nn\ops\runpod_launcher.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def list_all_pods(self) -> List[Dict[str, Any]]:
    """List all pods in the account."""
    query = """
    query Pods {
        myself {
            pods {
                id
                name
                desiredStatus
                imageName
                gpuCount
                costPerHr
                machine {
                    gpuDisplayName
                }
                runtime {
                    uptimeInSeconds
                    gpus {
                        gpuUtilPercent
                        memoryUtilPercent
                    }
                    ports {
                        privatePort
                        publicPort
                    }
                }
            }
        }
    }
    """

    result = self._graphql_query(query)
    pods = result.get("myself", {}).get("pods", [])
    return pods
get_pod_details(pod_id)

Get detailed information about a specific pod.

Source code in src\aegear\nn\ops\runpod_launcher.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def get_pod_details(self, pod_id: str) -> Dict[str, Any]:
    """Get detailed information about a specific pod."""
    query = """
    query Pod($input: PodFilter!) {
        pod(input: $input) {
            id
            name
            desiredStatus
            imageName
            gpuCount
            costPerHr
            machine {
                gpuDisplayName
            }
            runtime {
                uptimeInSeconds
                gpus {
                    id
                    gpuUtilPercent
                    memoryUtilPercent
                }
                ports {
                    ip
                    isIpPublic
                    privatePort
                    publicPort
                    type
                }
            }
        }
    }
    """

    variables = {"input": {"podId": pod_id}}
    result = self._graphql_query(query, variables)
    return result.get("pod", {})
terminate_pod(pod_id)

Terminate a specific pod.

Source code in src\aegear\nn\ops\runpod_launcher.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def terminate_pod(self, pod_id: str) -> bool:
    """Terminate a specific pod."""
    query = """
    mutation PodTerminate($input: PodTerminateInput!) {
        podTerminate(input: $input)
    }
    """

    variables = {"input": {"podId": pod_id}}

    try:
        result = self._graphql_query(query, variables)
        return result.get("podTerminate", False)
    except Exception as e:
        print(f"  ❌ Failed to terminate pod {pod_id}: {e}")
        return False
format_uptime(seconds)

Format uptime in human-readable format.

Source code in src\aegear\nn\ops\runpod_launcher.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def format_uptime(self, seconds: int) -> str:
    """Format uptime in human-readable format."""
    delta = timedelta(seconds=seconds)
    days = delta.days
    hours = delta.seconds // 3600
    minutes = (delta.seconds % 3600) // 60

    parts = []
    if days > 0:
        parts.append(f"{days}d")
    if hours > 0 or days > 0:
        parts.append(f"{hours}h")
    parts.append(f"{minutes}m")

    return " ".join(parts)
format_cost(cost_per_hr, uptime_seconds)

Calculate and format accumulated cost.

Source code in src\aegear\nn\ops\runpod_launcher.py
451
452
453
454
455
456
457
def format_cost(self, cost_per_hr: float, uptime_seconds: int) -> str:
    """Calculate and format accumulated cost."""
    if cost_per_hr and uptime_seconds:
        hours = uptime_seconds / 3600
        total_cost = cost_per_hr * hours
        return f"${total_cost:.4f} (${cost_per_hr:.4f}/hr)"
    return "N/A"
calculate_cost(cost_per_hr, uptime_seconds)

Calculate accumulated cost.

Source code in src\aegear\nn\ops\runpod_launcher.py
459
460
461
462
463
464
def calculate_cost(self, cost_per_hr: float, uptime_seconds: int) -> float:
    """Calculate accumulated cost."""
    if cost_per_hr and uptime_seconds:
        hours = uptime_seconds / 3600
        return cost_per_hr * hours
    return 0.0
print_pod_summary(pod, detailed=True)

Print formatted pod information.

Source code in src\aegear\nn\ops\runpod_launcher.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def print_pod_summary(self, pod: Dict[str, Any], detailed: bool = True):
    """Print formatted pod information."""
    pod_id = pod.get("id", "unknown")
    name = pod.get("name", "unnamed")
    status = pod.get("desiredStatus", "UNKNOWN")
    image = pod.get("imageName", "unknown")
    gpu_count = pod.get("gpuCount", 0)
    cost_per_hr = pod.get("costPerHr", 0)

    machine = pod.get("machine", {})
    gpu_name = machine.get("gpuDisplayName", "unknown GPU")

    runtime = pod.get("runtime")
    uptime_seconds = runtime.get("uptimeInSeconds", 0) if runtime else 0

    # Status icon and description
    status_desc = self.STATUS_DESCRIPTIONS.get(status, f"❓ Unknown status: {status}")

    print(f"\n{'='*80}")
    print(f"Pod: {name}")
    print(f"ID:  {pod_id}")
    print(f"{'='*80}")
    print(f"Status:  {status_desc}")
    print(f"GPU:     {gpu_name} x{gpu_count}")
    print(f"Image:   {image}")

    if uptime_seconds > 0:
        print(f"Uptime:  {self.format_uptime(uptime_seconds)}")
        print(f"Cost:    {self.format_cost(cost_per_hr, uptime_seconds)}")
    else:
        print(f"Uptime:  Not running")
        print(f"Cost:    ${cost_per_hr:.4f}/hr (when running)")

    if detailed and runtime:
        gpus = runtime.get("gpus", [])
        if gpus:
            print(f"\nGPU Utilization:")
            for i, gpu in enumerate(gpus):
                gpu_util = gpu.get("gpuUtilPercent", 0)
                mem_util = gpu.get("memoryUtilPercent", 0)
                print(f"  GPU {i}: {gpu_util}% compute, {mem_util}% memory")

        ports = runtime.get("ports", [])
        if ports:
            print(f"\nExposed Ports:")
            for port in ports:
                private = port.get("privatePort")
                public = port.get("publicPort")
                print(f"  {private} -> {public}")
list_pods_command()

List all pods with detailed information.

Source code in src\aegear\nn\ops\runpod_launcher.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def list_pods_command(self):
    """List all pods with detailed information."""
    print("\n" + "="*80)
    print("RUNPOD - ALL PODS")
    print("="*80)

    try:
        pods = self.list_all_pods()

        if not pods:
            print("\n✓ No pods found")
            return

        print(f"\nFound {len(pods)} pod(s):")

        # Separate by status
        running = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING"]]
        pending = [p for p in pods if p.get("desiredStatus") in ["PENDING", "CREATED"]]
        stopped = [p for p in pods if p.get("desiredStatus") in ["EXITED", "STOPPED", "TERMINATED"]]

        # Print running pods first
        if running:
            print(f"\n{'─'*80}")
            print(f"RUNNING PODS ({len(running)})")
            print(f"{'─'*80}")
            for pod in running:
                self.print_pod_summary(pod, detailed=True)

        # Then pending
        if pending:
            print(f"\n{'─'*80}")
            print(f"PENDING PODS ({len(pending)})")
            print(f"{'─'*80}")
            for pod in pending:
                self.print_pod_summary(pod, detailed=False)

        # Finally stopped
        if stopped:
            print(f"\n{'─'*80}")
            print(f"STOPPED PODS ({len(stopped)})")
            print(f"{'─'*80}")
            for pod in stopped:
                self.print_pod_summary(pod, detailed=False)

        # Summary
        print(f"\n{'='*80}")
        print(f"SUMMARY: {len(running)} running, {len(pending)} pending, {len(stopped)} stopped")
        print(f"{'='*80}\n")

    except Exception as e:
        print(f"\n❌ Error listing pods: {e}")
        import traceback
        traceback.print_exc()
kill_pod_command(pod_id)

Kill a specific pod.

Source code in src\aegear\nn\ops\runpod_launcher.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def kill_pod_command(self, pod_id: str) -> bool:
    """Kill a specific pod."""
    print(f"\nTerminating pod: {pod_id}")

    try:
        # Get pod details first
        pod = self.get_pod_details(pod_id)
        if not pod:
            print(f"❌ Pod not found: {pod_id}")
            return False

        self.print_pod_summary(pod, detailed=False)

        # Confirm
        response = input("\nTerminate this pod? (y/N): ").strip().lower()
        if response != 'y':
            print("Cancelled.")
            return False

        # Terminate
        success = self.terminate_pod(pod_id)
        if success:
            print(f"✓ Pod {pod_id} terminated successfully")
            return True
        else:
            print(f"❌ Failed to terminate pod {pod_id}")
            return False

    except Exception as e:
        print(f"❌ Error terminating pod: {e}")
        import traceback
        traceback.print_exc()
        return False
kill_all_command(running_only=False)

Kill all pods (with confirmation).

Source code in src\aegear\nn\ops\runpod_launcher.py
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def kill_all_command(self, running_only: bool = False):
    """Kill all pods (with confirmation)."""
    print("\n" + "="*80)
    print("RUNPOD - KILL ALL PODS")
    print("="*80)

    try:
        pods = self.list_all_pods()

        if not pods:
            print("\n✓ No pods found")
            return

        # Filter pods based on running_only flag
        if running_only:
            pods_to_kill = [p for p in pods if p.get("desiredStatus") in ["RUNNING", "STARTING", "PENDING", "CREATED"]]
            print(f"\nFound {len(pods_to_kill)} active pod(s) to terminate:")
        else:
            pods_to_kill = pods
            print(f"\nFound {len(pods_to_kill)} pod(s) to terminate:")

        if not pods_to_kill:
            print("\n✓ No pods to terminate")
            return

        # Display all pods
        for pod in pods_to_kill:
            self.print_pod_summary(pod, detailed=False)

        # Calculate total cost
        total_cost = sum(
            self.calculate_cost(p.get("costPerHr", 0), 
                              p.get("runtime", {}).get("uptimeInSeconds", 0) if p.get("runtime") else 0)
            for p in pods_to_kill
        )

        print(f"\n{'='*80}")
        print(f"Total accumulated cost: ${total_cost:.4f}")
        print(f"{'='*80}\n")

        # Confirm
        prompt = "⚠️  TERMINATE ALL THESE PODS? (y/N): "
        response = input(prompt).strip().lower()

        if response != 'y':
            print("\nCancelled. No pods were terminated.")
            return

        # Terminate all
        print("\nTerminating pods...")
        success_count = 0
        failed_count = 0

        for pod in pods_to_kill:
            pod_id = pod.get("id")
            name = pod.get("name", "unnamed")

            print(f"  Terminating {name} ({pod_id})...", end=" ")

            if self.terminate_pod(pod_id):
                print("✓")
                success_count += 1
            else:
                print("❌")
                failed_count += 1

        print(f"\n{'='*80}")
        print(f"Termination complete: {success_count} succeeded, {failed_count} failed")
        print(f"{'='*80}\n")

    except Exception as e:
        print(f"\n❌ Error in kill-all operation: {e}")
        import traceback
        traceback.print_exc()

training

Module containing various training-related utilities and functions.

WeightedBCEWithLogitsLoss

Custom weighted binary cross-entropy loss emphasizing Gaussian center.

Parameters:

Name Type Description Default
limit float

Threshold for positive region.

0.5
pos_weight float

Weight for positive region.

10.0
Source code in src\aegear\nn\training.py
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
class WeightedBCEWithLogitsLoss:
    """Custom weighted binary cross-entropy loss emphasizing Gaussian center.

    Args:
        limit (float): Threshold for positive region.
        pos_weight (float): Weight for positive region.
    """
    def __init__(self, limit=0.5, pos_weight=10.0):
        self.limit = limit
        self.pos_weight = pos_weight

    def __call__(self, pred, target):
        """Compute weighted BCE loss.

        Args:
            pred (torch.Tensor): Predicted logits.
            target (torch.Tensor): Target heatmap.

        Returns:
            torch.Tensor: Loss value.
        """
        weights = torch.ones_like(target)
        # emphasize center of Gaussian
        weights[target > self.limit] = self.pos_weight
        bce = F.binary_cross_entropy_with_logits(
            pred, target, weight=weights, reduction='mean')
        return bce

EfficientUNetLoss

Bases: WeightedBCEWithLogitsLoss

EfficientUNet loss combining BCE, centroid, sparsity, and Dice losses.

Parameters:

Name Type Description Default
limit float

Threshold for positive region.

0.5
pos_weight float

Weight for positive region.

5.0
centroid_weight float

Weight for centroid distance loss.

0.0025
sparsity_weight float

Weight for sparsity loss.

0.1
dice_weight float

Weight for Dice loss.

1.0
Source code in src\aegear\nn\training.py
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
class EfficientUNetLoss(WeightedBCEWithLogitsLoss):
    """EfficientUNet loss combining BCE, centroid, sparsity, and Dice losses.

    Args:
        limit (float): Threshold for positive region.
        pos_weight (float): Weight for positive region.
        centroid_weight (float): Weight for centroid distance loss.
        sparsity_weight (float): Weight for sparsity loss.
        dice_weight (float): Weight for Dice loss.
    """
    def __init__(self, limit=0.5, pos_weight=5.0, centroid_weight=2.5e-3, sparsity_weight=0.1, dice_weight=1.0):
        super().__init__(limit, pos_weight)
        self.centroid_weight = centroid_weight
        self.sparsity_weight = sparsity_weight
        self.dice_weight = dice_weight

    def __call__(self, pred, target, return_components=False):
        """Compute total loss for EfficientUNet.

        Args:
            pred (torch.Tensor): Predicted logits.
            target (torch.Tensor): Target heatmap.
            return_components (bool): If True, return (total_loss, components_dict).

        Returns:
            torch.Tensor or tuple: Loss value, or (loss, components_dict) if return_components=True.
        """
        bce_loss = super().__call__(pred, target)
        cdist_loss = self.centroid_distance_loss(pred, target)
        sparsity_loss_raw = torch.sigmoid(pred).mean()
        sparsity_loss = self.sparsity_weight * sparsity_loss_raw
        d_loss = self.dice_loss(pred, target)

        total_loss = bce_loss + (self.centroid_weight * cdist_loss) + sparsity_loss + (self.dice_weight * d_loss)

        if return_components:
            components = {
                'bce_loss': bce_loss.item(),
                'centroid_loss': cdist_loss.item(),
                'centroid_loss_weighted': (self.centroid_weight * cdist_loss).item(),
                'sparsity_loss': sparsity_loss_raw.item(),
                'sparsity_loss_weighted': sparsity_loss.item(),
                'dice_loss': d_loss.item(),
                'dice_loss_weighted': (self.dice_weight * d_loss).item(),
                'total_loss': total_loss.item()
            }
            return total_loss, components

        return total_loss

    @staticmethod
    def dice_loss(pred, target, smooth=1.0):
        """Compute Dice loss (1 - Dice coefficient).

        Args:
            pred (torch.Tensor): Logits from model.
            target (torch.Tensor): Ground truth mask.
            smooth (float): Smoothing factor.

        Returns:
            torch.Tensor: Dice loss value.
        """
        pred_probs = torch.sigmoid(pred)
        pred_flat = pred_probs.view(-1)
        target_flat = target.view(-1)
        intersection = (pred_flat * target_flat).sum()
        union = pred_flat.sum() + target_flat.sum()
        dice = (2. * intersection + smooth) / (union + smooth)
        return 1. - dice

    @staticmethod
    def centroid_distance_loss(pred, target):
        """Compute centroid distance loss between prediction and target.

        Args:
            pred (torch.Tensor): Predicted logits.
            target (torch.Tensor): Target heatmap.

        Returns:
            torch.Tensor: Mean centroid distance.
        """
        preds = get_centroids_per_sample(torch.sigmoid(pred))
        targets = get_centroids_per_sample(target)
        distances = []
        for p, t in zip(preds, targets):
            if p is not None and t is not None:
                x_p, y_p, _ = p
                x_t, y_t, _ = t
                dist = torch.sqrt((x_p - x_t) ** 2 + (y_p - y_t) ** 2 + 1e-8)
                distances.append(dist)
        if not distances:
            return torch.tensor(0.0).to(pred.device)
        return torch.stack(distances).mean()
dice_loss(pred, target, smooth=1.0) staticmethod

Compute Dice loss (1 - Dice coefficient).

Parameters:

Name Type Description Default
pred Tensor

Logits from model.

required
target Tensor

Ground truth mask.

required
smooth float

Smoothing factor.

1.0

Returns:

Type Description

torch.Tensor: Dice loss value.

Source code in src\aegear\nn\training.py
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
@staticmethod
def dice_loss(pred, target, smooth=1.0):
    """Compute Dice loss (1 - Dice coefficient).

    Args:
        pred (torch.Tensor): Logits from model.
        target (torch.Tensor): Ground truth mask.
        smooth (float): Smoothing factor.

    Returns:
        torch.Tensor: Dice loss value.
    """
    pred_probs = torch.sigmoid(pred)
    pred_flat = pred_probs.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum()
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1. - dice
centroid_distance_loss(pred, target) staticmethod

Compute centroid distance loss between prediction and target.

Parameters:

Name Type Description Default
pred Tensor

Predicted logits.

required
target Tensor

Target heatmap.

required

Returns:

Type Description

torch.Tensor: Mean centroid distance.

Source code in src\aegear\nn\training.py
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
@staticmethod
def centroid_distance_loss(pred, target):
    """Compute centroid distance loss between prediction and target.

    Args:
        pred (torch.Tensor): Predicted logits.
        target (torch.Tensor): Target heatmap.

    Returns:
        torch.Tensor: Mean centroid distance.
    """
    preds = get_centroids_per_sample(torch.sigmoid(pred))
    targets = get_centroids_per_sample(target)
    distances = []
    for p, t in zip(preds, targets):
        if p is not None and t is not None:
            x_p, y_p, _ = p
            x_t, y_t, _ = t
            dist = torch.sqrt((x_p - x_t) ** 2 + (y_p - y_t) ** 2 + 1e-8)
            distances.append(dist)
    if not distances:
        return torch.tensor(0.0).to(pred.device)
    return torch.stack(distances).mean()

SiameseLoss

Bases: EfficientUNetLoss

Siamese loss combining EfficientUNetLoss and RGB consistency loss.

Parameters:

Name Type Description Default
limit float

Threshold for positive region.

0.5
pos_weight float

Weight for positive region.

10.0
centroid_weight float

Weight for centroid distance loss.

0.0025
sparsity_weight float

Weight for sparsity loss.

0.001
dice_weight float

Weight for Dice loss.

1.0
rgb_weight float

Weight for RGB consistency loss.

0.005
rgb_sigma float

Sigma for Gaussian in RGB loss.

2.0
rgb_threshold float

Threshold for mask in RGB loss.

0.5
Source code in src\aegear\nn\training.py
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
class SiameseLoss(EfficientUNetLoss):
    """Siamese loss combining EfficientUNetLoss and RGB consistency loss.

    Args:
        limit (float): Threshold for positive region.
        pos_weight (float): Weight for positive region.
        centroid_weight (float): Weight for centroid distance loss.
        sparsity_weight (float): Weight for sparsity loss.
        dice_weight (float): Weight for Dice loss.
        rgb_weight (float): Weight for RGB consistency loss.
        rgb_sigma (float): Sigma for Gaussian in RGB loss.
        rgb_threshold (float): Threshold for mask in RGB loss.
    """
    def __init__(
        self,
        limit=0.5,
        pos_weight=10.0,
        centroid_weight=2.5e-3,
        sparsity_weight=1e-3,
        dice_weight=1.0,
        rgb_weight=5e-3,
        rgb_sigma=2.0,
        rgb_threshold=0.5
    ):
        super().__init__(limit, pos_weight, centroid_weight, sparsity_weight, dice_weight)
        self.rgb_weight = rgb_weight
        self.rgb_sigma = rgb_sigma
        self.rgb_threshold = rgb_threshold

    def __call__(self, output, target, template, search, return_components=False):
        """Compute total loss for Siamese model.

        Args:
            output (torch.Tensor): Predicted logits.
            target (torch.Tensor): Target heatmap.
            template (torch.Tensor): Template image.
            search (torch.Tensor): Search image.
            return_components (bool): If True, return (total_loss, components_dict).

        Returns:
            torch.Tensor or tuple: Loss value, or (loss, components_dict) if return_components=True.
        """
        if return_components:
            main_loss, main_components = super().__call__(output, target, return_components=True)
        else:
            main_loss = super().__call__(output, target, return_components=False)

        rgb_loss_raw = self.rgb_consistency_loss(template, search, output)
        rgb_loss = self.rgb_weight * rgb_loss_raw
        total_loss = main_loss + rgb_loss

        if return_components:
            components = main_components.copy()
            components['rgb_loss'] = rgb_loss_raw.item()
            components['rgb_loss_weighted'] = rgb_loss.item()
            components['total_loss'] = total_loss.item()
            return total_loss, components

        return total_loss

    def rgb_consistency_loss(self, template_img, search_img, pred_heatmap):
        """Compute RGB consistency loss between template and search images.

        Args:
            template_img (torch.Tensor): Template image tensor.
            search_img (torch.Tensor): Search image tensor.
            pred_heatmap (torch.Tensor): Predicted heatmap tensor.

        Returns:
            torch.Tensor: RGB consistency loss value.
        """
        B, _, H, W = template_img.shape
        device = template_img.device
        # Create fixed centered Gaussian for all batch
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(0, H - 1, H, device=device),
            torch.linspace(0, W - 1, W, device=device),
            indexing='ij'
        )
        center_y = (H - 1) / 2
        center_x = (W - 1) / 2
        gaussian = torch.exp(-((grid_x - center_x)**2 + (grid_y - center_y)**2) / (2 * self.rgb_sigma**2))
        gaussian /= gaussian.sum() + 1e-8
        gaussian = gaussian[None, None, :, :]  # shape (1, 1, H, W)
        loss = 0.0
        for i in range(B):
            # Mask and normalize predicted heatmap
            mask = (pred_heatmap[i] > self.rgb_threshold).float()
            weighted_mask = pred_heatmap[i] * mask
            weighted_mask /= weighted_mask.sum() + 1e-8  # (1, H, W)
            # Compute mean RGB in search
            rgb_search = (search_img[i] * weighted_mask).view(3, -1).sum(dim=1)
            # Compute mean RGB in template using Gaussian
            rgb_template = (template_img[i] * gaussian[0]).view(3, -1).sum(dim=1)
            loss += F.mse_loss(rgb_search, rgb_template)
        return loss / B
rgb_consistency_loss(template_img, search_img, pred_heatmap)

Compute RGB consistency loss between template and search images.

Parameters:

Name Type Description Default
template_img Tensor

Template image tensor.

required
search_img Tensor

Search image tensor.

required
pred_heatmap Tensor

Predicted heatmap tensor.

required

Returns:

Type Description

torch.Tensor: RGB consistency loss value.

Source code in src\aegear\nn\training.py
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
def rgb_consistency_loss(self, template_img, search_img, pred_heatmap):
    """Compute RGB consistency loss between template and search images.

    Args:
        template_img (torch.Tensor): Template image tensor.
        search_img (torch.Tensor): Search image tensor.
        pred_heatmap (torch.Tensor): Predicted heatmap tensor.

    Returns:
        torch.Tensor: RGB consistency loss value.
    """
    B, _, H, W = template_img.shape
    device = template_img.device
    # Create fixed centered Gaussian for all batch
    grid_y, grid_x = torch.meshgrid(
        torch.linspace(0, H - 1, H, device=device),
        torch.linspace(0, W - 1, W, device=device),
        indexing='ij'
    )
    center_y = (H - 1) / 2
    center_x = (W - 1) / 2
    gaussian = torch.exp(-((grid_x - center_x)**2 + (grid_y - center_y)**2) / (2 * self.rgb_sigma**2))
    gaussian /= gaussian.sum() + 1e-8
    gaussian = gaussian[None, None, :, :]  # shape (1, 1, H, W)
    loss = 0.0
    for i in range(B):
        # Mask and normalize predicted heatmap
        mask = (pred_heatmap[i] > self.rgb_threshold).float()
        weighted_mask = pred_heatmap[i] * mask
        weighted_mask /= weighted_mask.sum() + 1e-8  # (1, H, W)
        # Compute mean RGB in search
        rgb_search = (search_img[i] * weighted_mask).view(3, -1).sum(dim=1)
        # Compute mean RGB in template using Gaussian
        rgb_template = (template_img[i] * gaussian[0]).view(3, -1).sum(dim=1)
        loss += F.mse_loss(rgb_search, rgb_template)
    return loss / B

BaseVisualizer

Base class for visualizers used in training visualization.

Source code in src\aegear\nn\training.py
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
class BaseVisualizer:
    """Base class for visualizers used in training visualization."""
    def __init__(self, model, device, val_results, stage, epoch, output_dir="vis_epochs"):
        """Initialize the visualizer.

        Args:
            model: Model instance.
            device: Torch device.
            val_results (list): Validation results.
            stage (int): Training stage.
            epoch (int): Epoch number.
            output_dir (str): Output directory for visualizations.
        """
        self.model = model
        self.device = device
        self.val_results = val_results
        self.stage = stage
        self.epoch = epoch
        self.output_dir = output_dir

    def _save_fig(self, fig, subdir, prefix):
        """Save a matplotlib figure to disk.

        Args:
            fig: Matplotlib figure.
            subdir (str): Subdirectory for saving.
            prefix (str): Filename prefix.
        """
        os.makedirs(os.path.join(self.output_dir, subdir), exist_ok=True)
        path = os.path.join(
            self.output_dir,
            subdir,
            f"{prefix}_stage_{self.stage:03d}_epoch_{self.epoch:03d}.png"
        )
        fig.savefig(path, dpi=200)
        plt.close(fig)

SiameseTrackingVisualizer

Bases: BaseVisualizer

Visualizer for Siamese tracking model performance and activations.

Source code in src\aegear\nn\training.py
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
class SiameseTrackingVisualizer(BaseVisualizer):
    """Visualizer for Siamese tracking model performance and activations."""
    def performance(self, num_samples=5, save=True):
        """Visualize performance samples for Siamese tracking.

        Args:
            num_samples (int): Number of samples per group (worst, middle, best).
        """
        samples = _sort_samples(self.val_results, num_samples)
        fig, axes = plt.subplots(
            len(samples), 3, figsize=(9, 3 * len(samples)))


        for i, result in enumerate(samples):
            template_img = TF.to_pil_image(denormalize(result['template']))
            search_img = TF.to_pil_image(denormalize(result['search']))
            search_np = TF.to_tensor(search_img).permute(1, 2, 0).numpy()

            pred, gt = result['pred_heatmap'], result['gt_heatmap']
            pred = pred.detach().cpu()
            gt = gt.detach().cpu()
            pred_norm = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
            gt_norm = (gt - gt.min()) / (gt.max() - gt.min() + 1e-8)
            diff_norm = np.abs(pred_norm.numpy() - gt_norm.numpy())

            overlay = np.clip(0.6 * search_np + 0.4 *
                              plt.cm.jet(pred_norm)[..., :3], 0, 1)
            diff_rgb = plt.cm.magma(diff_norm)[..., :3]

            xg, yg = result['gt_centroid']
            xp, yp = result['pred_centroid']
            confidence = result['confidence']

            axes[i, 0].imshow(template_img)
            axes[i, 0].set_title(f"Template idx {i}")

            axes[i, 1].imshow(overlay)
            axes[i, 1].scatter([xp], [yp], c='red', marker='x', label='Pred')
            axes[i, 1].scatter([xg], [yg], c='green', marker='o', label='GT')
            axes[i, 1].set_title(f"Search | Conf: {confidence:.2f}")
            axes[i, 1].legend()

            axes[i, 2].imshow(diff_rgb)
            axes[i, 2].set_title("Abs Diff")

            for ax in axes[i]:
                ax.axis("off")

        plt.tight_layout()
        if save:
            self._save_fig(plt.gcf(), "performance", "epoch")

        return fig

    def activations(self, num_samples=3, save=True):
        """Visualize activations for Siamese tracking model.

        Args:
            num_samples (int): Number of samples to visualize.
        """
        output_dir = os.path.join(self.output_dir, "activations")
        os.makedirs(output_dir, exist_ok=True)

        stages = ['enc3', 'enc4', 'enc5', 'up4',
                  'up3', 'up2', 'up1', 'up0', 'out']
        channels_per_stage = 3

        activations = {}
        for name in stages:
            layer = getattr(self.model, name)
            layer.register_forward_hook(
                lambda m, i, o, n=name: activations.update({n: o.detach().cpu()}))

        samples = _sort_samples(self.val_results, num_samples)
        n_cols = 1 + channels_per_stage * len(stages)
        fig, axs = plt.subplots(len(samples), n_cols, figsize=(
            n_cols * 2.5, len(samples) * 3))
        axs = axs if len(samples) > 1 else axs[None, :]

        self.model.eval()
        for row, sample in enumerate(samples):
            template = sample['template'].unsqueeze(0).to(self.device)
            search = sample['search'].unsqueeze(0).to(self.device)
            heatmap = sample['pred_heatmap'].numpy()

            with torch.no_grad():
                _ = self.model(template, search)

            overlay = denormalize(search[0]).permute(1, 2, 0).cpu().numpy()
            overlay[..., 0] = np.clip(overlay[..., 0] + 0.5 * heatmap, 0, 1)

            axs[row, 0].imshow(overlay)
            axs[row, 0].scatter([sample['gt_centroid'][0]], [
                                sample['gt_centroid'][1]], c='green', marker='o')
            axs[row, 0].scatter([sample['pred_centroid'][0]], [
                                sample['pred_centroid'][1]], c='red', marker='x')
            axs[row, 0].set_title(f"Conf: {sample['confidence']:.2f}")
            axs[row, 0].axis('off')

            col = 1
            for stage in stages:
                act = activations[stage][0]
                for ch in range(channels_per_stage):
                    if ch < act.shape[0]:
                        axs[row, col].imshow(act[ch], cmap='viridis')
                        axs[row, col].set_title(f'{stage} | Ch {ch}')
                    axs[row, col].axis('off')
                    col += 1

        plt.tight_layout()
        if save:
            self._save_fig(plt.gcf(), "activations", "activation")

        return fig
performance(num_samples=5, save=True)

Visualize performance samples for Siamese tracking.

Parameters:

Name Type Description Default
num_samples int

Number of samples per group (worst, middle, best).

5
Source code in src\aegear\nn\training.py
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
def performance(self, num_samples=5, save=True):
    """Visualize performance samples for Siamese tracking.

    Args:
        num_samples (int): Number of samples per group (worst, middle, best).
    """
    samples = _sort_samples(self.val_results, num_samples)
    fig, axes = plt.subplots(
        len(samples), 3, figsize=(9, 3 * len(samples)))


    for i, result in enumerate(samples):
        template_img = TF.to_pil_image(denormalize(result['template']))
        search_img = TF.to_pil_image(denormalize(result['search']))
        search_np = TF.to_tensor(search_img).permute(1, 2, 0).numpy()

        pred, gt = result['pred_heatmap'], result['gt_heatmap']
        pred = pred.detach().cpu()
        gt = gt.detach().cpu()
        pred_norm = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
        gt_norm = (gt - gt.min()) / (gt.max() - gt.min() + 1e-8)
        diff_norm = np.abs(pred_norm.numpy() - gt_norm.numpy())

        overlay = np.clip(0.6 * search_np + 0.4 *
                          plt.cm.jet(pred_norm)[..., :3], 0, 1)
        diff_rgb = plt.cm.magma(diff_norm)[..., :3]

        xg, yg = result['gt_centroid']
        xp, yp = result['pred_centroid']
        confidence = result['confidence']

        axes[i, 0].imshow(template_img)
        axes[i, 0].set_title(f"Template idx {i}")

        axes[i, 1].imshow(overlay)
        axes[i, 1].scatter([xp], [yp], c='red', marker='x', label='Pred')
        axes[i, 1].scatter([xg], [yg], c='green', marker='o', label='GT')
        axes[i, 1].set_title(f"Search | Conf: {confidence:.2f}")
        axes[i, 1].legend()

        axes[i, 2].imshow(diff_rgb)
        axes[i, 2].set_title("Abs Diff")

        for ax in axes[i]:
            ax.axis("off")

    plt.tight_layout()
    if save:
        self._save_fig(plt.gcf(), "performance", "epoch")

    return fig
activations(num_samples=3, save=True)

Visualize activations for Siamese tracking model.

Parameters:

Name Type Description Default
num_samples int

Number of samples to visualize.

3
Source code in src\aegear\nn\training.py
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
def activations(self, num_samples=3, save=True):
    """Visualize activations for Siamese tracking model.

    Args:
        num_samples (int): Number of samples to visualize.
    """
    output_dir = os.path.join(self.output_dir, "activations")
    os.makedirs(output_dir, exist_ok=True)

    stages = ['enc3', 'enc4', 'enc5', 'up4',
              'up3', 'up2', 'up1', 'up0', 'out']
    channels_per_stage = 3

    activations = {}
    for name in stages:
        layer = getattr(self.model, name)
        layer.register_forward_hook(
            lambda m, i, o, n=name: activations.update({n: o.detach().cpu()}))

    samples = _sort_samples(self.val_results, num_samples)
    n_cols = 1 + channels_per_stage * len(stages)
    fig, axs = plt.subplots(len(samples), n_cols, figsize=(
        n_cols * 2.5, len(samples) * 3))
    axs = axs if len(samples) > 1 else axs[None, :]

    self.model.eval()
    for row, sample in enumerate(samples):
        template = sample['template'].unsqueeze(0).to(self.device)
        search = sample['search'].unsqueeze(0).to(self.device)
        heatmap = sample['pred_heatmap'].numpy()

        with torch.no_grad():
            _ = self.model(template, search)

        overlay = denormalize(search[0]).permute(1, 2, 0).cpu().numpy()
        overlay[..., 0] = np.clip(overlay[..., 0] + 0.5 * heatmap, 0, 1)

        axs[row, 0].imshow(overlay)
        axs[row, 0].scatter([sample['gt_centroid'][0]], [
                            sample['gt_centroid'][1]], c='green', marker='o')
        axs[row, 0].scatter([sample['pred_centroid'][0]], [
                            sample['pred_centroid'][1]], c='red', marker='x')
        axs[row, 0].set_title(f"Conf: {sample['confidence']:.2f}")
        axs[row, 0].axis('off')

        col = 1
        for stage in stages:
            act = activations[stage][0]
            for ch in range(channels_per_stage):
                if ch < act.shape[0]:
                    axs[row, col].imshow(act[ch], cmap='viridis')
                    axs[row, col].set_title(f'{stage} | Ch {ch}')
                axs[row, col].axis('off')
                col += 1

    plt.tight_layout()
    if save:
        self._save_fig(plt.gcf(), "activations", "activation")

    return fig

EfficientUNetVisualizer

Bases: BaseVisualizer

Visualizer for EfficientUNet model performance and activations.

Source code in src\aegear\nn\training.py
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
class EfficientUNetVisualizer(BaseVisualizer):
    """Visualizer for EfficientUNet model performance and activations."""
    def performance(self, num_samples=5, save=True):
        """Visualize performance samples for EfficientUNet.

        Args:
            num_samples (int): Number of samples per group (worst, middle, best).
        """
        samples = _sort_samples(self.val_results, num_samples)
        fig, axes = plt.subplots(
            len(samples), 3, figsize=(9, 3 * len(samples)))


        for i, result in enumerate(samples):
            search_img = TF.to_pil_image(denormalize(result['search']))
            search_np = TF.to_tensor(search_img).permute(1, 2, 0).numpy()

            pred, gt = result['pred_heatmap'], result['gt_heatmap']
            pred = pred.detach().cpu()
            gt = gt.detach().cpu()
            pred_norm = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
            gt_norm = (gt - gt.min()) / (gt.max() - gt.min() + 1e-8)
            diff_norm = np.abs(pred_norm.numpy() - gt_norm.numpy())

            overlay = np.clip(0.6 * search_np + 0.4 *
                              plt.cm.jet(pred_norm)[..., :3], 0, 1)
            diff_rgb = plt.cm.magma(diff_norm)[..., :3]

            xg, yg = result['gt_centroid']
            xp, yp = result['pred_centroid']
            confidence = result['confidence']

            axes[i, 0].imshow(overlay)
            axes[i, 0].scatter([xp], [yp], c='red', marker='x', label='Pred')
            axes[i, 0].scatter([xg], [yg], c='green', marker='o', label='GT')
            axes[i, 0].set_title(f"Search | Conf: {confidence:.2f}")
            axes[i, 0].legend()

            axes[i, 1].imshow(diff_rgb)
            axes[i, 1].set_title("Abs Diff")

            for ax in axes[i]:
                ax.axis("off")

        plt.tight_layout()

        if save:
            self._save_fig(plt.gcf(), "performance", "stage")

        return fig

    def activations(self, num_samples=3, save=True):
        """Visualize activations for EfficientUNet model.

        Args:
            num_samples (int): Number of samples to visualize.
        """
        output_dir = os.path.join(self.output_dir, "activations")
        os.makedirs(output_dir, exist_ok=True)

        stages = ['enc1', 'enc2', 'enc3', 'enc4', 'enc5',
                  'up4', 'up3', 'up2', 'up1', 'up0', 'out']
        channels_per_stage = 3

        activations = {}
        for name in stages:
            layer = getattr(self.model, name)
            layer.register_forward_hook(
                lambda m, i, o, n=name: activations.update({n: o.detach().cpu()}))

        samples = _sort_samples(self.val_results, num_samples)
        n_cols = 1 + channels_per_stage * len(stages)
        fig, axs = plt.subplots(len(samples), n_cols, figsize=(
            n_cols * 2.5, len(samples) * 3))
        axs = axs if len(samples) > 1 else axs[None, :]

        self.model.eval()
        for row, sample in enumerate(samples):
            search = sample['search'].unsqueeze(0).to(self.device)
            heatmap = sample['pred_heatmap'].numpy()

            with torch.no_grad():
                _ = self.model(search)

            overlay = denormalize(search[0]).permute(1, 2, 0).cpu().numpy()
            overlay[..., 0] = np.clip(overlay[..., 0] + 0.5 * heatmap, 0, 1)

            axs[row, 0].imshow(overlay)
            axs[row, 0].scatter([sample['gt_centroid'][0]], [
                                sample['gt_centroid'][1]], c='green', marker='o')
            axs[row, 0].scatter([sample['pred_centroid'][0]], [
                                sample['pred_centroid'][1]], c='red', marker='x')
            axs[row, 0].set_title(f"Conf: {sample['confidence']:.2f}")
            axs[row, 0].axis('off')
            axs[row, 0].legend()

            col = 1
            for stage in stages:
                act = activations[stage][0]
                for ch in range(channels_per_stage):
                    if ch < act.shape[0]:
                        axs[row, col].imshow(act[ch], cmap='viridis')
                        axs[row, col].set_title(f'{stage} | Ch {ch}')
                    axs[row, col].axis('off')
                    col += 1

        plt.tight_layout()
        if save:
            self._save_fig(plt.gcf(), "activations", "activation")

        return fig
performance(num_samples=5, save=True)

Visualize performance samples for EfficientUNet.

Parameters:

Name Type Description Default
num_samples int

Number of samples per group (worst, middle, best).

5
Source code in src\aegear\nn\training.py
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
def performance(self, num_samples=5, save=True):
    """Visualize performance samples for EfficientUNet.

    Args:
        num_samples (int): Number of samples per group (worst, middle, best).
    """
    samples = _sort_samples(self.val_results, num_samples)
    fig, axes = plt.subplots(
        len(samples), 3, figsize=(9, 3 * len(samples)))


    for i, result in enumerate(samples):
        search_img = TF.to_pil_image(denormalize(result['search']))
        search_np = TF.to_tensor(search_img).permute(1, 2, 0).numpy()

        pred, gt = result['pred_heatmap'], result['gt_heatmap']
        pred = pred.detach().cpu()
        gt = gt.detach().cpu()
        pred_norm = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
        gt_norm = (gt - gt.min()) / (gt.max() - gt.min() + 1e-8)
        diff_norm = np.abs(pred_norm.numpy() - gt_norm.numpy())

        overlay = np.clip(0.6 * search_np + 0.4 *
                          plt.cm.jet(pred_norm)[..., :3], 0, 1)
        diff_rgb = plt.cm.magma(diff_norm)[..., :3]

        xg, yg = result['gt_centroid']
        xp, yp = result['pred_centroid']
        confidence = result['confidence']

        axes[i, 0].imshow(overlay)
        axes[i, 0].scatter([xp], [yp], c='red', marker='x', label='Pred')
        axes[i, 0].scatter([xg], [yg], c='green', marker='o', label='GT')
        axes[i, 0].set_title(f"Search | Conf: {confidence:.2f}")
        axes[i, 0].legend()

        axes[i, 1].imshow(diff_rgb)
        axes[i, 1].set_title("Abs Diff")

        for ax in axes[i]:
            ax.axis("off")

    plt.tight_layout()

    if save:
        self._save_fig(plt.gcf(), "performance", "stage")

    return fig
activations(num_samples=3, save=True)

Visualize activations for EfficientUNet model.

Parameters:

Name Type Description Default
num_samples int

Number of samples to visualize.

3
Source code in src\aegear\nn\training.py
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
def activations(self, num_samples=3, save=True):
    """Visualize activations for EfficientUNet model.

    Args:
        num_samples (int): Number of samples to visualize.
    """
    output_dir = os.path.join(self.output_dir, "activations")
    os.makedirs(output_dir, exist_ok=True)

    stages = ['enc1', 'enc2', 'enc3', 'enc4', 'enc5',
              'up4', 'up3', 'up2', 'up1', 'up0', 'out']
    channels_per_stage = 3

    activations = {}
    for name in stages:
        layer = getattr(self.model, name)
        layer.register_forward_hook(
            lambda m, i, o, n=name: activations.update({n: o.detach().cpu()}))

    samples = _sort_samples(self.val_results, num_samples)
    n_cols = 1 + channels_per_stage * len(stages)
    fig, axs = plt.subplots(len(samples), n_cols, figsize=(
        n_cols * 2.5, len(samples) * 3))
    axs = axs if len(samples) > 1 else axs[None, :]

    self.model.eval()
    for row, sample in enumerate(samples):
        search = sample['search'].unsqueeze(0).to(self.device)
        heatmap = sample['pred_heatmap'].numpy()

        with torch.no_grad():
            _ = self.model(search)

        overlay = denormalize(search[0]).permute(1, 2, 0).cpu().numpy()
        overlay[..., 0] = np.clip(overlay[..., 0] + 0.5 * heatmap, 0, 1)

        axs[row, 0].imshow(overlay)
        axs[row, 0].scatter([sample['gt_centroid'][0]], [
                            sample['gt_centroid'][1]], c='green', marker='o')
        axs[row, 0].scatter([sample['pred_centroid'][0]], [
                            sample['pred_centroid'][1]], c='red', marker='x')
        axs[row, 0].set_title(f"Conf: {sample['confidence']:.2f}")
        axs[row, 0].axis('off')
        axs[row, 0].legend()

        col = 1
        for stage in stages:
            act = activations[stage][0]
            for ch in range(channels_per_stage):
                if ch < act.shape[0]:
                    axs[row, col].imshow(act[ch], cmap='viridis')
                    axs[row, col].set_title(f'{stage} | Ch {ch}')
                axs[row, col].axis('off')
                col += 1

    plt.tight_layout()
    if save:
        self._save_fig(plt.gcf(), "activations", "activation")

    return fig

setup_logging(log_level=logging.INFO)

Set up logging for training.

Parameters:

Name Type Description Default
log_level int

Logging level.

INFO

Returns:

Type Description

logging.Logger: Configured logger instance.

Source code in src\aegear\nn\training.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def setup_logging(log_level=logging.INFO):
    """Set up logging for training.

    Args:
        log_level (int): Logging level.

    Returns:
        logging.Logger: Configured logger instance.
    """
    logging.basicConfig(
        format='%(asctime)s %(levelname)s %(message)s',
        level=log_level
    )
    logger = logging.getLogger("aegear.train")
    return logger

get_device()

Get the best available torch device (MPS, CUDA, or CPU).

Returns:

Type Description

torch.device: The selected device.

Source code in src\aegear\nn\training.py
53
54
55
56
57
58
59
60
61
62
63
64
def get_device():
    """Get the best available torch device (MPS, CUDA, or CPU).

    Returns:
        torch.device: The selected device.
    """
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

load_datasets(cache_dir, datasets, batch_size=128, gaussian_sigma=15.0)

Load training and validation datasets and create DataLoaders.

Parameters:

Name Type Description Default
cache_dir str

Directory containing cached datasets.

required
datasets list

List of dataset names.

required
batch_size int

Batch size for DataLoader.

128
gaussian_sigma float

Sigma for Gaussian heatmap.

15.0

Returns:

Name Type Description
tuple

(train_loader, val_loader, train_dataset, val_dataset)

Source code in src\aegear\nn\training.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def load_datasets(cache_dir, datasets, batch_size=128, gaussian_sigma=15.0):
    """Load training and validation datasets and create DataLoaders.

    Args:
        cache_dir (str): Directory containing cached datasets.
        datasets (list): List of dataset names.
        batch_size (int): Batch size for DataLoader.
        gaussian_sigma (float): Sigma for Gaussian heatmap.

    Returns:
        tuple: (train_loader, val_loader, train_dataset, val_dataset)
    """
    train_dataset = ConcatDataset([
        CachedTrackingDataset(os.path.join(cache_dir, name, "train"), gaussian_sigma=gaussian_sigma)
        for name in datasets
    ])
    val_dataset = ConcatDataset([
        CachedTrackingDataset(os.path.join(cache_dir, name, "val"), gaussian_sigma=gaussian_sigma)
        for name in datasets
    ])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, train_dataset, val_dataset

setup_model(weights='IMAGENET1K_V1', continue_training=False, use_best_model=False, model_dir='../data/training/models/efficient_unet', pretrained_model_dir='../models/', device=None, **model_kwargs)

Set up the EfficientUNet model for training or fine-tuning.

Parameters:

Name Type Description Default
weights str

Pretrained weights identifier.

'IMAGENET1K_V1'
continue_training bool

Whether to continue training from a checkpoint.

False
use_best_model bool

Use the best model checkpoint.

False
model_dir str

Directory for saving/loading models.

'../data/training/models/efficient_unet'
pretrained_model_dir str

Directory for pretrained models.

'../models/'
device device

Device to load model on.

None
**model_kwargs

Additional model arguments.

{}

Returns:

Name Type Description
EfficientUNet

Initialized model.

Source code in src\aegear\nn\training.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def setup_model(weights="IMAGENET1K_V1", continue_training=False, use_best_model=False, model_dir="../data/training/models/efficient_unet", pretrained_model_dir="../models/", device=None, **model_kwargs):
    """Set up the EfficientUNet model for training or fine-tuning.

    Args:
        weights (str): Pretrained weights identifier.
        continue_training (bool): Whether to continue training from a checkpoint.
        use_best_model (bool): Use the best model checkpoint.
        model_dir (str): Directory for saving/loading models.
        pretrained_model_dir (str): Directory for pretrained models.
        device (torch.device): Device to load model on.
        **model_kwargs: Additional model arguments.

    Returns:
        EfficientUNet: Initialized model.
    """
    model = EfficientUNet(weights=weights, **model_kwargs)
    logger = logging.getLogger("aegear.train")
    if continue_training:
        if use_best_model:
            best_model_path = os.path.join(model_dir, "best_model.pth")
            assert os.path.exists(best_model_path)
        else:
            unet_model_filename = "model_efficient_unet"
            best_model_path = get_latest_model_path(pretrained_model_dir, unet_model_filename)
        logger.info(f"Continuing training of the UNet model from: {best_model_path}")
        model.load_state_dict(torch.load(best_model_path, map_location=device), strict=False)
    else:
        logger.info("Training a new UNet model from ImageNet weights.")
    model.to(device)
    return model

freeze_model_layers(model, freeze_layers)

Freeze specified layers in the model (set requires_grad=False).

Parameters:

Name Type Description Default
model

Model instance.

required
freeze_layers list

List of layers (or names) to freeze.

required
Source code in src\aegear\nn\training.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def freeze_model_layers(model, freeze_layers):
    """Freeze specified layers in the model (set requires_grad=False).

    Args:
        model: Model instance.
        freeze_layers (list): List of layers (or names) to freeze.
    """
    for param in model.parameters():
        param.requires_grad = True
    for layer in freeze_layers:
        # If layer is a string, resolve to model attribute
        if isinstance(layer, str):
            resolved_layer = model
            for attr in layer.split('.'):
                resolved_layer = getattr(resolved_layer, attr)
            layer_obj = resolved_layer
        else:
            layer_obj = layer

        for param in layer_obj.parameters():
            param.requires_grad = False

set_layers_eval(model, layers)

Set specified layers to evaluation mode.

Parameters:

Name Type Description Default
model

Model instance.

required
layers list

List of layers (or names) to set to eval mode.

required
Source code in src\aegear\nn\training.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def set_layers_eval(model, layers):
    """Set specified layers to evaluation mode.

    Args:
        model: Model instance.
        layers (list): List of layers (or names) to set to eval mode.
    """
    for layer in layers:
        # If layer is a string, resolve to model attribute
        if isinstance(layer, str):
            resolved_layer = model
            for attr in layer.split('.'):
                resolved_layer = getattr(resolved_layer, attr)
            layer_obj = resolved_layer
        else:
            layer_obj = layer
        layer_obj.eval()

load_training_stages(model, stages_path=None)

Load training stages from a JSON file and resolve layer names to model attributes. Args: model: The model instance (EfficientUNet or SiameseTracker). stages_path: Path to the JSON file. Returns: List of training stage dicts with freeze_layers resolved.

Source code in src\aegear\nn\training.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def load_training_stages(model, stages_path=None):
    """
    Load training stages from a JSON file and resolve layer names to model attributes.
    Args:
        model: The model instance (EfficientUNet or SiameseTracker).
        stages_path: Path to the JSON file.
    Returns:
        List of training stage dicts with freeze_layers resolved.
    """
    if not os.path.exists(stages_path):
        raise IOError(f"Training stages file not found: {stages_path}")

    with open(stages_path, 'r') as f:
        training_stages = json.load(f)

    for stage in training_stages:
        if 'freeze_layers' in stage:
            resolved_layers = []
            for layer_name in stage['freeze_layers']:
                layer = model
                for attr in layer_name.split('.'):
                    layer = getattr(layer, attr)
                resolved_layers.append(layer)
            stage['freeze_layers'] = resolved_layers

    return training_stages

get_default_training_stages(model_name, epochs=10, lr=0.0001)

Return default training stages for the given model name ('efficient_unet' or 'siamese'). The returned format matches what load_training_stages expects (layer names as strings). Args: model_name: 'efficient_unet' or 'siamese' epochs: Number of epochs for the stage(s). lr: Learning rate for the stage(s). Returns: List of training stage dicts with freeze_layers as strings.

Source code in src\aegear\nn\training.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def get_default_training_stages(model_name: str, epochs: int = 10, lr: float = 1e-4):
    """
    Return default training stages for the given model name ('efficient_unet' or 'siamese').
    The returned format matches what load_training_stages expects (layer names as strings).
    Args:
        model_name: 'efficient_unet' or 'siamese'
        epochs: Number of epochs for the stage(s).
        lr: Learning rate for the stage(s).
    Returns:
        List of training stage dicts with freeze_layers as strings.
    """
    if model_name == "efficient_unet":
        return [
            {
                "freeze_layers": ["enc1", "enc2", "enc3", "enc4"],
                "epochs": epochs,
                "lr": lr,
            }
        ]
    elif model_name == "siamese":
        # Always use positive integer epochs for both stages
        stage1_epochs = max(1, int(epochs))
        stage2_epochs = max(1, int(epochs // 2))
        return [
            {
                "freeze_layers": ["enc1", "enc2", "enc3", "enc4", "enc5"],
                "epochs": stage1_epochs,
                "lr": 5.0 * lr,
            },
            {
                "freeze_layers": ["enc1", "enc2", "enc3", "enc4"],
                "epochs": stage2_epochs,
                "lr": lr,
            }
        ]
    else:
        raise ValueError(f"Unknown model_name: {model_name}")

collect_val_results(val_batches, device)

Collect validation results from batches for visualization and metrics.

Parameters:

Name Type Description Default
val_batches list

List of validation batches.

required
device device

Device for tensor operations.

required

Returns:

Name Type Description
list

List of dicts containing results per sample.

Raises:

Type Description
ValueError

If a batch does not have 3 or 4 elements (unexpected batch size).

Source code in src\aegear\nn\training.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def collect_val_results(val_batches, device):
    """
    Collect validation results from batches for visualization and metrics.

    Args:
        val_batches (list): List of validation batches.
        device (torch.device): Device for tensor operations.

    Returns:
        list: List of dicts containing results per sample.

    Raises:
        ValueError: If a batch does not have 3 or 4 elements (unexpected batch size).
    """
    val_results = []
    # Generalized for both EfficientUNet and SiameseTracker
    for batch in val_batches:
        if len(batch) == 3:
            search, target, output = batch
            template = None
        elif len(batch) == 4:
            template, search, target, output = batch
        else:
            raise ValueError(f"Unexpected batch size {len(batch)} in validation results. Expected 3 or 4.")

        pred_resized = F.interpolate(torch.sigmoid(output), size=search.shape[-2:], mode='bilinear', align_corners=False)
        target_resized = F.interpolate(target, size=search.shape[-2:], mode='bilinear', align_corners=False)
        centroids_pred = get_centroids_per_sample(pred_resized)
        centroids_gt = get_centroids_per_sample(target_resized)

        for i in range(search.size(0)):
            p = centroids_pred[i]
            t = centroids_gt[i]
            if p is None or t is None:
                continue
            x_pred, y_pred, confidence = p
            x_gt, y_gt, _ = t
            xp, yp = x_pred.item(), y_pred.item()
            xg, yg = x_gt.item(), y_gt.item()
            dist = np.sqrt((xp - xg) ** 2 + (yp - yg) ** 2)
            result = {
                'search': search[i].cpu(),
                'gt_heatmap': target_resized[i, 0].cpu(),
                'pred_heatmap': pred_resized[i, 0].cpu(),
                'gt_centroid': (xg, yg),
                'pred_centroid': (xp, yp),
                'confidence': confidence.item(),
                'distance': dist,
            }
            if template is not None:
                result['template'] = template[i].cpu()
            val_results.append(result)
    return val_results

get_model_type(model, explicit_type=None)

Determine the model type ('efficient_unet' or 'siamese') from the model instance or explicit argument.

Parameters:

Name Type Description Default
model

Model instance.

required
explicit_type str

Explicit model type if provided.

None

Returns:

Name Type Description
str

Model type ('efficient_unet' or 'siamese').

Source code in src\aegear\nn\training.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def get_model_type(model, explicit_type=None):
    """
    Determine the model type ('efficient_unet' or 'siamese') from the model instance or explicit argument.

    Args:
        model: Model instance.
        explicit_type (str, optional): Explicit model type if provided.

    Returns:
        str: Model type ('efficient_unet' or 'siamese').
    """
    if explicit_type:
        return explicit_type
    name = model.__class__.__name__.lower()
    if name.startswith("efficientunet"):
        return "efficient_unet"
    elif name.startswith("siamesetracker"):
        return "siamese"
    raise ValueError("Unknown model type for training.")

process_train_batch(model, batch, model_type, device, loss_fn, return_components=False)

Process a single training batch for the given model type.

Parameters:

Name Type Description Default
model

Model instance.

required
batch

Batch data from DataLoader.

required
model_type str

Model type ('efficient_unet' or 'siamese').

required
device

Torch device.

required
loss_fn

Loss function.

required
return_components bool

If True, return loss components.

False

Returns:

Name Type Description
tuple

(loss, output) or (loss, output, components) if return_components=True

Source code in src\aegear\nn\training.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def process_train_batch(model, batch, model_type, device, loss_fn, return_components=False):
    """
    Process a single training batch for the given model type.

    Args:
        model: Model instance.
        batch: Batch data from DataLoader.
        model_type (str): Model type ('efficient_unet' or 'siamese').
        device: Torch device.
        loss_fn: Loss function.
        return_components (bool): If True, return loss components.

    Returns:
        tuple: (loss, output) or (loss, output, components) if return_components=True
    """
    if model_type == "efficient_unet":
        if len(batch) == 2:
            search, target = batch
            output = model(search.to(device))
            if return_components:
                loss, components = loss_fn(output, target.to(device), return_components=True)
                return loss, output, components
            else:
                loss = loss_fn(output, target.to(device))
        else:
            _, search, target = batch
            output = model(search.to(device))
            if return_components:
                loss, components = loss_fn(output, target.to(device), return_components=True)
                return loss, output, components
            else:
                loss = loss_fn(output, target.to(device))
    elif model_type == "siamese":
        template, search, target = batch
        output = model(template.to(device), search.to(device))
        if return_components:
            loss, components = loss_fn(output, target.to(device), template.to(device), search.to(device), return_components=True)
            return loss, output, components
        else:
            loss = loss_fn(output, target.to(device), template.to(device), search.to(device))
    else:
        raise ValueError("Unknown model_type in training loop.")
    return loss, output

process_val_batch(model, batch, model_type, device, loss_fn, return_components=False)

Process a single validation batch for the given model type.

Parameters:

Name Type Description Default
model

Model instance.

required
batch

Batch data from DataLoader.

required
model_type str

Model type ('efficient_unet' or 'siamese').

required
device

Torch device.

required
loss_fn

Loss function.

required
return_components bool

If True, return loss components.

False

Returns:

Name Type Description
tuple

(loss, batch_tuple) or (loss, batch_tuple, components) if return_components=True

Source code in src\aegear\nn\training.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def process_val_batch(model, batch, model_type, device, loss_fn, return_components=False):
    """
    Process a single validation batch for the given model type.

    Args:
        model: Model instance.
        batch: Batch data from DataLoader.
        model_type (str): Model type ('efficient_unet' or 'siamese').
        device: Torch device.
        loss_fn: Loss function.
        return_components (bool): If True, return loss components.

    Returns:
        tuple: (loss, batch_tuple) or (loss, batch_tuple, components) if return_components=True
    """
    if model_type == "efficient_unet":
        if len(batch) == 2:
            search, target = batch
            output = model(search.to(device))
            if return_components:
                loss, components = loss_fn(output, target.to(device), return_components=True)
                return loss, (search, target, output), components
            else:
                loss = loss_fn(output, target.to(device))
                return loss, (search, target, output)
        else:
            _, search, target = batch
            output = model(search.to(device))
            if return_components:
                loss, components = loss_fn(output, target.to(device), return_components=True)
                return loss, (search, target, output), components
            else:
                loss = loss_fn(output, target.to(device))
                return loss, (search, target, output)
    elif model_type == "siamese":
        template, search, target = batch
        output = model(template.to(device), search.to(device))
        if return_components:
            loss, components = loss_fn(output, target.to(device), template.to(device), search.to(device), return_components=True)
            return loss, (template, search, target, output), components
        else:
            loss = loss_fn(output, target.to(device), template.to(device), search.to(device))
            return loss, (template, search, target, output)
    else:
        raise ValueError("Unknown model_type in validation loop.")

get_visualizer(model_type, model, device, val_results, stage, epoch, output_dir)

Get the appropriate visualizer instance for the model type.

Parameters:

Name Type Description Default
model_type str

Model type ('efficient_unet' or 'siamese').

required
model

Model instance.

required
device

Torch device.

required
val_results

Validation results.

required
stage int

Training stage index.

required
epoch int

Epoch index.

required
output_dir str

Directory for visualizer outputs.

required

Returns:

Type Description

object or None: Visualizer instance or None if not applicable.

Source code in src\aegear\nn\training.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def get_visualizer(model_type, model, device, val_results, stage, epoch, output_dir):
    """
    Get the appropriate visualizer instance for the model type.

    Args:
        model_type (str): Model type ('efficient_unet' or 'siamese').
        model: Model instance.
        device: Torch device.
        val_results: Validation results.
        stage (int): Training stage index.
        epoch (int): Epoch index.
        output_dir (str): Directory for visualizer outputs.

    Returns:
        object or None: Visualizer instance or None if not applicable.
    """
    if model_type == "efficient_unet":
        return EfficientUNetVisualizer(model, device, val_results, stage, epoch, output_dir=output_dir)
    elif model_type == "siamese":
        return SiameseTrackingVisualizer(model, device, val_results, stage, epoch, output_dir=output_dir)
    return None

create_scheduler(optimizer, scheduler_config, **kwargs)

Create a PyTorch LR scheduler from a config dict. Args: optimizer: Optimizer instance. scheduler_config (dict): Dict with 'type' and scheduler-specific kwargs. train_loader: DataLoader (needed for OneCycleLR). epochs: Number of epochs (needed for OneCycleLR). Returns: torch.optim.lr_scheduler._LRScheduler or ReduceLROnPlateau

Source code in src\aegear\nn\training.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
def create_scheduler(optimizer, scheduler_config, **kwargs):
    """
    Create a PyTorch LR scheduler from a config dict.
    Args:
        optimizer: Optimizer instance.
        scheduler_config (dict): Dict with 'type' and scheduler-specific kwargs.
        train_loader: DataLoader (needed for OneCycleLR).
        epochs: Number of epochs (needed for OneCycleLR).
    Returns:
        torch.optim.lr_scheduler._LRScheduler or ReduceLROnPlateau
    """
    if scheduler_config is None:
        return None
    sched_type = scheduler_config.get('type', 'ReduceLROnPlateau')
    params = scheduler_config.get('params', {})

    # For OneCycleLR, require epochs as argument
    if sched_type == 'OneCycleLR':
        steps_per_epoch = kwargs.get('steps_per_epoch', None)
        epochs = kwargs.get('epochs', None)

        if steps_per_epoch is None or epochs is None:
            raise ValueError("steps_per_epoch and epochs must be provided for OneCycleLR scheduler.")

        return torch.optim.lr_scheduler.OneCycleLR(optimizer, steps_per_epoch=steps_per_epoch, epochs=epochs, **params)
    elif sched_type == 'ReduceLROnPlateau':
        return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **params)
    elif sched_type == 'StepLR':
        return torch.optim.lr_scheduler.StepLR(optimizer, **params)
    elif sched_type == 'CosineAnnealingLR':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **params)
    else:
        raise ValueError(f"Unknown scheduler type: {sched_type}")

scheduler_config_to_json(scheduler_config)

Serialize scheduler config dict to JSON string.

Source code in src\aegear\nn\training.py
446
447
448
def scheduler_config_to_json(scheduler_config):
    """Serialize scheduler config dict to JSON string."""
    return json.dumps(scheduler_config)

scheduler_config_from_json(json_str)

Deserialize scheduler config dict from JSON string.

Source code in src\aegear\nn\training.py
450
451
452
def scheduler_config_from_json(json_str):
    """Deserialize scheduler config dict from JSON string."""
    return json.loads(json_str)

get_epoch_progress_message(current_epoch, total_epochs, epoch_time, epoch_times)

Generate a progress message for the current epoch.

Source code in src\aegear\nn\training.py
454
455
456
457
458
459
460
461
462
463
464
def get_epoch_progress_message(current_epoch, total_epochs, epoch_time, epoch_times):
    """Generate a progress message for the current epoch."""
    avg_epoch_time = sum(epoch_times) / len(epoch_times)
    epochs_left = total_epochs - current_epoch
    eta = avg_epoch_time * epochs_left
    eta_str = time.strftime('%H:%M:%S', time.gmtime(eta))
    epoch_time_str = time.strftime('%H:%M:%S', time.gmtime(epoch_time))

    return (f"Epoch {current_epoch}/{total_epochs} completed. "
        f"Time: {epoch_time_str}. "
        f"ETA: {eta_str}.")

compute_validation_metrics(model, val_loader, device, model_type, loss_fn=None)

Compute validation metrics: average centroid distance, confidence, within-radius percentages, and loss components.

Parameters:

Name Type Description Default
model

The trained model.

required
val_loader

Validation data loader.

required
device

Torch device.

required
model_type str

Model type ('efficient_unet' or 'siamese').

required
loss_fn

Optional loss function to compute loss components.

None

Returns:

Name Type Description
dict

Dictionary containing metrics and loss components.

Source code in src\aegear\nn\training.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
def compute_validation_metrics(model, val_loader, device, model_type, loss_fn=None):
    """
    Compute validation metrics: average centroid distance, confidence, within-radius percentages, and loss components.

    Args:
        model: The trained model.
        val_loader: Validation data loader.
        device: Torch device.
        model_type (str): Model type ('efficient_unet' or 'siamese').
        loss_fn: Optional loss function to compute loss components.

    Returns:
        dict: Dictionary containing metrics and loss components.
    """
    model.eval()
    total_distances = []
    total_confidences = []
    within_radius = {3: 0, 5: 0, 10: 0}
    n_samples = 0
    loss_components_sum = {}
    n_batches = 0

    with torch.no_grad():
        for batch in tqdm(
            val_loader,
            desc="Computing validation metrics",
            leave=False,
            disable=_should_disable_progress()
        ):
            if model_type == "efficient_unet":
                if len(batch) == 2:
                    imgs, heatmaps = batch
                    imgs = imgs.to(device)
                    heatmaps = heatmaps.to(device)
                    preds = torch.sigmoid(model(imgs))
                else:
                    imgs, _, heatmaps = batch
                    imgs = imgs.to(device)
                    heatmaps = heatmaps.to(device)
                    preds = torch.sigmoid(model(imgs))
            elif model_type == "siamese":
                template, search, heatmaps = batch
                template = template.to(device)
                search = search.to(device)
                heatmaps = heatmaps.to(device)
                preds = torch.sigmoid(model(template, search))
            else:
                raise ValueError(f"Unknown model_type: {model_type}")

            centroids_pred = get_centroids_per_sample(preds)
            centroids_gt = get_centroids_per_sample(heatmaps)

            # Compute loss components if loss function is provided
            if loss_fn is not None:
                if model_type == "efficient_unet":
                    # Need logits for loss computation (recompute to get logits, not sigmoid)
                    logits = model(imgs)
                    _, components = loss_fn(logits, heatmaps, return_components=True)
                elif model_type == "siamese":
                    logits = model(template, search)
                    _, components = loss_fn(logits, heatmaps, template, search, return_components=True)

                # Accumulate components (only weighted versions and bce_loss/total_loss)
                for key, value in components.items():
                    if '_weighted' in key or key == 'total_loss' or key == 'bce_loss':
                        if key not in loss_components_sum:
                            loss_components_sum[key] = 0.0
                        loss_components_sum[key] += value
                n_batches += 1

            for i in range(len(imgs) if model_type == "efficient_unet" else len(template)):
                p = centroids_pred[i]
                t = centroids_gt[i]

                if p is None or t is None:
                    continue

                x_pred, y_pred, confidence = p
                x_gt, y_gt, _ = t

                xp, yp = x_pred.item(), y_pred.item()
                xg, yg = x_gt.item(), y_gt.item()
                confidence = confidence.item()

                dist = np.sqrt((xp - xg) ** 2 + (yp - yg) ** 2)
                total_distances.append(dist)
                total_confidences.append(confidence)

                for r in within_radius:
                    if dist <= r:
                        within_radius[r] += 1
                n_samples += 1

    if n_samples == 0:
        return {
            'avg_distance': 0.0,
            'avg_confidence': 0.0,
            'within_3px': 0.0,
            'within_5px': 0.0,
            'within_10px': 0.0,
            'n_samples': 0
        }

    metrics = {
        'avg_distance': np.mean(total_distances),
        'avg_confidence': np.mean(total_confidences),
        'within_3px': within_radius[3] / n_samples,
        'within_5px': within_radius[5] / n_samples,
        'within_10px': within_radius[10] / n_samples,
        'n_samples': n_samples
    }

    # Add averaged loss components to metrics
    if loss_components_sum and n_batches > 0:
        for key, value in loss_components_sum.items():
            metrics[f'loss/{key}'] = value / n_batches

    return metrics

save_model_with_clearml(model, path, clearml_task=None, artifact_name=None, metadata=None)

Save a model checkpoint and register it with ClearML if a task is provided.

Parameters:

Name Type Description Default
model

The PyTorch model to save.

required
path str

Path to save the model.

required
clearml_task

ClearML Task object or None.

None
artifact_name str

Name for the artifact in ClearML.

None
metadata dict

Optional metadata to attach.

None
Source code in src\aegear\nn\training.py
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
def save_model_with_clearml(model, path, clearml_task=None, artifact_name=None, metadata=None):
    """
    Save a model checkpoint and register it with ClearML if a task is provided.

    Args:
        model: The PyTorch model to save.
        path (str): Path to save the model.
        clearml_task: ClearML Task object or None.
        artifact_name (str): Name for the artifact in ClearML.
        metadata (dict): Optional metadata to attach.
    """
    torch.save(model.state_dict(), path)
    if clearml_task is not None:
        name = artifact_name if artifact_name else os.path.basename(path)
        clearml_task.upload_artifact(
            name=name,
            artifact_object=path,
            metadata=metadata
        )

train(model, train_loader, val_loader, device, model_dir, checkpoint_dir, epoch_vis, training_stages, loss_fn=None, epoch_save_interval=1, model_type=None, use_visualizer=False, weight_decay=0.005, clearml_task=None, scheduler_config=None)

Unified training function for EfficientUNet and SiameseTracker. model_type: 'efficient_unet' or 'siamese'. If None, inferred from model class name.

Source code in src\aegear\nn\training.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
def train(
    model,
    train_loader,
    val_loader,
    device,
    model_dir,
    checkpoint_dir,
    epoch_vis,
    training_stages,
    loss_fn=None,
    epoch_save_interval=1,
    model_type=None,
    use_visualizer=False,
    weight_decay=5e-3,
    clearml_task=None,
    scheduler_config=None,
):
    """
    Unified training function for EfficientUNet and SiameseTracker.
    model_type: 'efficient_unet' or 'siamese'. If None, inferred from model class name.
    """
    # Ensure model_dir and checkpoint_dir exist
    if model_dir and not os.path.exists(model_dir):
        os.makedirs(model_dir, exist_ok=True)

    if checkpoint_dir and not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)

    # Log configuration before training starts
    config_dict = {
        'model_type': get_model_type(model, model_type),
        'model_dir': model_dir,
        'checkpoint_dir': checkpoint_dir,
        'epoch_vis': epoch_vis,
        'epoch_save_interval': epoch_save_interval,
        'weight_decay': weight_decay,
        'use_visualizer': use_visualizer,
        'scheduler_config': scheduler_config,
        'training_stages': training_stages,
        'loss_fn': loss_fn.__class__.__name__ if loss_fn is not None else None,
    }
    if clearml_task is not None:
        logger = clearml_task.get_logger()
        logger.report_text(f"Training configuration:\n{json.dumps(config_dict, indent=2)}", logging.INFO, iteration=0)
    else:
        logging.getLogger("aegear.train").info(f"Training configuration:\n{json.dumps(config_dict, indent=2)}")

    best_val_loss = float('inf')
    losses = []
    model_type = get_model_type(model, model_type)

    suppress_progress = (clearml_task is not None) or _should_disable_progress()

    try:
        total_train_batches = len(train_loader)
    except TypeError:
        total_train_batches = None
    try:
        total_val_batches = len(val_loader)
    except TypeError:
        total_val_batches = None

    for stage, training_stage in enumerate(training_stages):
        freeze_layers = training_stage["freeze_layers"]
        epochs = training_stage["epochs"]
        freeze_model_layers(model, freeze_layers)
        optimizer = torch.optim.Adam(model.parameters(), lr=training_stage["lr"], weight_decay=weight_decay)
        # Use scheduler_config if provided, else default ReduceLROnPlateau
        scfg = scheduler_config if scheduler_config is not None else {
            'type': 'ReduceLROnPlateau',
            'params': {'mode': 'min', 'factor': 0.5, 'patience': 3}
        }
        # For OneCycleLR, pass epochs as kwarg
        if scfg.get('type') == 'OneCycleLR':
            scheduler = create_scheduler(optimizer, scfg, steps_per_epoch=len(train_loader), epochs=epochs)
        else:
            scheduler = create_scheduler(optimizer, scfg)

        epoch_times = []
        total_epochs = epochs
        for epoch in range(epochs):
            epoch_start = time.time()
            model.train()
            set_layers_eval(model, freeze_layers)
            train_loss = 0.0
            train_loss_components = {}
            if suppress_progress:
                logging.getLogger("aegear.train").info(
                    f"Stage {stage + 1}, Training epoch {epoch + 1}/{epochs}"
                )
            train_bar = tqdm(
                train_loader,
                desc=f"Stage {stage + 1}, Training {epoch + 1}",
                leave=False,
                disable=suppress_progress
            )
            for batch_idx, batch in enumerate(train_bar, start=1):
                result = process_train_batch(model, batch, model_type, device, loss_fn, return_components=(clearml_task is not None))
                if clearml_task is not None:
                    loss, _, components = result
                    # Accumulate components
                    for key, value in components.items():
                        if key not in train_loss_components:
                            train_loss_components[key] = 0.0
                        train_loss_components[key] += value
                else:
                    loss, _ = result

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Step OneCycleLR scheduler per batch
                if scheduler is not None and isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
                    scheduler.step()

                train_loss += loss.item()
                if not suppress_progress:
                    train_bar.set_postfix(loss=loss.item())
                elif total_train_batches and batch_idx % BATCH_UPDATE_FREQUENCY == 0:
                    logging.getLogger("aegear.train").info(
                        f"Stage {stage + 1} epoch {epoch + 1}: batch {batch_idx}/{total_train_batches} loss={loss.item():.4f}"
                    )

            model.eval()
            val_loss = 0.0
            val_loss_components = {}
            val_batches = []
            if suppress_progress:
                logging.getLogger("aegear.train").info(
                    f"Stage {stage + 1}, Validation epoch {epoch + 1}/{epochs}"
                )
            val_bar = tqdm(
                val_loader,
                desc=f"Stage {stage + 1}, Validation {epoch + 1}",
                leave=False,
                disable=suppress_progress
            )

            with torch.no_grad():
                for batch_idx, batch in enumerate(val_bar, start=1):
                    result = process_val_batch(model, batch, model_type, device, loss_fn, return_components=(clearml_task is not None))
                    if clearml_task is not None:
                        loss, val_batch, components = result
                        # Accumulate components
                        for key, value in components.items():
                            if key not in val_loss_components:
                                val_loss_components[key] = 0.0
                            val_loss_components[key] += value
                    else:
                        loss, val_batch = result

                    # Detach all tensors in val_batch before storing
                    if isinstance(val_batch, tuple):
                        detached_batch = tuple(x.detach().cpu() if torch.is_tensor(x) else x for x in val_batch)
                    else:
                        detached_batch = val_batch.detach().cpu() if torch.is_tensor(val_batch) else val_batch
                    val_batches.append(detached_batch)
                    val_loss += loss.item()
                    if not suppress_progress:
                        val_bar.set_postfix(loss=loss.item())
                    elif total_val_batches and batch_idx % BATCH_UPDATE_FREQUENCY == 0:
                        logging.getLogger("aegear.train").info(
                            f"Stage {stage + 1} validation epoch {epoch + 1}: batch {batch_idx}/{total_val_batches} loss={loss.item():.4f}"
                        )

            train_loss /= len(train_loader)
            val_loss /= len(val_loader)
            val_results = collect_val_results(val_batches, device)
            losses.append((train_loss, val_loss))
            # Step scheduler depending on type (OneCycleLR is stepped per batch, not per epoch)
            if scheduler is not None:
                if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                elif not isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
                    scheduler.step()

            # Logging to ClearML if available
            if clearml_task is not None:
                logger = clearml_task.get_logger()
                logger.report_scalar("loss", "train", iteration=epoch+1, value=train_loss)
                logger.report_scalar("loss", "val", iteration=epoch+1, value=val_loss)

                # Log individual loss components
                if train_loss_components:
                    num_train_batches = len(train_loader)
                    for key, value in train_loss_components.items():
                        avg_value = value / num_train_batches
                        logger.report_scalar(f"loss_components/train", key, iteration=epoch+1, value=avg_value)

                if val_loss_components:
                    num_val_batches = len(val_loader)
                    for key, value in val_loss_components.items():
                        avg_value = value / num_val_batches
                        logger.report_scalar(f"loss_components/val", key, iteration=epoch+1, value=avg_value)

                # Log sample images from visualizer
                if use_visualizer:
                    visualizer = get_visualizer(model_type, model, device, val_results, stage, epoch, epoch_vis)
                    if visualizer:
                        perf_fig = visualizer.performance(num_samples=5, save=False)
                        act_fig = visualizer.activations(num_samples=3, save=False)

                        logger.report_matplotlib_figure(
                            title="Sample Evaluation",
                            series=f"Epoch {epoch+1}, Stage {stage+1}",
                            iteration=epoch+1,
                            report_image=True,  # These are samples evaluated, so we report them like debug samples.
                            figure=perf_fig
                        )
                        plt.close(perf_fig)

                        logger.report_matplotlib_figure(
                            title="Activation Evaluation",
                            series=f"Epoch {epoch+1}, Stage {stage+1}",
                            iteration=epoch+1,
                            report_image=False,  # These are plots to be inspected (show up among 'Plots' in ClearML UI)
                            figure=act_fig
                        )
                        plt.close(act_fig)
            else:
                # Fallback to stdout logging and tqdm
                logging.getLogger("aegear.train").info(f"Epoch {epoch+1}/{epochs} - Train: {train_loss:.4f} | Val: {val_loss:.4f}")
                if use_visualizer:
                    visualizer = get_visualizer(model_type, model, device, val_results, stage, epoch, epoch_vis)
                    if visualizer:
                        visualizer.performance(num_samples=5)
                        visualizer.activations(num_samples=3)


            # --- Epoch timing and ETA logging ---
            epoch_end = time.time()
            epoch_time = epoch_end - epoch_start
            epoch_times.append(epoch_time)
            msg = get_epoch_progress_message(epoch+1, total_epochs, epoch_time, epoch_times)
            if clearml_task is not None:
                logger.report_text(msg, logging.INFO, iteration=epoch+1)
            else:
                logging.getLogger("aegear.train").info(msg)
            # --- End epoch timing and ETA logging ---

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                logging.getLogger("aegear.train").info("New best model, saving.")
                save_model_with_clearml(
                    model,
                    f'{model_dir}/best_model.pth',
                    clearml_task,
                    artifact_name="best_model",
                    metadata={
                        "type": "best",
                        "stage": stage,
                        "epoch": epoch+1,
                        "val_loss": val_loss
                    }
                )
            if (epoch + 1) % epoch_save_interval == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f'model_stage_{stage+1}_epoch_{epoch+1}.pth')
                save_model_with_clearml(
                    model,
                    checkpoint_path,
                    clearml_task,
                    artifact_name=f"checkpoint_stage_{stage+1}_epoch_{epoch+1}",
                    metadata={
                        "type": "checkpoint",
                        "stage": stage,
                        "epoch": epoch+1,
                        "val_loss": val_loss
                    }
                )

    # Compute final validation metrics on the best model
    logger_obj = logging.getLogger("aegear.train")
    logger_obj.info("Computing final validation metrics on best model...")

    # Load best model
    best_model_path = f'{model_dir}/best_model.pth'
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        model.to(device)

        # Compute metrics (pass loss_fn to get loss components)
        final_metrics = compute_validation_metrics(model, val_loader, device, model_type, loss_fn=loss_fn)

        # Log metrics
        logger_obj.info(f"\nFinal Validation Metrics:")
        logger_obj.info(f"  Average centroid distance: {final_metrics['avg_distance']:.2f} px")
        logger_obj.info(f"  Average confidence: {final_metrics['avg_confidence']:.4f}")
        logger_obj.info(f"  Within 3px: {final_metrics['within_3px']:.2%}")
        logger_obj.info(f"  Within 5px: {final_metrics['within_5px']:.2%}")
        logger_obj.info(f"  Within 10px: {final_metrics['within_10px']:.2%}")
        logger_obj.info(f"  Total samples: {final_metrics['n_samples']}")

        # Log loss components if available
        loss_component_keys = [k for k in final_metrics.keys() if k.startswith('loss/')]
        if loss_component_keys:
            logger_obj.info(f"\nFinal Loss Components:")
            for key in sorted(loss_component_keys):
                clean_name = key.replace('loss/', '')
                logger_obj.info(f"  {clean_name}: {final_metrics[key]:.6f}")

        # Log to ClearML if available
        if clearml_task is not None:
            logger = clearml_task.get_logger()
            logger.report_single_value("final_avg_distance", final_metrics['avg_distance'])
            logger.report_single_value("final_avg_confidence", final_metrics['avg_confidence'])
            logger.report_single_value("final_within_3px", final_metrics['within_3px'])
            logger.report_single_value("final_within_5px", final_metrics['within_5px'])
            logger.report_single_value("final_within_10px", final_metrics['within_10px'])
            logger.report_single_value("final_n_samples", final_metrics['n_samples'])

            # Log final loss components to ClearML Summary
            for key in loss_component_keys:
                clean_name = key.replace('loss/', '')
                logger.report_single_value(f"final_loss_{clean_name}", final_metrics[key])
    else:
        logger_obj.warning(f"Best model not found at {best_model_path}, skipping final metrics computation.")

    return losses

get_confidence(heatmap)

Get confidence score from a heatmap by finding the maximum value.

Parameters:

Name Type Description Default
heatmap Tensor

Heatmap tensor of shape (B, 1, H, W).

required

Returns:

Name Type Description
float

Confidence score (max value in heatmap).

Source code in src\aegear\nn\training.py
926
927
928
929
930
931
932
933
934
935
936
937
938
939
def get_confidence(heatmap):
    """Get confidence score from a heatmap by finding the maximum value.

    Args:
        heatmap (torch.Tensor): Heatmap tensor of shape (B, 1, H, W).

    Returns:
        float: Confidence score (max value in heatmap).
    """
    b, _, _, w = heatmap.shape
    flat_idx = torch.argmax(heatmap.view(b, -1), dim=1)
    y = flat_idx // w
    x = flat_idx % w
    return heatmap[0, 0, y, x].item()

overlay_heatmap_on_rgb(rgb_tensor, heatmap, alpha=0.5, centroid_color=(0, 1, 0))

Overlay heatmap onto RGB image and draw a circle at the predicted centroid.

Parameters:

Name Type Description Default
rgb_tensor Tensor

RGB image tensor of shape (3, H, W).

required
heatmap ndarray

Heatmap array of shape (H, W).

required
alpha float

Blending weight for overlay.

0.5
centroid_color tuple

(R, G, B) color for centroid (0-1 range).

(0, 1, 0)

Returns:

Type Description

np.ndarray: Overlay image of shape (H, W, 3).

Source code in src\aegear\nn\training.py
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
def overlay_heatmap_on_rgb(rgb_tensor, heatmap, alpha=0.5, centroid_color=(0, 1, 0)):
    """Overlay heatmap onto RGB image and draw a circle at the predicted centroid.

    Args:
        rgb_tensor (torch.Tensor): RGB image tensor of shape (3, H, W).
        heatmap (np.ndarray): Heatmap array of shape (H, W).
        alpha (float): Blending weight for overlay.
        centroid_color (tuple): (R, G, B) color for centroid (0-1 range).

    Returns:
        np.ndarray: Overlay image of shape (H, W, 3).
    """
    rgb = rgb_tensor.permute(1, 2, 0).cpu().numpy()
    rgb = rgb * 0.229 + 0.485
    rgb = rgb.clip(0, 1)

    heatmap_color = plt.cm.hot(heatmap)[..., :3]
    overlay = (1 - alpha) * rgb + alpha * heatmap_color

    # Find centroid
    flat_idx = heatmap.reshape(-1).argmax()
    h, w = heatmap.shape
    cy = flat_idx // w
    cx = flat_idx % w

    # Draw circle
    overlay_uint8 = (overlay * 255).astype(np.uint8)
    cx_int, cy_int = int(cx), int(cy)
    color_bgr = tuple(int(c * 255) for c in reversed(centroid_color))
    cv2.circle(overlay_uint8, (cx_int, cy_int), 4, color_bgr, thickness=1)

    return overlay_uint8 / 255.0

denormalize(img_tensor, clamp=True)

Denormalize an image tensor using ImageNet mean and std.

Parameters:

Name Type Description Default
img_tensor Tensor

Normalized image tensor.

required
clamp bool

Whether to clamp output to [0, 1].

True

Returns:

Type Description

torch.Tensor: Denormalized image tensor.

Source code in src\aegear\nn\training.py
976
977
978
979
980
981
982
983
984
985
986
987
988
989
def denormalize(img_tensor, clamp=True):
    """Denormalize an image tensor using ImageNet mean and std.

    Args:
        img_tensor (torch.Tensor): Normalized image tensor.
        clamp (bool): Whether to clamp output to [0, 1].

    Returns:
        torch.Tensor: Denormalized image tensor.
    """
    mean = torch.tensor([0.485, 0.456, 0.406], device=img_tensor.device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=img_tensor.device).view(3, 1, 1)
    out = img_tensor * std + mean
    return out.clamp(0, 1) if clamp else out

get_centroids_per_sample(heatmap)

Get centroids from a batch of heatmaps.

Parameters:

Name Type Description Default
heatmap Tensor

Batch of heatmaps (B, 1, H, W).

required

Returns:

Name Type Description
list

List of (x, y, confidence) tuples or None per sample.

Source code in src\aegear\nn\training.py
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
def get_centroids_per_sample(heatmap):
    """Get centroids from a batch of heatmaps.

    Args:
        heatmap (torch.Tensor): Batch of heatmaps (B, 1, H, W).

    Returns:
        list: List of (x, y, confidence) tuples or None per sample.
    """
    b, _, _, w = heatmap.shape
    heatmaps = heatmap.squeeze(1)
    centroids = []
    for i in range(b):
        hm = heatmaps[i]
        hm_sum = hm.mean().item()
        if hm_sum < 1e-8:
            centroids.append(None)
        else:
            flat_idx = torch.argmax(hm)
            y = flat_idx // w
            x = flat_idx % w
            conf = hm[y, x]
            centroids.append((x.float(), y.float(), conf.float()))
    return centroids

tracker

Prediction

A class to represent a prediction made by the model.

Source code in src\aegear\tracker.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Prediction:
    """A class to represent a prediction made by the model."""

    def __init__(self, confidence, centroid, roi=None):
        """Initialize the prediction.

        Parameters
        ----------

        confidence : float
            The confidence of the prediction.
        centroid : tuple
            The centroid of the prediction.
        roi : np.ndarray
            The region of interest of the prediction.
        """

        self.centroid = centroid
        self.confidence = confidence
        self.roi = roi

    def global_coordinates(self, origin):
        x, y = origin

        confidence = self.confidence
        centroid = self.centroid

        return Prediction(
            confidence,
            (centroid[0] + x, centroid[1] + y),
            self.roi,
        )

FishTracker

Source code in src\aegear\tracker.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
class FishTracker:

    # Original window size for the training data.
    WINDOW_SIZE = 128
    # The size of the tracking window.
    TRACKER_WINDOW_SIZE = 128
    # Threshold for movement mask to consider valid movement.
    MOVEMENT_MASK_THRESHOLD = 0.1

    def __init__(self,
                 heatmap_model_path,
                 siamese_model_path,
                 tracking_threshold=0.9,
                 detection_threshold=0.85,
                 search_stride=0.5,
                 tracking_max_skip=10,
                 debug=False):

        self._debug = debug
        self._stride = search_stride
        self._device = FishTracker._select_device()
        self._transform = FishTracker._init_transform()
        self.heatmap_model = self._init_heatmap_model(heatmap_model_path)
        self.siamese_model = self._init_siamese_model(siamese_model_path)
        self.tracking_threshold = tracking_threshold
        self.detection_threshold = detection_threshold
        self.tracking_max_skip = tracking_max_skip

        self.last_result = None
        self.history = []
        self.frame_size = None

    def run_tracking(self,
                     video: VideoClip,
                     start_frame: int,
                     end_frame: int,
                     model_track_register,
                     progress_reporter: Optional[ProgressReporter] = None,
                     ui_update=None):
        """Run the tracking on a video."""

        bgs = self._init_background_subtractor(video, start_frame)
        current_skip = self.tracking_max_skip
        anchor_frame = start_frame

        self.last_result = None

        def progress_still_running(
        ): return progress_reporter is not None and progress_reporter.still_running()

        while anchor_frame < end_frame and progress_still_running():
            candidate = anchor_frame + current_skip
            if candidate >= end_frame:
                break

            # Read and pre‑process the candidate.
            frame = video.get_frame(float(candidate) / video.fps)
            if frame is None:
                break

            result = self._track_frame(
                frame, mask=self._motion_detection(bgs, frame))

            if result is not None:
                # Store this result for further tracking.
                self.last_result = result
                model_track_register(
                    candidate, result.centroid, result.confidence)

                anchor_frame = candidate

                if progress_reporter is not None:
                    progress_reporter.update(anchor_frame)

                if current_skip < self.tracking_max_skip:
                    current_skip = min(
                        current_skip * 2, self.tracking_max_skip)
            else:
                if self.last_result is not None and current_skip > 1:
                    current_skip = max(current_skip // 2, 1)
                    continue

                anchor_frame = candidate
                self.last_result = None

            if ui_update is not None:
                ui_update(anchor_frame)

    def _select_device():
        """Select the device - try CUDA, if fails, try mps for Apple Silicon, else CPU."""
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")

    def _init_transform():
        """Initialize the transform."""
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def _init_heatmap_model(self, model_path):
        """Initialize the model."""
        model = EfficientUNet(weights=None)
        model.load_state_dict(torch.load(
            model_path, map_location=self._device))
        model.to(self._device)

        # Set the model to evaluation mode
        model.eval()
        return model

    def _init_siamese_model(self, model_path):
        """Initialize the siamese tracking model."""
        model = SiameseTracker()
        model.load_state_dict(torch.load(
            model_path, map_location=self._device))
        model.to(self._device)

        # Set the model to evaluation mode
        model.eval()
        return model

    def _track_frame(self, frame, mask=None):
        """Track the fish in the given frame.

        Parameters
        ----------

        frame : np.ndarray
            The frame to track the fish in.
        mask : np.ndarray, optional
            The mask to use for tracking. If None, the whole frame is used.

        Returns
        -------

        Prediction or None
            The prediction made by the model, or None if no fish is detected.
        """
        if self.frame_size is None:
            self.frame_size = frame.shape[:2]

        self._debug_print("track")

        if self.last_result is None:
            self._debug_print("sliding")
            # Do a sliding window over the whole frame to try and find our fish.
            result = self._sliding_window_predict(frame, mask)

            if result is not None:
                prediction = result

                prediction.roi = self._tracking_roi(
                    frame, prediction.centroid)[1]

                return prediction
        else:
            self._debug_print("tracking")
            # Try getting a ROI around the last position.
            (x1, y1), current_roi = self._tracking_roi(
                frame, self.last_result.centroid)
            result = self._evaluate_siamese_model(
                self.last_result.roi, current_roi)

            if result is not None:
                prediction = result.global_coordinates((x1, y1))

                prediction.roi = self._tracking_roi(
                    frame, prediction.centroid)[1]

                # Figure out mask values at the tracked point.
                if mask is not None:
                    w = FishTracker.TRACKER_WINDOW_SIZE // 2
                    mask_values = mask[prediction.centroid[1] - w: prediction.centroid[1] + w, prediction.centroid[0] - w: prediction.centroid[0] + w].mean()
                    if mask_values < FishTracker.MOVEMENT_MASK_THRESHOLD:
                        self._debug_print("Tracking: Mask check failed")
                        return None


                self._debug_print(
                    f"Found fish at ({result.centroid}) with confidence {result.confidence}")

                return prediction

        return None

    def _tracking_roi(self, frame, centroid):
        """Get the tracking ROI around the centroid."""
        x, y = centroid
        h, w = frame.shape[:2]
        w_t = self.TRACKER_WINDOW_SIZE // 2

        # Clamp center so that full ROI fits in frame
        x = max(w_t, min(x, w - w_t))
        y = max(w_t, min(y, h - w_t))

        x1 = int(x - w_t)
        y1 = int(y - w_t)
        x2 = int(x + w_t)
        y2 = int(y + w_t)

        return (x1, y1), frame[y1:y2, x1:x2]

    def _init_background_subtractor(self, video: VideoClip, start_frame: int, history=50, dist2threshold=500, warmup=20):
        """Initialize the background subtractor."""
        background_subtractor = cv2.createBackgroundSubtractorKNN(
            history=history, dist2Threshold=dist2threshold, detectShadows=False)

        # Warm up the background subtractor with a few frames.
        for fid in range(max(start_frame - warmup, 0), start_frame):
            t = float(fid) / video.fps
            f = video.get_frame(t)
            if f is None:
                continue

            gframe = cv2.cvtColor(f, cv2.COLOR_RGB2GRAY)
            gframe = cv2.GaussianBlur(gframe, (5, 5), 1.0)

            background_subtractor.apply(gframe, learningRate=0.25)

        return background_subtractor

    def _motion_detection(self, bgs, frame):
        """Detect motion in the frame using the background subtractor."""

        gframe = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        gframe = cv2.GaussianBlur(gframe, (5, 5), 1.0)

        mask = bgs.apply(gframe, learningRate=0.125)

        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, k)

        return mask

    def _sliding_window_predict(self, frame, mask=None) -> Optional[Prediction]:
        """
        Do a sliding window over the whole frame to try and find our fish.

        Parameters
        ----------
        frame : np.ndarray
            The frame to do the sliding window over.

        Returns
        -------

        list
            A list of predictions made by the model.

        """

        h, w = frame.shape[:2]
        results = []

        win_size = self.WINDOW_SIZE
        stride = int(self._stride * win_size)

        for y in range(0, h, stride):
            for x in range(0, w, stride):

                if mask is not None:
                    mask_roi = mask[y:y+win_size, x:x+win_size]
                    mask_sum = mask_roi.sum()

                    # Check if the window is in the mask.
                    if mask_sum == 0:
                        continue

                try:
                    window = frame[y:y+win_size, x:x+win_size]
                except:
                    # If we go out of bounds, we skip this window.
                    continue

                if window.shape[0] != win_size or window.shape[1] != win_size:
                    continue

                result = self._evaluate_heatmap_model(window)

                if not result:
                    continue

                # Map out the global coordinates of the predictions.
                results.append(result.global_coordinates((x, y)))

        if results:
            self._debug_print(f"Got {len(results)} results")

            # Sort by score
            results.sort(key=lambda x: x.confidence, reverse=True)

            # Get the best result
            result = results[0]

            if result.confidence < self.detection_threshold:

                self._debug_print(
                    f"Best candidate confidence {result.confidence} is below threshold {self.detection_threshold}")
                return None

            return result  # Return the best result

        self._debug_print(f"Not a single sliding window found a fish")

        return None

    def _get_centroid(heatmap):
        if heatmap.sum() < 1e-6:
            return None

        b, _, _, w = heatmap.shape
        flat_idx = torch.argmax(heatmap.view(b, -1), dim=1)
        y = flat_idx // w
        x = flat_idx % w

        # Get confidence at the centroid
        confidence = heatmap[0, 0, y, x].item()

        return confidence, (x.int().item(), y.int().item())

    def _evaluate_heatmap_model(self, window) -> Prediction:
        """Evaluate the model on a window of the image.
        Note that this returns the prediction in window local space. For global space
        adjust the centroid and box coordinates accordingly using the origin of the window.
        """

        # Prepare the input.
        input = self._transform(window) \
                    .to(self._device) \
                    .unsqueeze(0)

        try:
            output = torch.sigmoid(self.heatmap_model(input))
        except Exception as e:
            self._debug_print(f"Error in model evaluation: {e}")
            # If we get an error, we just return None.
            return None

        result = FishTracker._get_centroid(output)

        if result is None:
            self._debug_print("Heatmap: No fish detected")
            return None

        (confidence, centroid) = result

        return Prediction(confidence, centroid)

    def _evaluate_siamese_model(self, last_roi, current_roi) -> Prediction:

        # Prepare the input.
        template = self._transform(last_roi) \
            .to(self._device) \
            .unsqueeze(0)

        search = self._transform(current_roi) \
            .to(self._device) \
            .unsqueeze(0)

        try:
            output = torch.sigmoid(self.siamese_model(template, search))
        except Exception as e:
            self._debug_print(f"Siamese: Error in model evaluation: {e}")
            # If we get an error, we just return None.
            return None

        result = FishTracker._get_centroid(output)

        if result is None:
            self._debug_print("Siamese: No fish detected")
            return None

        (confidence, centroid) = result

        if confidence < self.tracking_threshold:
            self._debug_print(
                f"Siamese: Confidence {confidence} is below threshold {self.tracking_threshold}")
            return None

        return Prediction(confidence, centroid, roi=None)

    def _debug_print(self, msg):
        if self._debug:
            print(msg)

run_tracking(video, start_frame, end_frame, model_track_register, progress_reporter=None, ui_update=None)

Run the tracking on a video.

Source code in src\aegear\tracker.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def run_tracking(self,
                 video: VideoClip,
                 start_frame: int,
                 end_frame: int,
                 model_track_register,
                 progress_reporter: Optional[ProgressReporter] = None,
                 ui_update=None):
    """Run the tracking on a video."""

    bgs = self._init_background_subtractor(video, start_frame)
    current_skip = self.tracking_max_skip
    anchor_frame = start_frame

    self.last_result = None

    def progress_still_running(
    ): return progress_reporter is not None and progress_reporter.still_running()

    while anchor_frame < end_frame and progress_still_running():
        candidate = anchor_frame + current_skip
        if candidate >= end_frame:
            break

        # Read and pre‑process the candidate.
        frame = video.get_frame(float(candidate) / video.fps)
        if frame is None:
            break

        result = self._track_frame(
            frame, mask=self._motion_detection(bgs, frame))

        if result is not None:
            # Store this result for further tracking.
            self.last_result = result
            model_track_register(
                candidate, result.centroid, result.confidence)

            anchor_frame = candidate

            if progress_reporter is not None:
                progress_reporter.update(anchor_frame)

            if current_skip < self.tracking_max_skip:
                current_skip = min(
                    current_skip * 2, self.tracking_max_skip)
        else:
            if self.last_result is not None and current_skip > 1:
                current_skip = max(current_skip // 2, 1)
                continue

            anchor_frame = candidate
            self.last_result = None

        if ui_update is not None:
            ui_update(anchor_frame)

trajectory

Utility functions for working with 2D trajectories in image frames, including drawing, smoothing, and computing properties of motion paths.

Assumes trajectory is a list of (x, y) pixel coordinates sampled at video frame rate.

smooth_trajectory(trajectory, filterSize=15)

Apply Savitzky-Golay filter to smooth a trajectory.

Parameters:

Name Type Description Default
trajectory list of (t, x, y)

Frame id with raw trajectory points.

required
filterSize int

Window size for filtering (must be odd and >= 5).

15

Returns:

Type Description
list[tuple[int, int, int]]

list of (t, x, y): Smoothed trajectory points.

Source code in src\aegear\trajectory.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def smooth_trajectory(trajectory: list[tuple[int, int, int]], filterSize: int = 15) -> list[tuple[int, int, int]]:
    """
    Apply Savitzky-Golay filter to smooth a trajectory.

    Parameters:
        trajectory (list of (t, x, y)): Frame id with raw trajectory points.
        filterSize (int): Window size for filtering (must be odd and >= 5).

    Returns:
        list of (t, x, y): Smoothed trajectory points.
    """
    # Ensure filterSize is odd and at least 5 (polyorder=3, so min window=5)
    if filterSize < 5:
        filterSize = 5
    if filterSize % 2 == 0:
        filterSize += 1
    if len(trajectory) < filterSize:
        return trajectory

    trajectory = np.array(trajectory)
    t = savgol_filter(trajectory[:, 0], filterSize, 3)
    x = savgol_filter(trajectory[:, 1], filterSize, 3)
    y = savgol_filter(trajectory[:, 2], filterSize, 3)

    smoothed = list(zip(t.astype(int), x.astype(int), y.astype(int)))
    return smoothed

detect_trajectory_outliers(trajectory, threshold=20.0)

Detects large jumps in pixel space, indicating likely tracking failures.

Parameters:

Name Type Description Default
trajectory list[tuple[int, int, int]]

List of (frame_idx, x, y) tuples.

required
threshold float

Maximum allowed pixel movement per frame.

20.0

Returns:

Type Description
list[int]

List of frame indices where jump exceeds threshold.

Source code in src\aegear\trajectory.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def detect_trajectory_outliers(
    trajectory: list[tuple[int, int, int]],
    threshold: float = 20.0  # distance in pixels per frame
) -> list[int]:
    """
    Detects large jumps in pixel space, indicating likely tracking failures.

    Args:
        trajectory: List of (frame_idx, x, y) tuples.
        threshold: Maximum allowed pixel movement per frame.

    Returns:
        List of frame indices where jump exceeds threshold.
    """
    if len(trajectory) < 2:
        return []

    frame_idx, xs, ys = zip(*trajectory)
    xs = np.array(xs)
    ys = np.array(ys)
    frame_idx = np.array(frame_idx)

    dx = np.diff(xs)
    dy = np.diff(ys)
    dist = np.sqrt(dx**2 + dy**2)

    # Mark current frame if jump from previous is too large
    outlier_mask = dist > threshold
    outlier_frames = frame_idx[1:][outlier_mask]  # current frame that made the jump

    return list(outlier_frames)

utils

Kalman2D

A simple 2D Kalman filter for tracking.

Source code in src\aegear\utils.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class Kalman2D:
    """A simple 2D Kalman filter for tracking."""

    def __init__(self, r=1.0, q=0.1):
        """Initialize the Kalman filter.

        Parameters
        ----------
        r : float
            The measurement noise.
        q : float
            The process noise.
        """
        self.x = np.zeros((4, 1))  # state
        self.P = np.eye(4) * 1000  # uncertainty

        self.A = np.array([[1, 0, 1, 0],
                           [0, 1, 0, 1],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]])

        self.H = np.array([[1, 0, 0, 0],
                           [0, 1, 0, 0]])

        self.R = np.eye(2) * r # measurement noise
        self.Q = np.eye(4) * q # process noise

    def reset(self, x, y):
        self.x = np.array([[x], [y], [0], [0]])
        self.P = np.eye(4)

    def update(self, z):
        # Predict
        self.x = self.A @ self.x
        self.P = self.A @ self.P @ self.A.T + self.Q

        # Update
        z = np.array(z).reshape(2, 1)
        y = z - self.H @ self.x
        S = self.H @ self.P @ self.H.T + self.R
        K = self.P @ self.H.T @ np.linalg.inv(S)

        self.x = self.x + K @ y
        self.P = (np.eye(4) - K @ self.H) @ self.P

        return self.x[0, 0], self.x[1, 0]

resource_path(relative_path)

Get the absolute path to the resource, works for dev and PyInstaller.

Source code in src\aegear\utils.py
16
17
18
19
20
21
22
23
def resource_path(relative_path: str) -> Path:
    """Get the absolute path to the resource, works for dev and PyInstaller."""
    try:
        base_path = Path(sys._MEIPASS)
    except AttributeError:
        # Go two levels up from aegear/app.py → project root
        base_path = Path(__file__).resolve().parents[2]
    return base_path / relative_path

get_latest_model_path(directory, model_name)

Find the latest model file in the given directory matching the base model name. Model files are expected to be named as: modelname_YYYY-MM-DD.pth

Source code in src\aegear\utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def get_latest_model_path(directory, model_name) -> str:
    """
    Find the latest model file in the given directory matching the base model name.
    Model files are expected to be named as: modelname_YYYY-MM-DD.pth
    """
    pattern = re.compile(rf"{re.escape(model_name)}_(\d{{4}}-\d{{2}}-\d{{2}})\.pth")
    latest_date = None
    latest_file = None

    for filename in os.listdir(directory):
        match = pattern.fullmatch(filename)
        if match:
            date_str = match.group(1)
            try:
                file_date = datetime.strptime(date_str, "%Y-%m-%d")
                if latest_date is None or file_date > latest_date:
                    latest_date = file_date
                    latest_file = filename
            except ValueError:
                continue

    return os.path.join(directory, latest_file) if latest_file else None

load_model_with_weights(model_class, model_path, device)

Load a model with weights from a checkpoint.

Parameters

model_class : torch.nn.Module Model class to instantiate model_path : str Path to model checkpoint device : str Device to load model on ('cuda' or 'cpu')

Returns

torch.nn.Module Loaded model

Source code in src\aegear\utils.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def load_model_with_weights(model_class, model_path, device):
    """Load a model with weights from a checkpoint.

    Parameters
    ----------
    model_class : torch.nn.Module
        Model class to instantiate
    model_path : str
        Path to model checkpoint
    device : str
        Device to load model on ('cuda' or 'cpu')

    Returns
    -------
    torch.nn.Module
        Loaded model
    """
    model = model_class()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

download_dataset(dataset_dir, dataset_type='tracking')

Download dataset from GCS if not already present.

Parameters

dataset_dir : str Directory to download the dataset to dataset_type : str Type of dataset to download ('tracking' or 'detection')

Source code in src\aegear\utils.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def download_dataset(dataset_dir, dataset_type="tracking"):
    """Download dataset from GCS if not already present.

    Parameters
    ----------
    dataset_dir : str
        Directory to download the dataset to
    dataset_type : str
        Type of dataset to download ('tracking' or 'detection')
    """
    bucket_name = "aegear-training-data"
    blob_path = f"cache/{dataset_type}.zip"

    dataset_path = os.path.join(dataset_dir, dataset_type)

    if os.path.exists(dataset_path):
        print(f"{dataset_type.capitalize()} dataset already exists. Skipping download.")
        return

    print(f"{dataset_type.capitalize()} dataset not found. Downloading...")
    zip_path = os.path.join(dataset_dir, f"{dataset_type}.zip")
    os.makedirs(dataset_dir, exist_ok=True)

    # Download the zip file if it doesn't exist
    if not os.path.exists(zip_path):
        print(f"Downloading gs://{bucket_name}/{blob_path} to {zip_path}...")

        # Initialize anonymous GCS client for public data
        client = storage.Client.create_anonymous_client()
        bucket = client.bucket(bucket_name)
        blob = bucket.blob(blob_path)
        blob.download_to_filename(zip_path)

        print("Download complete.")

    # Unzip the file
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(dataset_dir)
        print(f"Extracted to {dataset_dir}")

video

VideoClip

Minimalistic video clip class for reading video files.

Source code in src\aegear\video.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class VideoClip:
    """Minimalistic video clip class for reading video files."""
    def __init__(self, path):
        self.path = path
        self._cap = cv2.VideoCapture(path)
        if not self._cap.isOpened():
            raise IOError(f"Cannot open video: {path}")

        self.fps = self._cap.get(cv2.CAP_PROP_FPS)
        self.num_frames = int(self._cap.get(cv2.CAP_PROP_FRAME_COUNT))
        self.duration = self.num_frames / self.fps

    def get_frame(self, t):
        """
        Return the frame at time `t` (in seconds).
        """
        frame_id = int(t * self.fps)
        return self.get_frame_by_index(frame_id)

    def get_frame_by_index(self, frame_id):
        """
        Return the frame at the given frame index.
        """
        self._cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
        success, frame = self._cap.read()
        if not success:
            return None

        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        return frame

    def get_frame_width(self):
        """
        Return the width of the video frames.
        """
        return int(self._cap.get(cv2.CAP_PROP_FRAME_WIDTH))

    def get_frame_height(self):
        """
        Return the height of the video frames.
        """
        return int(self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    def get_frame_shape(self):
        """
        Return the shape of the video frames.
        """
        return (self.get_frame_height(), self.get_frame_width(), 3)

    def release(self):
        self._cap.release()

    def __del__(self):
        self.release()

get_frame(t)

Return the frame at time t (in seconds).

Source code in src\aegear\video.py
16
17
18
19
20
21
def get_frame(self, t):
    """
    Return the frame at time `t` (in seconds).
    """
    frame_id = int(t * self.fps)
    return self.get_frame_by_index(frame_id)

get_frame_by_index(frame_id)

Return the frame at the given frame index.

Source code in src\aegear\video.py
23
24
25
26
27
28
29
30
31
32
33
34
35
def get_frame_by_index(self, frame_id):
    """
    Return the frame at the given frame index.
    """
    self._cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
    success, frame = self._cap.read()
    if not success:
        return None

    # Convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    return frame

get_frame_width()

Return the width of the video frames.

Source code in src\aegear\video.py
37
38
39
40
41
def get_frame_width(self):
    """
    Return the width of the video frames.
    """
    return int(self._cap.get(cv2.CAP_PROP_FRAME_WIDTH))

get_frame_height()

Return the height of the video frames.

Source code in src\aegear\video.py
43
44
45
46
47
def get_frame_height(self):
    """
    Return the height of the video frames.
    """
    return int(self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

get_frame_shape()

Return the shape of the video frames.

Source code in src\aegear\video.py
49
50
51
52
53
def get_frame_shape(self):
    """
    Return the shape of the video frames.
    """
    return (self.get_frame_height(), self.get_frame_width(), 3)

visualization

Visualization utilities for model inspection using FiftyOne.

FiftyOneDatasetBuilder

Base class for building FiftyOne datasets from model predictions.

Source code in src\aegear\visualization.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class FiftyOneDatasetBuilder:
    """Base class for building FiftyOne datasets from model predictions."""

    def __init__(self, dataset, model, device, img_size):
        self.dataset = dataset
        self.model = model
        self.device = device
        self.img_size = img_size

    def build_dataset(self, fo_dataset_name, batch_size=128, num_workers=4):
        """Build and populate a FiftyOne dataset.

        Parameters
        ----------
        fo_dataset_name : str
            Name for the FiftyOne dataset
        batch_size : int
            Batch size for inference
        num_workers : int
            Number of workers for data loading

        Returns
        -------
        fo.Dataset
            Populated FiftyOne dataset
        """
        raise NotImplementedError("Subclasses must implement build_dataset")

    def _create_dataloader(self, batch_size, num_workers):
        """Create a DataLoader with shuffle=False for consistent ordering."""
        return DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers
        )

build_dataset(fo_dataset_name, batch_size=128, num_workers=4)

Build and populate a FiftyOne dataset.

Parameters

fo_dataset_name : str Name for the FiftyOne dataset batch_size : int Batch size for inference num_workers : int Number of workers for data loading

Returns

fo.Dataset Populated FiftyOne dataset

Source code in src\aegear\visualization.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def build_dataset(self, fo_dataset_name, batch_size=128, num_workers=4):
    """Build and populate a FiftyOne dataset.

    Parameters
    ----------
    fo_dataset_name : str
        Name for the FiftyOne dataset
    batch_size : int
        Batch size for inference
    num_workers : int
        Number of workers for data loading

    Returns
    -------
    fo.Dataset
        Populated FiftyOne dataset
    """
    raise NotImplementedError("Subclasses must implement build_dataset")

TrackingDatasetBuilder

Bases: FiftyOneDatasetBuilder

Builder for tracking model evaluation datasets in FiftyOne.

Source code in src\aegear\visualization.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class TrackingDatasetBuilder(FiftyOneDatasetBuilder):
    """Builder for tracking model evaluation datasets in FiftyOne."""

    def build_dataset(self, fo_dataset_name, batch_size=128, num_workers=4):
        """Build FiftyOne dataset for tracking model evaluation."""
        print("Populating FiftyOne dataset from metadata...")
        fo_dataset = fo.Dataset(fo_dataset_name, persistent=True, overwrite=True)

        # Add metadata samples
        samples_to_add = self._create_metadata_samples()
        fo_dataset.add_samples(samples_to_add)

        # Run inference and add predictions
        print("Running inference and adding predictions to FiftyOne...")
        self._add_predictions(fo_dataset, batch_size, num_workers)

        return fo_dataset

    def _create_metadata_samples(self):
        """Create FiftyOne samples from dataset metadata."""
        samples = []
        for item in tqdm(self.dataset.metadata, desc="Loading metadata"):
            search_path = os.path.join(self.dataset.root_dir, item["search_path"])
            template_path = os.path.join(self.dataset.root_dir, item["template_path"])

            sample = fo.Sample(filepath=search_path)

            # Store template as a custom field with its own visualizations
            sample["template_filepath"] = template_path

            # Add template image as embedded visualization if available
            # This allows viewing template alongside search image
            try:
                import PIL.Image as Image
                template_img = Image.open(template_path)
                sample["template_image"] = fo.Image(filepath=template_path)
            except:
                pass  # Skip if image loading fails

            if item.get("background", False):
                sample["ground_truth"] = fo.Keypoints()
                sample.tags.append("background")
            else:
                # Store GT centroid as keypoint
                xg, yg = item["centroid"]
                xg_rel = xg / self.img_size
                yg_rel = yg / self.img_size

                gt_keypoint = fo.Keypoint(label="gt", points=[(xg_rel, yg_rel)])
                sample["ground_truth"] = fo.Keypoints(keypoints=[gt_keypoint])

                # Store raw GT heatmap
                gt_heatmap = self.dataset.generate_heatmap(item["centroid"])
                sample["gt_heatmap"] = fo.Heatmap(map=gt_heatmap.squeeze().numpy())

            # Store template and search ROI information if available
            if "template_bbox" in item:
                bbox = item["template_bbox"]
                # Normalize bbox to [0, 1] for search image
                x, y, w, h = bbox
                sample["template_roi"] = fo.Detection(
                    label="template_roi",
                    bounding_box=[x/self.img_size, y/self.img_size, w/self.img_size, h/self.img_size]
                )

            # Store search ROI information if available (the region being searched)
            if "search_bbox" in item:
                bbox = item["search_bbox"]
                x, y, w, h = bbox
                sample["search_roi"] = fo.Detection(
                    label="search_roi",
                    bounding_box=[x/self.img_size, y/self.img_size, w/self.img_size, h/self.img_size]
                )

            # Store metadata for reference
            sample["sample_metadata"] = {
                "search_path": item["search_path"],
                "template_path": item["template_path"],
                "is_background": item.get("background", False)
            }

            samples.append(sample)

        return samples

    def _add_predictions(self, fo_dataset, batch_size, num_workers):
        """Run inference and add predictions to FiftyOne dataset."""
        loader = self._create_dataloader(batch_size, num_workers)
        sample_ids = fo_dataset.values("id")
        samples_to_save = []

        self.model.eval()
        with torch.no_grad():
            idx_counter = 0
            for templates, searches, heatmaps in tqdm(loader, desc="Evaluating"):
                templates = templates.to(self.device)
                searches = searches.to(self.device)

                # Run model
                preds_logits = self.model(templates, searches)
                preds = torch.sigmoid(preds_logits)

                # Interpolate
                preds = F.interpolate(preds, size=(self.img_size, self.img_size), 
                                    mode='bilinear', align_corners=False)
                heatmaps = F.interpolate(heatmaps, size=(self.img_size, self.img_size),
                                       mode='bilinear', align_corners=False)

                # Get centroids
                centroids_pred = get_centroids_per_sample(preds)
                centroids_gt = get_centroids_per_sample(heatmaps)

                for i in range(len(templates)):
                    sample_id = sample_ids[idx_counter]
                    sample = fo_dataset[sample_id]
                    idx_counter += 1

                    p = centroids_pred[i]
                    t = centroids_gt[i]

                    # Save predicted heatmap
                    pred_hm_np = preds[i, 0].cpu().numpy()
                    sample["pred_heatmap"] = fo.Heatmap(map=pred_hm_np)

                    pred_keypoints = []
                    if p is not None:
                        xp, yp, confidence = p
                        xp, yp = xp.item(), yp.item()
                        confidence = confidence.item()

                        # Normalize for FiftyOne
                        xp_rel, yp_rel = xp / self.img_size, yp / self.img_size

                        pred_keypoint = fo.Keypoint(
                            label="pred",
                            points=[(xp_rel, yp_rel)],
                            confidence=[confidence]
                        )
                        pred_keypoints.append(pred_keypoint)
                        sample["confidence"] = confidence

                    sample["prediction"] = fo.Keypoints(keypoints=pred_keypoints)

                    # Calculate distance error
                    if p is not None and t is not None:
                        xp, yp, _ = p
                        xg, yg, _ = t
                        dist = np.sqrt((xp.item() - xg.item())**2 + (yp.item() - yg.item())**2)
                        sample["distance_error"] = dist
                    else:
                        sample["distance_error"] = None

                    samples_to_save.append(sample)

        fo_dataset.add_samples(samples_to_save)
        print(f"Successfully added predictions to {len(samples_to_save)} samples.")

build_dataset(fo_dataset_name, batch_size=128, num_workers=4)

Build FiftyOne dataset for tracking model evaluation.

Source code in src\aegear\visualization.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def build_dataset(self, fo_dataset_name, batch_size=128, num_workers=4):
    """Build FiftyOne dataset for tracking model evaluation."""
    print("Populating FiftyOne dataset from metadata...")
    fo_dataset = fo.Dataset(fo_dataset_name, persistent=True, overwrite=True)

    # Add metadata samples
    samples_to_add = self._create_metadata_samples()
    fo_dataset.add_samples(samples_to_add)

    # Run inference and add predictions
    print("Running inference and adding predictions to FiftyOne...")
    self._add_predictions(fo_dataset, batch_size, num_workers)

    return fo_dataset

DetectionDatasetBuilder

Bases: FiftyOneDatasetBuilder

Builder for detection model evaluation datasets in FiftyOne.

Source code in src\aegear\visualization.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
class DetectionDatasetBuilder(FiftyOneDatasetBuilder):
    """Builder for detection model evaluation datasets in FiftyOne."""

    def build_dataset(self, fo_dataset_name, batch_size=128, num_workers=4):
        """Build FiftyOne dataset for detection model evaluation."""
        print("Populating FiftyOne dataset from metadata...")
        fo_dataset = fo.Dataset(fo_dataset_name, persistent=True, overwrite=True)

        # Add metadata samples
        samples_to_add = self._create_metadata_samples()
        fo_dataset.add_samples(samples_to_add)

        # Run inference and add predictions
        print("Running inference and adding predictions to FiftyOne...")
        self._add_predictions(fo_dataset, batch_size, num_workers)

        return fo_dataset

    def _create_metadata_samples(self):
        """Create FiftyOne samples from dataset metadata."""
        samples = []
        for item in tqdm(self.dataset.metadata, desc="Loading metadata"):
            image_path = os.path.join(self.dataset.root_dir, item["image_path"])

            sample = fo.Sample(filepath=image_path)

            if item.get("background", False):
                sample["ground_truth"] = fo.Keypoints()
                sample.tags.append("background")
            else:
                # Store GT centroid as keypoint
                xg, yg = item["centroid"]
                xg_rel = xg / self.img_size
                yg_rel = yg / self.img_size

                gt_keypoint = fo.Keypoint(label="gt", points=[(xg_rel, yg_rel)])
                sample["ground_truth"] = fo.Keypoints(keypoints=[gt_keypoint])

                # Store raw GT heatmap
                gt_heatmap = self.dataset.generate_heatmap(item["centroid"])
                sample["gt_heatmap"] = fo.Heatmap(map=gt_heatmap.squeeze().numpy())

            samples.append(sample)

        return samples

    def _add_predictions(self, fo_dataset, batch_size, num_workers):
        """Run inference and add predictions to FiftyOne dataset."""
        loader = self._create_dataloader(batch_size, num_workers)
        sample_ids = fo_dataset.values("id")
        samples_to_save = []

        self.model.eval()
        with torch.no_grad():
            idx_counter = 0
            for images, heatmaps in tqdm(loader, desc="Evaluating"):
                images = images.to(self.device)

                # Run model
                preds_logits = self.model(images)
                preds = torch.sigmoid(preds_logits)

                # Interpolate
                preds = F.interpolate(preds, size=(self.img_size, self.img_size),
                                    mode='bilinear', align_corners=False)
                heatmaps = F.interpolate(heatmaps, size=(self.img_size, self.img_size),
                                       mode='bilinear', align_corners=False)

                # Get centroids
                centroids_pred = get_centroids_per_sample(preds)
                centroids_gt = get_centroids_per_sample(heatmaps)

                for i in range(len(images)):
                    sample_id = sample_ids[idx_counter]
                    sample = fo_dataset[sample_id]
                    idx_counter += 1

                    p = centroids_pred[i]
                    t = centroids_gt[i]

                    # Save predicted heatmap
                    pred_hm_np = preds[i, 0].cpu().numpy()
                    sample["pred_heatmap"] = fo.Heatmap(map=pred_hm_np)

                    pred_keypoints = []
                    if p is not None:
                        xp, yp, confidence = p
                        xp, yp = xp.item(), yp.item()
                        confidence = confidence.item()

                        # Normalize for FiftyOne
                        xp_rel, yp_rel = xp / self.img_size, yp / self.img_size

                        pred_keypoint = fo.Keypoint(
                            label="pred",
                            points=[(xp_rel, yp_rel)],
                            confidence=[confidence]
                        )
                        pred_keypoints.append(pred_keypoint)
                        sample["confidence"] = confidence

                    sample["prediction"] = fo.Keypoints(keypoints=pred_keypoints)

                    # Calculate distance error
                    if p is not None and t is not None:
                        xp, yp, _ = p
                        xg, yg, _ = t
                        dist = np.sqrt((xp.item() - xg.item())**2 + (yp.item() - yg.item())**2)
                        sample["distance_error"] = dist
                    else:
                        sample["distance_error"] = None

                    samples_to_save.append(sample)

        fo_dataset.add_samples(samples_to_save)
        print(f"Successfully added predictions to {len(samples_to_save)} samples.")

build_dataset(fo_dataset_name, batch_size=128, num_workers=4)

Build FiftyOne dataset for detection model evaluation.

Source code in src\aegear\visualization.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def build_dataset(self, fo_dataset_name, batch_size=128, num_workers=4):
    """Build FiftyOne dataset for detection model evaluation."""
    print("Populating FiftyOne dataset from metadata...")
    fo_dataset = fo.Dataset(fo_dataset_name, persistent=True, overwrite=True)

    # Add metadata samples
    samples_to_add = self._create_metadata_samples()
    fo_dataset.add_samples(samples_to_add)

    # Run inference and add predictions
    print("Running inference and adding predictions to FiftyOne...")
    self._add_predictions(fo_dataset, batch_size, num_workers)

    return fo_dataset