Closed
Description
Bug Description
A runtime error is thrown when compiled torch script model is loaded.
RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0
The same code works on a docker image with Torch version: 1.10.0a0+0aef44c, Torch-TensorRT version: 1.0.0+55c3bab4, TensorRT version: 8.0.3-1, and CUDA 11.3.
To Reproduce
Code:
import torch
import torch_tensorrt
import torch.nn as nn
import timm
class TrEstimator(nn.Module):
def __init__(self, cfg, pretrained=True):
super().__init__()
base_model_name = cfg['base_model_name']
input_depth = 2
self.weight_scale = cfg['weight_scale']
self.base_model = timm.create_model(base_model_name,
features_only=True,
in_chans=input_depth,
pretrained=pretrained)
out_ch = self.base_model.feature_info.channels()[-1]
self.conv_heatmap = nn.Conv2d(out_ch, 1, kernel_size=1, bias=True)
self.conv_offset = nn.Conv2d(out_ch, 2, kernel_size=1, bias=True)
def freeze_encoder(self):
self.base_model.freeze_encoder()
def unfreeze_encoder(self):
self.base_model.unfreeze_encoder()
def forward(self, prev_frame, cur_frame):
inputs = torch.stack([prev_frame, cur_frame], dim=1)
x = self.base_model(inputs)
x = x[-1]
x = x[:, :, 2:-2, 2:-2]
x_hm = self.conv_heatmap(x)
m = torch.exp(self.weight_scale * torch.sigmoid(x_hm))
heatmap = m / torch.sum(m, dim=(2, 3), keepdim=True)
offsets = self.conv_offset(x)
return heatmap, offsets
model = TrEstimator(
cfg=dict(
base_model_name="resnet34",
weight_scale=3.0
), pretrained=False)
model = model.half().cuda()
model.eval()
print('Compiling TensorRT model ...')
model_jit = torch.jit.trace(model, (torch.rand((1, 640, 1024)).half().cuda(), torch.rand((1, 640, 1024)).half().cuda()))
model = torch_tensorrt.compile(
model_jit,
inputs=[torch_tensorrt.Input(shape=[1, 640, 1024], dtype=torch.half), torch_tensorrt.Input(shape=[1, 640, 1024], dtype=torch.half)],
enabled_precisions={torch.half},
truncate_long_and_double=True
)
torch.jit.save(model, f"./model.ts")
print('Loading TensorRT model ...')
model_new = torch.jit.load(f"./model.ts")
Output:
Compiling TensorRT model ...
WARNING: [Torch-TensorRT] - For input prev_frame, found user specified input dtype as Float16. The compiler is going to use the user setting Float16
WARNING: [Torch-TensorRT] - For input cur_frame, found user specified input dtype as Float16. The compiler is going to use the user setting Float16
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - Sum converter disregards dtype
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuDNN 8.6.0 but loaded cuDNN 8.5.0
Loading TensorRT model ...
Traceback (most recent call last):
File "compile.py", line 61, in <module>
model_new = torch.jit.load(f"./model.ts")
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_serialization.py", line 162, in load
cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0
Expected behavior
Model loads successfully.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version: 1.3.0
- PyTorch Version: 1.13.0+cu117
- CPU Architecture: Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
- OS (e.g., Linux): Ubuntu 20.04.5 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip (after uninstalling from NGC Docker image nvcr.io/nvidia/pytorch:22.11-py3) - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.8.10
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
dpkg -l | grep nvinfer
ii libnvinfer-bin 8.5.1-1+cuda11.8 amd64 TensorRT binaries
ii libnvinfer-dev 8.5.1-1+cuda11.8 amd64 TensorRT development libraries and headers
ii libnvinfer-plugin-dev 8.5.1-1+cuda11.8 amd64 TensorRT plugin libraries
ii libnvinfer-plugin8 8.5.1-1+cuda11.8 amd64 TensorRT plugin libraries
ii libnvinfer8 8.5.1-1+cuda11.8 amd64 TensorRT runtime libraries
Activity