Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update threshold for batched NMS #8925

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Update threshold for batched NMS #8925

wants to merge 2 commits into from

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Feb 20, 2025

With #8766 now merged, a normal call to nms() is able to handle much more boxes while staying fast. This PR updates our internal heuristic in batched_nms() to switch between the coordinate trick and the for-loop-based approach.

100_000 seems to be a decent new threshold.

WITH NUM_CLASSES = 80
num_boxes = 10 med = 0.48ms +- 0.12  # _batched_nms_coordinate_trick
num_boxes = 10 med = 4.36ms +- 0.24  # _batched_nms_vanilla

num_boxes = 100 med = 0.48ms +- 0.04
num_boxes = 100 med = 25.95ms +- 0.38

num_boxes = 1000 med = 0.55ms +- 0.04
num_boxes = 1000 med = 36.35ms +- 0.42

num_boxes = 10000 med = 1.41ms +- 0.08
num_boxes = 10000 med = 37.21ms +- 0.80

num_boxes = 20000 med = 2.85ms +- 0.07
num_boxes = 20000 med = 37.75ms +- 3.63

num_boxes = 80000 med = 18.23ms +- 1.95
num_boxes = 80000 med = 39.47ms +- 0.48

num_boxes = 100000 med = 26.48ms +- 2.04
num_boxes = 100000 med = 36.94ms +- 0.66

num_boxes = 200000 med = 93.27ms +- 4.03
num_boxes = 200000 med = 45.62ms +- 2.52


WITH NUM_CLASSES = 10
num_boxes = 10 med = 0.44ms +- 0.11
num_boxes = 10 med = 3.26ms +- 0.12

num_boxes = 100 med = 0.45ms +- 0.04
num_boxes = 100 med = 4.59ms +- 0.32

num_boxes = 1000 med = 0.50ms +- 0.04
num_boxes = 1000 med = 4.71ms +- 0.29

num_boxes = 10000 med = 1.33ms +- 0.11
num_boxes = 10000 med = 5.17ms +- 0.08

num_boxes = 20000 med = 2.22ms +- 0.06
num_boxes = 20000 med = 5.82ms +- 0.19

num_boxes = 80000 med = 17.78ms +- 1.46
num_boxes = 80000 med = 10.21ms +- 0.67

num_boxes = 100000 med = 25.91ms +- 2.55
num_boxes = 100000 med = 11.35ms +- 0.91

num_boxes = 200000 med = 92.23ms +- 0.24
num_boxes = 200000 med = 18.48ms +- 1.39
import torch
from time import perf_counter_ns
from torchvision.ops.boxes import _batched_nms_coordinate_trick, _batched_nms_vanilla


def bench(f, *args, num_exp=100, warmup=0, **kwargs):

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    for _ in range(num_exp):
        start = perf_counter_ns()
        f(*args, **kwargs)
        torch.cuda.synchronize()
        end = perf_counter_ns()
        times.append(end - start)
    return torch.tensor(times).float()

def report_stats(times, unit="ms", prefix=""):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    print(f"{prefix}{med = :.2f}{unit} +- {std:.2f}")
    return med


def make_boxes(num_boxes, num_classes=10):
    boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1).to("cuda")
    assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
    assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2

    scores = torch.rand(num_boxes).to("cuda")
    idxs = torch.randint(0, num_classes, size=(num_boxes,)).to("cuda")
    return boxes, scores, idxs

NUM_EXP = 30
for num_boxes in (10, 100, 1000, 10000, 20_000, 80_000, 100000, 200_000):
    for f in (_batched_nms_coordinate_trick, _batched_nms_vanilla):
        boxes, scores, idxs = make_boxes(num_boxes)
        times = bench(f, boxes, scores, idxs, iou_threshold=.7, warmup=1, num_exp=NUM_EXP)
        report_stats(times, prefix=f"{num_boxes = } ")
    print()

Copy link

pytorch-bot bot commented Feb 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8925

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 2 Unrelated Failures

As of commit f3b94cf with merge base e239710 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants