Starting with 0.23.0, a new metrics module is being introduced to supervision.
Metrics here are part of the legacy evaluation API and will be deprecated in the future.
@dataclassclassConfusionMatrix:""" Confusion matrix for object detection tasks. Attributes: matrix (np.ndarray): An 2D `np.ndarray` of shape `(len(classes) + 1, len(classes) + 1)` containing the number of `TP`, `FP`, `FN` and `TN` for each class. classes (List[str]): Model class names. conf_threshold (float): Detection confidence threshold between `0` and `1`. Detections with lower confidence will be excluded from the matrix. iou_threshold (float): Detection IoU threshold between `0` and `1`. Detections with lower IoU will be classified as `FP`. """matrix:np.ndarrayclasses:List[str]conf_threshold:floatiou_threshold:float@classmethoddeffrom_detections(cls,predictions:List[Detections],targets:List[Detections],classes:List[str],conf_threshold:float=0.3,iou_threshold:float=0.5,)->ConfusionMatrix:""" Calculate confusion matrix based on predicted and ground-truth detections. Args: targets (List[Detections]): Detections objects from ground-truth. predictions (List[Detections]): Detections objects predicted by the model. classes (List[str]): Model class names. conf_threshold (float): Detection confidence threshold between `0` and `1`. Detections with lower confidence will be excluded. iou_threshold (float): Detection IoU threshold between `0` and `1`. Detections with lower IoU will be classified as `FP`. Returns: ConfusionMatrix: New instance of ConfusionMatrix. Example: ```python import supervision as sv targets = [ sv.Detections(...), sv.Detections(...) ] predictions = [ sv.Detections(...), sv.Detections(...) ] confusion_matrix = sv.ConfusionMatrix.from_detections( predictions=predictions, targets=target, classes=['person', ...] ) print(confusion_matrix.matrix) # np.array([ # [0., 0., 0., 0.], # [0., 1., 0., 1.], # [0., 1., 1., 0.], # [1., 1., 0., 0.] # ]) ``` """prediction_tensors=[]target_tensors=[]forprediction,targetinzip(predictions,targets):prediction_tensors.append(detections_to_tensor(prediction,with_confidence=True))target_tensors.append(detections_to_tensor(target,with_confidence=False))returncls.from_tensors(predictions=prediction_tensors,targets=target_tensors,classes=classes,conf_threshold=conf_threshold,iou_threshold=iou_threshold,)@classmethoddeffrom_tensors(cls,predictions:List[np.ndarray],targets:List[np.ndarray],classes:List[str],conf_threshold:float=0.3,iou_threshold:float=0.5,)->ConfusionMatrix:""" Calculate confusion matrix based on predicted and ground-truth detections. Args: predictions (List[np.ndarray]): Each element of the list describes a single image and has `shape = (M, 6)` where `M` is the number of detected objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class, conf)` format. targets (List[np.ndarray]): Each element of the list describes a single image and has `shape = (N, 5)` where `N` is the number of ground-truth objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class)` format. classes (List[str]): Model class names. conf_threshold (float): Detection confidence threshold between `0` and `1`. Detections with lower confidence will be excluded. iou_threshold (float): Detection iou threshold between `0` and `1`. Detections with lower iou will be classified as `FP`. Returns: ConfusionMatrix: New instance of ConfusionMatrix. Example: ```python import supervision as sv import numpy as np targets = ( [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1], [2.0, 2.0, 5.0, 5.0, 1], [6.0, 1.0, 8.0, 3.0, 2], ] ), np.array([1.0, 1.0, 2.0, 2.0, 2]), ] ) predictions = [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1, 0.9], [0.1, 0.1, 3.0, 3.0, 0, 0.9], [6.0, 1.0, 8.0, 3.0, 1, 0.8], [1.0, 6.0, 2.0, 7.0, 1, 0.8], ] ), np.array([[1.0, 1.0, 2.0, 2.0, 2, 0.8]]) ] confusion_matrix = sv.ConfusionMatrix.from_tensors( predictions=predictions, targets=targets, classes=['person', ...] ) print(confusion_matrix.matrix) # np.array([ # [0., 0., 0., 0.], # [0., 1., 0., 1.], # [0., 1., 1., 0.], # [1., 1., 0., 0.] # ]) ``` """validate_input_tensors(predictions,targets)num_classes=len(classes)matrix=np.zeros((num_classes+1,num_classes+1))fortrue_batch,detection_batchinzip(targets,predictions):matrix+=cls.evaluate_detection_batch(predictions=detection_batch,targets=true_batch,num_classes=num_classes,conf_threshold=conf_threshold,iou_threshold=iou_threshold,)returncls(matrix=matrix,classes=classes,conf_threshold=conf_threshold,iou_threshold=iou_threshold,)@staticmethoddefevaluate_detection_batch(predictions:np.ndarray,targets:np.ndarray,num_classes:int,conf_threshold:float,iou_threshold:float,)->np.ndarray:""" Calculate confusion matrix for a batch of detections for a single image. Args: predictions (np.ndarray): Batch prediction. Describes a single image and has `shape = (M, 6)` where `M` is the number of detected objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class, conf)` format. targets (np.ndarray): Batch target labels. Describes a single image and has `shape = (N, 5)` where `N` is the number of ground-truth objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class)` format. num_classes (int): Number of classes. conf_threshold (float): Detection confidence threshold between `0` and `1`. Detections with lower confidence will be excluded. iou_threshold (float): Detection iou threshold between `0` and `1`. Detections with lower iou will be classified as `FP`. Returns: np.ndarray: Confusion matrix based on a single image. """result_matrix=np.zeros((num_classes+1,num_classes+1))conf_idx=5confidence=predictions[:,conf_idx]detection_batch_filtered=predictions[confidence>conf_threshold]class_id_idx=4true_classes=np.array(targets[:,class_id_idx],dtype=np.int16)detection_classes=np.array(detection_batch_filtered[:,class_id_idx],dtype=np.int16)true_boxes=targets[:,:class_id_idx]detection_boxes=detection_batch_filtered[:,:class_id_idx]iou_batch=box_iou_batch(boxes_true=true_boxes,boxes_detection=detection_boxes)matched_idx=np.asarray(iou_batch>iou_threshold).nonzero()ifmatched_idx[0].shape[0]:matches=np.stack((matched_idx[0],matched_idx[1],iou_batch[matched_idx]),axis=1)matches=ConfusionMatrix._drop_extra_matches(matches=matches)else:matches=np.zeros((0,3))matched_true_idx,matched_detection_idx,_=matches.transpose().astype(np.int16)fori,true_class_valueinenumerate(true_classes):j=matched_true_idx==iifmatches.shape[0]>0andsum(j)==1:result_matrix[true_class_value,detection_classes[matched_detection_idx[j]]]+=1# TPelse:result_matrix[true_class_value,num_classes]+=1# FNfori,detection_class_valueinenumerate(detection_classes):ifnotany(matched_detection_idx==i):result_matrix[num_classes,detection_class_value]+=1# FPreturnresult_matrix@staticmethoddef_drop_extra_matches(matches:np.ndarray)->np.ndarray:""" Deduplicate matches. If there are multiple matches for the same true or predicted box, only the one with the highest IoU is kept. """ifmatches.shape[0]>0:matches=matches[matches[:,2].argsort()[::-1]]matches=matches[np.unique(matches[:,1],return_index=True)[1]]matches=matches[matches[:,2].argsort()[::-1]]matches=matches[np.unique(matches[:,0],return_index=True)[1]]returnmatches@classmethoddefbenchmark(cls,dataset:DetectionDataset,callback:Callable[[np.ndarray],Detections],conf_threshold:float=0.3,iou_threshold:float=0.5,)->ConfusionMatrix:""" Calculate confusion matrix from dataset and callback function. Args: dataset (DetectionDataset): Object detection dataset used for evaluation. callback (Callable[[np.ndarray], Detections]): Function that takes an image as input and returns Detections object. conf_threshold (float): Detection confidence threshold between `0` and `1`. Detections with lower confidence will be excluded. iou_threshold (float): Detection IoU threshold between `0` and `1`. Detections with lower IoU will be classified as `FP`. Returns: ConfusionMatrix: New instance of ConfusionMatrix. Example: ```python import supervision as sv from ultralytics import YOLO dataset = sv.DetectionDataset.from_yolo(...) model = YOLO(...) def callback(image: np.ndarray) -> sv.Detections: result = model(image)[0] return sv.Detections.from_ultralytics(result) confusion_matrix = sv.ConfusionMatrix.benchmark( dataset = dataset, callback = callback ) print(confusion_matrix.matrix) # np.array([ # [0., 0., 0., 0.], # [0., 1., 0., 1.], # [0., 1., 1., 0.], # [1., 1., 0., 0.] # ]) ``` """predictions,targets=[],[]for_,image,annotationindataset:predictions_batch=callback(image)predictions.append(predictions_batch)targets.append(annotation)returncls.from_detections(predictions=predictions,targets=targets,classes=dataset.classes,conf_threshold=conf_threshold,iou_threshold=iou_threshold,)defplot(self,save_path:Optional[str]=None,title:Optional[str]=None,classes:Optional[List[str]]=None,normalize:bool=False,fig_size:Tuple[int,int]=(12,10),)->matplotlib.figure.Figure:""" Create confusion matrix plot and save it at selected location. Args: save_path (Optional[str]): Path to save the plot. If not provided, plot will be displayed. title (Optional[str]): Title of the plot. classes (Optional[List[str]]): List of classes to be displayed on the plot. If not provided, all classes will be displayed. normalize (bool): If True, normalize the confusion matrix. fig_size (Tuple[int, int]): Size of the plot. Returns: matplotlib.figure.Figure: Confusion matrix plot. """array=self.matrix.copy()ifnormalize:eps=1e-8array=array/(array.sum(0).reshape(1,-1)+eps)array[array<0.005]=np.nanfig,ax=plt.subplots(figsize=fig_size,tight_layout=True,facecolor="white")class_names=classesifclassesisnotNoneelseself.classesuse_labels_for_ticks=class_namesisnotNoneand(0<len(class_names)<99)ifuse_labels_for_ticks:x_tick_labels=[*class_names,"FN"]y_tick_labels=[*class_names,"FP"]num_ticks=len(x_tick_labels)else:x_tick_labels=Noney_tick_labels=Nonenum_ticks=len(array)im=ax.imshow(array,cmap="Blues")cbar=ax.figure.colorbar(im,ax=ax)cbar.mappable.set_clim(vmin=0,vmax=np.nanmax(array))ifx_tick_labelsisNone:tick_interval=2else:tick_interval=1ax.set_xticks(np.arange(0,num_ticks,tick_interval),labels=x_tick_labels)ax.set_yticks(np.arange(0,num_ticks,tick_interval),labels=y_tick_labels)plt.setp(ax.get_xticklabels(),rotation=90,ha="right",rotation_mode="default")labelsize=10ifnum_ticks<50else8ax.tick_params(axis="both",which="both",labelsize=labelsize)ifnum_ticks<30:foriinrange(array.shape[0]):forjinrange(array.shape[1]):n_preds=array[i,j]ifnotnp.isnan(n_preds):ax.text(j,i,f"{n_preds:.2f}"ifnormalizeelsef"{n_preds:.0f}",ha="center",va="center",color="black"ifn_preds<0.5*np.nanmax(array)else"white",)iftitle:ax.set_title(title,fontsize=20)ax.set_xlabel("Predicted")ax.set_ylabel("True")ax.set_facecolor("white")ifsave_path:fig.savefig(save_path,dpi=250,facecolor=fig.get_facecolor(),transparent=True)returnfig
Batch prediction. Describes a single image and
has shape = (M, 6) where M is the number of detected objects.
Each row is expected to be in
(x_min, y_min, x_max, y_max, class, conf) format.
Batch target labels. Describes a single image and
has shape = (N, 5) where N is the number of ground-truth objects.
Each row is expected to be in
(x_min, y_min, x_max, y_max, class) format.
@staticmethoddefevaluate_detection_batch(predictions:np.ndarray,targets:np.ndarray,num_classes:int,conf_threshold:float,iou_threshold:float,)->np.ndarray:""" Calculate confusion matrix for a batch of detections for a single image. Args: predictions (np.ndarray): Batch prediction. Describes a single image and has `shape = (M, 6)` where `M` is the number of detected objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class, conf)` format. targets (np.ndarray): Batch target labels. Describes a single image and has `shape = (N, 5)` where `N` is the number of ground-truth objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class)` format. num_classes (int): Number of classes. conf_threshold (float): Detection confidence threshold between `0` and `1`. Detections with lower confidence will be excluded. iou_threshold (float): Detection iou threshold between `0` and `1`. Detections with lower iou will be classified as `FP`. Returns: np.ndarray: Confusion matrix based on a single image. """result_matrix=np.zeros((num_classes+1,num_classes+1))conf_idx=5confidence=predictions[:,conf_idx]detection_batch_filtered=predictions[confidence>conf_threshold]class_id_idx=4true_classes=np.array(targets[:,class_id_idx],dtype=np.int16)detection_classes=np.array(detection_batch_filtered[:,class_id_idx],dtype=np.int16)true_boxes=targets[:,:class_id_idx]detection_boxes=detection_batch_filtered[:,:class_id_idx]iou_batch=box_iou_batch(boxes_true=true_boxes,boxes_detection=detection_boxes)matched_idx=np.asarray(iou_batch>iou_threshold).nonzero()ifmatched_idx[0].shape[0]:matches=np.stack((matched_idx[0],matched_idx[1],iou_batch[matched_idx]),axis=1)matches=ConfusionMatrix._drop_extra_matches(matches=matches)else:matches=np.zeros((0,3))matched_true_idx,matched_detection_idx,_=matches.transpose().astype(np.int16)fori,true_class_valueinenumerate(true_classes):j=matched_true_idx==iifmatches.shape[0]>0andsum(j)==1:result_matrix[true_class_value,detection_classes[matched_detection_idx[j]]]+=1# TPelse:result_matrix[true_class_value,num_classes]+=1# FNfori,detection_class_valueinenumerate(detection_classes):ifnotany(matched_detection_idx==i):result_matrix[num_classes,detection_class_value]+=1# FPreturnresult_matrix
Each element of the list describes a single
image and has shape = (M, 6) where M is the number of detected
objects. Each row is expected to be in
(x_min, y_min, x_max, y_max, class, conf) format.
Each element of the list describes a single
image and has shape = (N, 5) where N is the number of
ground-truth objects. Each row is expected to be in
(x_min, y_min, x_max, y_max, class) format.
@classmethoddeffrom_tensors(cls,predictions:List[np.ndarray],targets:List[np.ndarray],classes:List[str],conf_threshold:float=0.3,iou_threshold:float=0.5,)->ConfusionMatrix:""" Calculate confusion matrix based on predicted and ground-truth detections. Args: predictions (List[np.ndarray]): Each element of the list describes a single image and has `shape = (M, 6)` where `M` is the number of detected objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class, conf)` format. targets (List[np.ndarray]): Each element of the list describes a single image and has `shape = (N, 5)` where `N` is the number of ground-truth objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class)` format. classes (List[str]): Model class names. conf_threshold (float): Detection confidence threshold between `0` and `1`. Detections with lower confidence will be excluded. iou_threshold (float): Detection iou threshold between `0` and `1`. Detections with lower iou will be classified as `FP`. Returns: ConfusionMatrix: New instance of ConfusionMatrix. Example: ```python import supervision as sv import numpy as np targets = ( [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1], [2.0, 2.0, 5.0, 5.0, 1], [6.0, 1.0, 8.0, 3.0, 2], ] ), np.array([1.0, 1.0, 2.0, 2.0, 2]), ] ) predictions = [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1, 0.9], [0.1, 0.1, 3.0, 3.0, 0, 0.9], [6.0, 1.0, 8.0, 3.0, 1, 0.8], [1.0, 6.0, 2.0, 7.0, 1, 0.8], ] ), np.array([[1.0, 1.0, 2.0, 2.0, 2, 0.8]]) ] confusion_matrix = sv.ConfusionMatrix.from_tensors( predictions=predictions, targets=targets, classes=['person', ...] ) print(confusion_matrix.matrix) # np.array([ # [0., 0., 0., 0.], # [0., 1., 0., 1.], # [0., 1., 1., 0.], # [1., 1., 0., 0.] # ]) ``` """validate_input_tensors(predictions,targets)num_classes=len(classes)matrix=np.zeros((num_classes+1,num_classes+1))fortrue_batch,detection_batchinzip(targets,predictions):matrix+=cls.evaluate_detection_batch(predictions=detection_batch,targets=true_batch,num_classes=num_classes,conf_threshold=conf_threshold,iou_threshold=iou_threshold,)returncls(matrix=matrix,classes=classes,conf_threshold=conf_threshold,iou_threshold=iou_threshold,)
defplot(self,save_path:Optional[str]=None,title:Optional[str]=None,classes:Optional[List[str]]=None,normalize:bool=False,fig_size:Tuple[int,int]=(12,10),)->matplotlib.figure.Figure:""" Create confusion matrix plot and save it at selected location. Args: save_path (Optional[str]): Path to save the plot. If not provided, plot will be displayed. title (Optional[str]): Title of the plot. classes (Optional[List[str]]): List of classes to be displayed on the plot. If not provided, all classes will be displayed. normalize (bool): If True, normalize the confusion matrix. fig_size (Tuple[int, int]): Size of the plot. Returns: matplotlib.figure.Figure: Confusion matrix plot. """array=self.matrix.copy()ifnormalize:eps=1e-8array=array/(array.sum(0).reshape(1,-1)+eps)array[array<0.005]=np.nanfig,ax=plt.subplots(figsize=fig_size,tight_layout=True,facecolor="white")class_names=classesifclassesisnotNoneelseself.classesuse_labels_for_ticks=class_namesisnotNoneand(0<len(class_names)<99)ifuse_labels_for_ticks:x_tick_labels=[*class_names,"FN"]y_tick_labels=[*class_names,"FP"]num_ticks=len(x_tick_labels)else:x_tick_labels=Noney_tick_labels=Nonenum_ticks=len(array)im=ax.imshow(array,cmap="Blues")cbar=ax.figure.colorbar(im,ax=ax)cbar.mappable.set_clim(vmin=0,vmax=np.nanmax(array))ifx_tick_labelsisNone:tick_interval=2else:tick_interval=1ax.set_xticks(np.arange(0,num_ticks,tick_interval),labels=x_tick_labels)ax.set_yticks(np.arange(0,num_ticks,tick_interval),labels=y_tick_labels)plt.setp(ax.get_xticklabels(),rotation=90,ha="right",rotation_mode="default")labelsize=10ifnum_ticks<50else8ax.tick_params(axis="both",which="both",labelsize=labelsize)ifnum_ticks<30:foriinrange(array.shape[0]):forjinrange(array.shape[1]):n_preds=array[i,j]ifnotnp.isnan(n_preds):ax.text(j,i,f"{n_preds:.2f}"ifnormalizeelsef"{n_preds:.0f}",ha="center",va="center",color="black"ifn_preds<0.5*np.nanmax(array)else"white",)iftitle:ax.set_title(title,fontsize=20)ax.set_xlabel("Predicted")ax.set_ylabel("True")ax.set_facecolor("white")ifsave_path:fig.savefig(save_path,dpi=250,facecolor=fig.get_facecolor(),transparent=True)returnfig
@dataclass(frozen=True)classMeanAveragePrecision:""" Mean Average Precision for object detection tasks. Attributes: map50_95 (float): Mean Average Precision (mAP) calculated over IoU thresholds ranging from `0.50` to `0.95` with a step size of `0.05`. map50 (float): Mean Average Precision (mAP) calculated specifically at an IoU threshold of `0.50`. map75 (float): Mean Average Precision (mAP) calculated specifically at an IoU threshold of `0.75`. per_class_ap50_95 (np.ndarray): Average Precision (AP) values calculated over IoU thresholds ranging from `0.50` to `0.95` with a step size of `0.05`, provided for each individual class. """map50_95:floatmap50:floatmap75:floatper_class_ap50_95:np.ndarray@classmethoddeffrom_detections(cls,predictions:List[Detections],targets:List[Detections],)->MeanAveragePrecision:""" Calculate mean average precision based on predicted and ground-truth detections. Args: targets (List[Detections]): Detections objects from ground-truth. predictions (List[Detections]): Detections objects predicted by the model. Returns: MeanAveragePrecision: New instance of ConfusionMatrix. Example: ```python import supervision as sv targets = [ sv.Detections(...), sv.Detections(...) ] predictions = [ sv.Detections(...), sv.Detections(...) ] mean_average_precision = sv.MeanAveragePrecision.from_detections( predictions=predictions, targets=target, ) print(mean_average_precison.map50_95) # 0.2899 ``` """prediction_tensors=[]target_tensors=[]forprediction,targetinzip(predictions,targets):prediction_tensors.append(detections_to_tensor(prediction,with_confidence=True))target_tensors.append(detections_to_tensor(target,with_confidence=False))returncls.from_tensors(predictions=prediction_tensors,targets=target_tensors,)@classmethoddefbenchmark(cls,dataset:DetectionDataset,callback:Callable[[np.ndarray],Detections],)->MeanAveragePrecision:""" Calculate mean average precision from dataset and callback function. Args: dataset (DetectionDataset): Object detection dataset used for evaluation. callback (Callable[[np.ndarray], Detections]): Function that takes an image as input and returns Detections object. Returns: MeanAveragePrecision: New instance of MeanAveragePrecision. Example: ```python import supervision as sv from ultralytics import YOLO dataset = sv.DetectionDataset.from_yolo(...) model = YOLO(...) def callback(image: np.ndarray) -> sv.Detections: result = model(image)[0] return sv.Detections.from_ultralytics(result) mean_average_precision = sv.MeanAveragePrecision.benchmark( dataset = dataset, callback = callback ) print(mean_average_precision.map50_95) # 0.433 ``` """predictions,targets=[],[]for_,image,annotationindataset:predictions_batch=callback(image)predictions.append(predictions_batch)targets.append(annotation)returncls.from_detections(predictions=predictions,targets=targets,)@classmethoddeffrom_tensors(cls,predictions:List[np.ndarray],targets:List[np.ndarray],)->MeanAveragePrecision:""" Calculate Mean Average Precision based on predicted and ground-truth detections at different threshold. Args: predictions (List[np.ndarray]): Each element of the list describes a single image and has `shape = (M, 6)` where `M` is the number of detected objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class, conf)` format. targets (List[np.ndarray]): Each element of the list describes a single image and has `shape = (N, 5)` where `N` is the number of ground-truth objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class)` format. Returns: MeanAveragePrecision: New instance of MeanAveragePrecision. Example: ```python import supervision as sv import numpy as np targets = ( [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1], [2.0, 2.0, 5.0, 5.0, 1], [6.0, 1.0, 8.0, 3.0, 2], ] ), np.array([[1.0, 1.0, 2.0, 2.0, 2]]), ] ) predictions = [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1, 0.9], [0.1, 0.1, 3.0, 3.0, 0, 0.9], [6.0, 1.0, 8.0, 3.0, 1, 0.8], [1.0, 6.0, 2.0, 7.0, 1, 0.8], ] ), np.array([[1.0, 1.0, 2.0, 2.0, 2, 0.8]]) ] mean_average_precision = sv.MeanAveragePrecision.from_tensors( predictions=predictions, targets=targets, ) print(mean_average_precision.map50_95) # 0.6649 ``` """validate_input_tensors(predictions,targets)iou_thresholds=np.linspace(0.5,0.95,10)stats=[]# Gather matching stats for predictions and targetsfortrue_objs,predicted_objsinzip(targets,predictions):ifpredicted_objs.shape[0]==0:iftrue_objs.shape[0]:stats.append((np.zeros((0,iou_thresholds.size),dtype=bool),*np.zeros((2,0)),true_objs[:,4],))continueiftrue_objs.shape[0]:matches=cls._match_detection_batch(predicted_objs,true_objs,iou_thresholds)stats.append((matches,predicted_objs[:,5],predicted_objs[:,4],true_objs[:,4],))# Compute average precisions if any matches existifstats:concatenated_stats=[np.concatenate(items,0)foritemsinzip(*stats)]average_precisions=cls._average_precisions_per_class(*concatenated_stats)map50=average_precisions[:,0].mean()map75=average_precisions[:,5].mean()map50_95=average_precisions.mean()else:map50,map75,map50_95=0,0,0average_precisions=[]returncls(map50_95=map50_95,map50=map50,map75=map75,per_class_ap50_95=average_precisions,)@staticmethoddefcompute_average_precision(recall:np.ndarray,precision:np.ndarray)->float:""" Compute the average precision using 101-point interpolation (COCO), given the recall and precision curves. Args: recall (np.ndarray): The recall curve. precision (np.ndarray): The precision curve. Returns: float: Average precision. """extended_recall=np.concatenate(([0.0],recall,[1.0]))extended_precision=np.concatenate(([1.0],precision,[0.0]))max_accumulated_precision=np.flip(np.maximum.accumulate(np.flip(extended_precision)))interpolated_recall_levels=np.linspace(0,1,101)interpolated_precision=np.interp(interpolated_recall_levels,extended_recall,max_accumulated_precision)average_precision=np.trapz(interpolated_precision,interpolated_recall_levels)returnaverage_precision@staticmethoddef_match_detection_batch(predictions:np.ndarray,targets:np.ndarray,iou_thresholds:np.ndarray)->np.ndarray:""" Match predictions with target labels based on IoU levels. Args: predictions (np.ndarray): Batch prediction. Describes a single image and has `shape = (M, 6)` where `M` is the number of detected objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class, conf)` format. targets (np.ndarray): Batch target labels. Describes a single image and has `shape = (N, 5)` where `N` is the number of ground-truth objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class)` format. iou_thresholds (np.ndarray): Array contains different IoU thresholds. Returns: np.ndarray: Matched prediction with target labels result. """num_predictions,num_iou_levels=predictions.shape[0],iou_thresholds.shape[0]correct=np.zeros((num_predictions,num_iou_levels),dtype=bool)iou=box_iou_batch(targets[:,:4],predictions[:,:4])correct_class=targets[:,4:5]==predictions[:,4]fori,iou_levelinenumerate(iou_thresholds):matched_indices=np.where((iou>=iou_level)&correct_class)ifmatched_indices[0].shape[0]:combined_indices=np.stack(matched_indices,axis=1)iou_values=iou[matched_indices][:,None]matches=np.hstack([combined_indices,iou_values])ifmatched_indices[0].shape[0]>1:matches=matches[matches[:,2].argsort()[::-1]]matches=matches[np.unique(matches[:,1],return_index=True)[1]]matches=matches[np.unique(matches[:,0],return_index=True)[1]]correct[matches[:,1].astype(int),i]=Truereturncorrect@staticmethoddef_average_precisions_per_class(matches:np.ndarray,prediction_confidence:np.ndarray,prediction_class_ids:np.ndarray,true_class_ids:np.ndarray,eps:float=1e-16,)->np.ndarray:""" Compute the average precision, given the recall and precision curves. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. Args: matches (np.ndarray): True positives. prediction_confidence (np.ndarray): Objectness value from 0-1. prediction_class_ids (np.ndarray): Predicted object classes. true_class_ids (np.ndarray): True object classes. eps (float): Small value to prevent division by zero. Returns: np.ndarray: Average precision for different IoU levels. """sorted_indices=np.argsort(-prediction_confidence)matches=matches[sorted_indices]prediction_class_ids=prediction_class_ids[sorted_indices]unique_classes,class_counts=np.unique(true_class_ids,return_counts=True)num_classes=unique_classes.shape[0]average_precisions=np.zeros((num_classes,matches.shape[1]))forclass_idx,class_idinenumerate(unique_classes):is_class=prediction_class_ids==class_idtotal_true=class_counts[class_idx]total_prediction=is_class.sum()iftotal_prediction==0ortotal_true==0:continuefalse_positives=(1-matches[is_class]).cumsum(0)true_positives=matches[is_class].cumsum(0)recall=true_positives/(total_true+eps)precision=true_positives/(true_positives+false_positives)foriou_level_idxinrange(matches.shape[1]):average_precisions[class_idx,iou_level_idx]=(MeanAveragePrecision.compute_average_precision(recall[:,iou_level_idx],precision[:,iou_level_idx]))returnaverage_precisions
@classmethoddefbenchmark(cls,dataset:DetectionDataset,callback:Callable[[np.ndarray],Detections],)->MeanAveragePrecision:""" Calculate mean average precision from dataset and callback function. Args: dataset (DetectionDataset): Object detection dataset used for evaluation. callback (Callable[[np.ndarray], Detections]): Function that takes an image as input and returns Detections object. Returns: MeanAveragePrecision: New instance of MeanAveragePrecision. Example: ```python import supervision as sv from ultralytics import YOLO dataset = sv.DetectionDataset.from_yolo(...) model = YOLO(...) def callback(image: np.ndarray) -> sv.Detections: result = model(image)[0] return sv.Detections.from_ultralytics(result) mean_average_precision = sv.MeanAveragePrecision.benchmark( dataset = dataset, callback = callback ) print(mean_average_precision.map50_95) # 0.433 ``` """predictions,targets=[],[]for_,image,annotationindataset:predictions_batch=callback(image)predictions.append(predictions_batch)targets.append(annotation)returncls.from_detections(predictions=predictions,targets=targets,)
@staticmethoddefcompute_average_precision(recall:np.ndarray,precision:np.ndarray)->float:""" Compute the average precision using 101-point interpolation (COCO), given the recall and precision curves. Args: recall (np.ndarray): The recall curve. precision (np.ndarray): The precision curve. Returns: float: Average precision. """extended_recall=np.concatenate(([0.0],recall,[1.0]))extended_precision=np.concatenate(([1.0],precision,[0.0]))max_accumulated_precision=np.flip(np.maximum.accumulate(np.flip(extended_precision)))interpolated_recall_levels=np.linspace(0,1,101)interpolated_precision=np.interp(interpolated_recall_levels,extended_recall,max_accumulated_precision)average_precision=np.trapz(interpolated_precision,interpolated_recall_levels)returnaverage_precision
Each element of the list describes
a single image and has shape = (M, 6) where M is
the number of detected objects. Each row is expected to be
in (x_min, y_min, x_max, y_max, class, conf) format.
Each element of the list describes a single
image and has shape = (N, 5) where N is the
number of ground-truth objects. Each row is expected to be in
(x_min, y_min, x_max, y_max, class) format.
required
Returns:
MeanAveragePrecision: New instance of MeanAveragePrecision.
@classmethoddeffrom_tensors(cls,predictions:List[np.ndarray],targets:List[np.ndarray],)->MeanAveragePrecision:""" Calculate Mean Average Precision based on predicted and ground-truth detections at different threshold. Args: predictions (List[np.ndarray]): Each element of the list describes a single image and has `shape = (M, 6)` where `M` is the number of detected objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class, conf)` format. targets (List[np.ndarray]): Each element of the list describes a single image and has `shape = (N, 5)` where `N` is the number of ground-truth objects. Each row is expected to be in `(x_min, y_min, x_max, y_max, class)` format. Returns: MeanAveragePrecision: New instance of MeanAveragePrecision. Example: ```python import supervision as sv import numpy as np targets = ( [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1], [2.0, 2.0, 5.0, 5.0, 1], [6.0, 1.0, 8.0, 3.0, 2], ] ), np.array([[1.0, 1.0, 2.0, 2.0, 2]]), ] ) predictions = [ np.array( [ [0.0, 0.0, 3.0, 3.0, 1, 0.9], [0.1, 0.1, 3.0, 3.0, 0, 0.9], [6.0, 1.0, 8.0, 3.0, 1, 0.8], [1.0, 6.0, 2.0, 7.0, 1, 0.8], ] ), np.array([[1.0, 1.0, 2.0, 2.0, 2, 0.8]]) ] mean_average_precision = sv.MeanAveragePrecision.from_tensors( predictions=predictions, targets=targets, ) print(mean_average_precision.map50_95) # 0.6649 ``` """validate_input_tensors(predictions,targets)iou_thresholds=np.linspace(0.5,0.95,10)stats=[]# Gather matching stats for predictions and targetsfortrue_objs,predicted_objsinzip(targets,predictions):ifpredicted_objs.shape[0]==0:iftrue_objs.shape[0]:stats.append((np.zeros((0,iou_thresholds.size),dtype=bool),*np.zeros((2,0)),true_objs[:,4],))continueiftrue_objs.shape[0]:matches=cls._match_detection_batch(predicted_objs,true_objs,iou_thresholds)stats.append((matches,predicted_objs[:,5],predicted_objs[:,4],true_objs[:,4],))# Compute average precisions if any matches existifstats:concatenated_stats=[np.concatenate(items,0)foritemsinzip(*stats)]average_precisions=cls._average_precisions_per_class(*concatenated_stats)map50=average_precisions[:,0].mean()map75=average_precisions[:,5].mean()map50_95=average_precisions.mean()else:map50,map75,map50_95=0,0,0average_precisions=[]returncls(map50_95=map50_95,map50=map50,map75=map75,per_class_ap50_95=average_precisions,)