Skip to content

Batched NMS generates onnx IF operator, which causes onnx conversion problem #8892

Open
@RiverLight4

Description

🐛 Describe the bug

Background

I'd like to convert Mask R-CNN model into RyzenAI model and run on Ryzen NPU.
When I convert torchvision Mask R-CNN model into ONNX, it contains IF operator.
RyzenAI doesn't support IF operator with NPU inference, so I tried to remove it, but I couldn't found the solution to fix.

Issue

Could anyone please check and answer?

  • Is this behavior expected?
  • To convert Mask R-CNN into ONNX model without using IF operator, what should I do?

Sample code

import torch
import torchvision

model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=None, progress=True)
model.load_state_dict(torch.load('./maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth')) # already downloaded from https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth
model.eval()

x = torch.rand(1, 3, 480, 640)
predictions = model(x)

torch.onnx.export(model,
                  x, 
                  "mask_rcnn_torchvision.onnx",
                  opset_version = 11,
                  input_names=['input'],
                  output_names = ['boxes', 'labels', 'scores', 'masks'],
                  dynamic_axes = {'input': {2: 'height', 3: 'width'}, 'boxes': { 1:'num'}, 'labels':{ 1: 'num'}, 'scores': {1: 'num'},'masks' : {1:'num'}})

Results

mask_rcnn_torchvision.onnx contains IF operator. While this operator is in rpn block, the name is not /rpn/foobar but onnx::If_xxxx.

Image

sub_graph1 (ELSE subgraph):

Image

sub_graph (THEN subgraph):

Image

Expected behavior

mask_rcnn_torchvision.onnx doesn't contain IF operator.

My investigation

I traced inside torchvision Mask R-CNN structure and found that it may be caused by torch.ops.batched_nms()

def batched_nms(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Args:
boxes (Tensor[N, 4]): boxes where NMS will be performed. They
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
``0 <= y1 < y2``.
scores (Tensor[N]): scores for each one of the boxes
idxs (Tensor[N]): indices of the categories for each one of the boxes.
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
Returns:
Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted
in decreasing order of scores
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(batched_nms)
# Benchmarks that drove the following thresholds are at
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
else:
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
@torch.jit._script_if_tracing
def _batched_nms_coordinate_trick(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
# strategy: in order to perform NMS independently per class,
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
keep = nms(boxes_for_nms, scores, iou_threshold)
return keep
@torch.jit._script_if_tracing
def _batched_nms_vanilla(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
# Based on Detectron2 implementation, just manually call nms() on each class independently
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
for class_id in torch.unique(idxs):
curr_indices = torch.where(idxs == class_id)[0]
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
keep_mask[curr_indices[curr_keep_indices]] = True
keep_indices = torch.where(keep_mask)[0]
return keep_indices[scores[keep_indices].sort(descending=True)[1]]

I think IF operator is generated by this code:

if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)

However, commented out this code doesn't work, using _batched_nms_vanilla() instead of _batched_nms_coordinate_trick() also doesn't work, too.

Versions

(ryzen-ai) PS C:\work\mask_vision> python .\collect_env.py
Collecting environment information...
PyTorch version: 2.3.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Enterprise Evaluation (10.0.22631 64 bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: N/A

Python version: 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:07:43) [MSC v.1942 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22631-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: AMD Ryzen 7 8845HS
[Detail information is REDACTED because of NDA]

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] onnx==1.16.1
[pip3] onnxruntime==1.17.0
[pip3] onnxruntime_extensions==0.13.0
[pip3] onnxruntime-vitisai==1.19.0
[pip3] onnxscript==0.1.0.dev20250123
[pip3] onnxsim==0.4.36
[pip3] torch==2.3.1+cpu
[pip3] torch-geometric==2.6.1
[pip3] torchvision==0.18.1+cpu
[pip3] vai-q-onnx==1.19.0
[conda] mkl                       2021.4.0                 pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] torch                     2.3.1+cpu                pypi_0    pypi
[conda] torch-geometric           2.6.1                    pypi_0    pypi
[conda] torchvision               0.18.1+cpu               pypi_0    pypi
(ryzen-ai) PS C:\work\mask_vision>

NOTE: It is difficult to update into newest version because of compatibility. It seems that batched_nms is not changed, so anyone may be able to reproduce it in newest version.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions