Shortcuts

Source code for mmdet.models.test_time_augs.merge_augs

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List, Optional, Union

import numpy as np
import torch

try:
    from mmcv.ops import nms
except ImportError:

    def nms(*args, **kwargs):
        raise RuntimeError('nms requires mmcv to be compiled with ops. Please '
                           'reinstall onedl-mmcv with CUDA support.')


from mmengine.config import ConfigDict
from torch import Tensor

from mmdet.structures.bbox import bbox_mapping_back


# TODO remove this, never be used in mmdet
[docs] def merge_aug_proposals(aug_proposals, img_metas, cfg): """Merge augmented proposals (multiscale, flip, etc.) Args: aug_proposals (list[Tensor]): proposals from different testing schemes, shape (n, 5). Note that they are not rescaled to the original image size. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. cfg (dict): rpn test config. Returns: Tensor: shape (n, 4), proposals corresponding to original image scale. """ cfg = copy.deepcopy(cfg) # deprecate arguments warning if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: warnings.warn( 'In rpn_proposal or test_cfg, ' 'nms_thr has been moved to a dict named nms as ' 'iou_threshold, max_num has been renamed as max_per_img, ' 'name of original arguments and the way to specify ' 'iou_threshold of NMS will be deprecated.') if 'nms' not in cfg: cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) if 'max_num' in cfg: if 'max_per_img' in cfg: assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \ f'max_per_img at the same time, but get {cfg.max_num} ' \ f'and {cfg.max_per_img} respectively' \ f'Please delete max_num which will be deprecated.' else: cfg.max_per_img = cfg.max_num if 'nms_thr' in cfg: assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \ f'iou_threshold in nms and ' \ f'nms_thr at the same time, but get ' \ f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \ f' respectively. Please delete the nms_thr ' \ f'which will be deprecated.' recovered_proposals = [] for proposals, img_info in zip(aug_proposals, img_metas): img_shape = img_info['img_shape'] scale_factor = img_info['scale_factor'] flip = img_info['flip'] flip_direction = img_info['flip_direction'] _proposals = proposals.clone() _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape, scale_factor, flip, flip_direction) recovered_proposals.append(_proposals) aug_proposals = torch.cat(recovered_proposals, dim=0) merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(), aug_proposals[:, -1].contiguous(), cfg.nms.iou_threshold) scores = merged_proposals[:, 4] _, order = scores.sort(0, descending=True) num = min(cfg.max_per_img, merged_proposals.shape[0]) order = order[:num] merged_proposals = merged_proposals[order, :] return merged_proposals
# TODO remove this, never be used in mmdet
[docs] def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg): """Merge augmented detection bboxes and scores. Args: aug_bboxes (list[Tensor]): shape (n, 4*#class) aug_scores (list[Tensor] or None): shape (n, #class) img_shapes (list[Tensor]): shape (3, ). rcnn_test_cfg (dict): rcnn test config. Returns: tuple: (bboxes, scores) """ recovered_bboxes = [] for bboxes, img_info in zip(aug_bboxes, img_metas): img_shape = img_info[0]['img_shape'] scale_factor = img_info[0]['scale_factor'] flip = img_info[0]['flip'] flip_direction = img_info[0]['flip_direction'] bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, flip_direction) recovered_bboxes.append(bboxes) bboxes = torch.stack(recovered_bboxes).mean(dim=0) if aug_scores is None: return bboxes else: scores = torch.stack(aug_scores).mean(dim=0) return bboxes, scores
[docs] def merge_aug_results(aug_batch_results, aug_batch_img_metas): """Merge augmented detection results, only bboxes corresponding score under flipping and multi-scale resizing can be processed now. Args: aug_batch_results (list[list[[obj:`InstanceData`]]): Detection results of multiple images with different augmentations. The outer list indicate the augmentation . The inter list indicate the batch dimension. Each item usually contains the following keys. - scores (Tensor): Classification scores, in shape (num_instance,) - labels (Tensor): Labels of bboxes, in shape (num_instances,). - bboxes (Tensor): In shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). aug_batch_img_metas (list[list[dict]]): The outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. Each dict in the list contains information of an image in the batch. Returns: batch_results (list[obj:`InstanceData`]): Same with the input `aug_results` except that all bboxes have been mapped to the original scale. """ num_augs = len(aug_batch_results) num_imgs = len(aug_batch_results[0]) batch_results = [] aug_batch_results = copy.deepcopy(aug_batch_results) for img_id in range(num_imgs): aug_results = [] for aug_id in range(num_augs): img_metas = aug_batch_img_metas[aug_id][img_id] results = aug_batch_results[aug_id][img_id] img_shape = img_metas['img_shape'] scale_factor = img_metas['scale_factor'] flip = img_metas['flip'] flip_direction = img_metas['flip_direction'] bboxes = bbox_mapping_back(results.bboxes, img_shape, scale_factor, flip, flip_direction) results.bboxes = bboxes aug_results.append(results) merged_aug_results = results.cat(aug_results) batch_results.append(merged_aug_results) return batch_results
[docs] def merge_aug_scores(aug_scores): """Merge augmented bbox scores.""" if isinstance(aug_scores[0], torch.Tensor): return torch.mean(torch.stack(aug_scores), dim=0) else: return np.mean(aug_scores, axis=0)
[docs] def merge_aug_masks(aug_masks: List[Tensor], img_metas: dict, weights: Optional[Union[list, Tensor]] = None) -> Tensor: """Merge augmented mask prediction. Args: aug_masks (list[Tensor]): each has shape (n, c, h, w). img_metas (dict): Image information. weights (list or Tensor): Weight of each aug_masks, the length should be n. Returns: Tensor: has shape (n, c, h, w) """ recovered_masks = [] for i, mask in enumerate(aug_masks): if weights is not None: assert len(weights) == len(aug_masks) weight = weights[i] else: weight = 1 flip = img_metas.get('flip', False) if flip: flip_direction = img_metas['flip_direction'] if flip_direction == 'horizontal': mask = mask[:, :, :, ::-1] elif flip_direction == 'vertical': mask = mask[:, :, ::-1, :] elif flip_direction == 'diagonal': mask = mask[:, :, :, ::-1] mask = mask[:, :, ::-1, :] else: raise ValueError( f"Invalid flipping direction '{flip_direction}'") recovered_masks.append(mask[None, :] * weight) merged_masks = torch.cat(recovered_masks, 0).mean(dim=0) if weights is not None: merged_masks = merged_masks * len(weights) / sum(weights) return merged_masks