Skip to content

Utils

box_iou_batch

Compute Intersection over Union (IoU) of two sets of bounding boxes - boxes_true and boxes_detection. Both sets of boxes are expected to be in (x_min, y_min, x_max, y_max) format.

Parameters:

Name Type Description Default
boxes_true ndarray

2D np.ndarray representing ground-truth boxes. shape = (N, 4) where N is number of true objects.

required
boxes_detection ndarray

2D np.ndarray representing detection boxes. shape = (M, 4) where M is number of detected objects.

required

Returns:

Type Description
ndarray

np.ndarray: Pairwise IoU of boxes from boxes_true and boxes_detection. shape = (N, M) where N is number of true objects and M is number of detected objects.

Source code in supervision/detection/utils.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def box_iou_batch(boxes_true: np.ndarray, boxes_detection: np.ndarray) -> np.ndarray:
    """
    Compute Intersection over Union (IoU) of two sets of bounding boxes -
        `boxes_true` and `boxes_detection`. Both sets
        of boxes are expected to be in `(x_min, y_min, x_max, y_max)` format.

    Args:
        boxes_true (np.ndarray): 2D `np.ndarray` representing ground-truth boxes.
            `shape = (N, 4)` where `N` is number of true objects.
        boxes_detection (np.ndarray): 2D `np.ndarray` representing detection boxes.
            `shape = (M, 4)` where `M` is number of detected objects.

    Returns:
        np.ndarray: Pairwise IoU of boxes from `boxes_true` and `boxes_detection`.
            `shape = (N, M)` where `N` is number of true objects and
            `M` is number of detected objects.
    """

    def box_area(box):
        return (box[2] - box[0]) * (box[3] - box[1])

    area_true = box_area(boxes_true.T)
    area_detection = box_area(boxes_detection.T)

    top_left = np.maximum(boxes_true[:, None, :2], boxes_detection[:, :2])
    bottom_right = np.minimum(boxes_true[:, None, 2:], boxes_detection[:, 2:])

    area_inter = np.prod(np.clip(bottom_right - top_left, a_min=0, a_max=None), 2)
    return area_inter / (area_true[:, None] + area_detection - area_inter)

non_max_suppression

Perform Non-Maximum Suppression (NMS) on object detection predictions.

Parameters:

Name Type Description Default
predictions ndarray

An array of object detection predictions in the format of (x_min, y_min, x_max, y_max, score) or (x_min, y_min, x_max, y_max, score, class).

required
iou_threshold float

The intersection-over-union threshold to use for non-maximum suppression.

0.5

Returns:

Type Description
ndarray

np.ndarray: A boolean array indicating which predictions to keep after n on-maximum suppression.

Raises:

Type Description
AssertionError

If iou_threshold is not within the closed range from 0 to 1.

Source code in supervision/detection/utils.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def non_max_suppression(
    predictions: np.ndarray, iou_threshold: float = 0.5
) -> np.ndarray:
    """
    Perform Non-Maximum Suppression (NMS) on object detection predictions.

    Args:
        predictions (np.ndarray): An array of object detection predictions in
            the format of `(x_min, y_min, x_max, y_max, score)`
            or `(x_min, y_min, x_max, y_max, score, class)`.
        iou_threshold (float, optional): The intersection-over-union threshold
            to use for non-maximum suppression.

    Returns:
        np.ndarray: A boolean array indicating which predictions to keep after n
            on-maximum suppression.

    Raises:
        AssertionError: If `iou_threshold` is not within the
            closed range from `0` to `1`.
    """
    assert 0 <= iou_threshold <= 1, (
        "Value of `iou_threshold` must be in the closed range from 0 to 1, "
        f"{iou_threshold} given."
    )
    rows, columns = predictions.shape

    # add column #5 - category filled with zeros for agnostic nms
    if columns == 5:
        predictions = np.c_[predictions, np.zeros(rows)]

    # sort predictions column #4 - score
    sort_index = np.flip(predictions[:, 4].argsort())
    predictions = predictions[sort_index]

    boxes = predictions[:, :4]
    categories = predictions[:, 5]
    ious = box_iou_batch(boxes, boxes)
    ious = ious - np.eye(rows)

    keep = np.ones(rows, dtype=bool)

    for index, (iou, category) in enumerate(zip(ious, categories)):
        if not keep[index]:
            continue

        # drop detections with iou > iou_threshold and
        # same category as current detections
        condition = (iou > iou_threshold) & (categories == category)
        keep = keep & ~condition

    return keep[sort_index.argsort()]

polygon_to_mask

Generate a mask from a polygon.

Parameters:

Name Type Description Default
polygon ndarray

The polygon for which the mask should be generated, given as a list of vertices.

required
resolution_wh Tuple[int, int]

The width and height of the desired resolution.

required

Returns:

Type Description
ndarray

np.ndarray: The generated 2D mask, where the polygon is marked with 1's and the rest is filled with 0's.

Source code in supervision/detection/utils.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def polygon_to_mask(polygon: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray:
    """Generate a mask from a polygon.

    Args:
        polygon (np.ndarray): The polygon for which the mask should be generated,
            given as a list of vertices.
        resolution_wh (Tuple[int, int]): The width and height of the desired resolution.

    Returns:
        np.ndarray: The generated 2D mask, where the polygon is marked with
            `1`'s and the rest is filled with `0`'s.
    """
    width, height = resolution_wh
    mask = np.zeros((height, width))

    cv2.fillPoly(mask, [polygon], color=1)
    return mask

mask_to_xyxy

Converts a 3D np.array of 2D bool masks into a 2D np.array of bounding boxes.

Parameters:

Name Type Description Default
masks ndarray

A 3D np.array of shape (N, W, H) containing 2D bool masks

required

Returns:

Type Description
ndarray

np.ndarray: A 2D np.array of shape (N, 4) containing the bounding boxes (x_min, y_min, x_max, y_max) for each mask

Source code in supervision/detection/utils.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def mask_to_xyxy(masks: np.ndarray) -> np.ndarray:
    """
    Converts a 3D `np.array` of 2D bool masks into a 2D `np.array` of bounding boxes.

    Parameters:
        masks (np.ndarray): A 3D `np.array` of shape `(N, W, H)`
            containing 2D bool masks

    Returns:
        np.ndarray: A 2D `np.array` of shape `(N, 4)` containing the bounding boxes
            `(x_min, y_min, x_max, y_max)` for each mask
    """
    n = masks.shape[0]
    bboxes = np.zeros((n, 4), dtype=int)

    for i, mask in enumerate(masks):
        rows, cols = np.where(mask)

        if len(rows) > 0 and len(cols) > 0:
            x_min, x_max = np.min(cols), np.max(cols)
            y_min, y_max = np.min(rows), np.max(rows)
            bboxes[i, :] = [x_min, y_min, x_max, y_max]

    return bboxes

mask_to_polygons

Converts a binary mask to a list of polygons.

Parameters:

Name Type Description Default
mask ndarray

A binary mask represented as a 2D NumPy array of shape (H, W), where H and W are the height and width of the mask, respectively.

required

Returns:

Type Description
List[ndarray]

List[np.ndarray]: A list of polygons, where each polygon is represented by a NumPy array of shape (N, 2), containing the x, y coordinates of the points. Polygons with fewer points than MIN_POLYGON_POINT_COUNT = 3 are excluded from the output.

Source code in supervision/detection/utils.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def mask_to_polygons(mask: np.ndarray) -> List[np.ndarray]:
    """
    Converts a binary mask to a list of polygons.

    Parameters:
        mask (np.ndarray): A binary mask represented as a 2D NumPy array of
            shape `(H, W)`, where H and W are the height and width of
            the mask, respectively.

    Returns:
        List[np.ndarray]: A list of polygons, where each polygon is represented by a
            NumPy array of shape `(N, 2)`, containing the `x`, `y` coordinates
            of the points. Polygons with fewer points than `MIN_POLYGON_POINT_COUNT = 3`
            are excluded from the output.
    """

    contours, _ = cv2.findContours(
        mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
    )
    return [
        np.squeeze(contour, axis=1)
        for contour in contours
        if contour.shape[0] >= MIN_POLYGON_POINT_COUNT
    ]

polygon_to_xyxy

Converts a polygon represented by a NumPy array into a bounding box.

Parameters:

Name Type Description Default
polygon ndarray

A polygon represented by a NumPy array of shape (N, 2), containing the x, y coordinates of the points.

required

Returns:

Type Description
ndarray

np.ndarray: A 1D NumPy array containing the bounding box (x_min, y_min, x_max, y_max) of the input polygon.

Source code in supervision/detection/utils.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def polygon_to_xyxy(polygon: np.ndarray) -> np.ndarray:
    """
    Converts a polygon represented by a NumPy array into a bounding box.

    Parameters:
        polygon (np.ndarray): A polygon represented by a NumPy array of shape `(N, 2)`,
            containing the `x`, `y` coordinates of the points.

    Returns:
        np.ndarray: A 1D NumPy array containing the bounding box
            `(x_min, y_min, x_max, y_max)` of the input polygon.
    """
    x_min, y_min = np.min(polygon, axis=0)
    x_max, y_max = np.max(polygon, axis=0)
    return np.array([x_min, y_min, x_max, y_max])

filter_polygons_by_area

Filters a list of polygons based on their area.

Parameters:

Name Type Description Default
polygons List[ndarray]

A list of polygons, where each polygon is represented by a NumPy array of shape (N, 2), containing the x, y coordinates of the points.

required
min_area Optional[float]

The minimum area threshold. Only polygons with an area greater than or equal to this value will be included in the output. If set to None, no minimum area constraint will be applied.

None
max_area Optional[float]

The maximum area threshold. Only polygons with an area less than or equal to this value will be included in the output. If set to None, no maximum area constraint will be applied.

None

Returns:

Type Description
List[ndarray]

List[np.ndarray]: A new list of polygons containing only those with areas within the specified thresholds.

Source code in supervision/detection/utils.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def filter_polygons_by_area(
    polygons: List[np.ndarray],
    min_area: Optional[float] = None,
    max_area: Optional[float] = None,
) -> List[np.ndarray]:
    """
    Filters a list of polygons based on their area.

    Parameters:
        polygons (List[np.ndarray]): A list of polygons, where each polygon is
            represented by a NumPy array of shape `(N, 2)`,
            containing the `x`, `y` coordinates of the points.
        min_area (Optional[float]): The minimum area threshold.
            Only polygons with an area greater than or equal to this value
            will be included in the output. If set to None,
            no minimum area constraint will be applied.
        max_area (Optional[float]): The maximum area threshold.
            Only polygons with an area less than or equal to this value
            will be included in the output. If set to None,
            no maximum area constraint will be applied.

    Returns:
        List[np.ndarray]: A new list of polygons containing only those with
            areas within the specified thresholds.
    """
    if min_area is None and max_area is None:
        return polygons
    ares = [cv2.contourArea(polygon) for polygon in polygons]
    return [
        polygon
        for polygon, area in zip(polygons, ares)
        if (min_area is None or area >= min_area)
        and (max_area is None or area <= max_area)
    ]

move_boxes

Parameters:

Name Type Description Default
xyxy ndarray

An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2]

required
offset array

An array of shape (2,) containing offset values in format is [dx, dy].

required

Returns:

Type Description
ndarray

np.ndarray: Repositioned bounding boxes.

Example
import numpy as np
import supervision as sv

boxes = np.array([[10, 10, 20, 20], [30, 30, 40, 40]])
offset = np.array([5, 5])
moved_box = sv.move_boxes(boxes, offset)
print(moved_box)
# np.array([
#    [15, 15, 25, 25],
#     [35, 35, 45, 45]
# ])
Source code in supervision/detection/utils.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def move_boxes(xyxy: np.ndarray, offset: np.ndarray) -> np.ndarray:
    """
    Parameters:
        xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes
            coordinates in format `[x1, y1, x2, y2]`
        offset (np.array): An array of shape `(2,)` containing offset values in format
            is `[dx, dy]`.

    Returns:
        np.ndarray: Repositioned bounding boxes.

    Example:
        ```python
        import numpy as np
        import supervision as sv

        boxes = np.array([[10, 10, 20, 20], [30, 30, 40, 40]])
        offset = np.array([5, 5])
        moved_box = sv.move_boxes(boxes, offset)
        print(moved_box)
        # np.array([
        #    [15, 15, 25, 25],
        #     [35, 35, 45, 45]
        # ])
        ```
    """
    return xyxy + np.hstack([offset, offset])

scale_boxes

Scale the dimensions of bounding boxes.

Parameters:

Name Type Description Default
xyxy ndarray

An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2]

required
factor float

A float value representing the factor by which the box dimensions are scaled. A factor greater than 1 enlarges the boxes, while a factor less than 1 shrinks them.

required

Returns:

Type Description
ndarray

np.ndarray: Scaled bounding boxes.

Example
import numpy as np
import supervision as sv

boxes = np.array([[10, 10, 20, 20], [30, 30, 40, 40]])
factor = 1.5
scaled_bb = sv.scale_boxes(boxes, factor)
print(scaled_bb)
# np.array([
#    [ 7.5,  7.5, 22.5, 22.5],
#    [27.5, 27.5, 42.5, 42.5]
# ])
Source code in supervision/detection/utils.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
def scale_boxes(xyxy: np.ndarray, factor: float) -> np.ndarray:
    """
    Scale the dimensions of bounding boxes.

    Parameters:
        xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes
            coordinates in format `[x1, y1, x2, y2]`
        factor (float): A float value representing the factor by which the box
            dimensions are scaled. A factor greater than 1 enlarges the boxes, while a
            factor less than 1 shrinks them.

    Returns:
        np.ndarray: Scaled bounding boxes.

    Example:
        ```python
        import numpy as np
        import supervision as sv

        boxes = np.array([[10, 10, 20, 20], [30, 30, 40, 40]])
        factor = 1.5
        scaled_bb = sv.scale_boxes(boxes, factor)
        print(scaled_bb)
        # np.array([
        #    [ 7.5,  7.5, 22.5, 22.5],
        #    [27.5, 27.5, 42.5, 42.5]
        # ])
        ```
    """
    centers = (xyxy[:, :2] + xyxy[:, 2:]) / 2
    new_sizes = (xyxy[:, 2:] - xyxy[:, :2]) * factor
    return np.concatenate((centers - new_sizes / 2, centers + new_sizes / 2), axis=1)

Comments