博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch 直接转 tensorrt 的trt文件,并运行,量化int8
阅读量:3577 次
发布时间:2019-05-20

本文共 2996 字,大约阅读时间需要 9 分钟。

这个项目是我使用tensorrt,torch2tr包,将训练的yolov5s模型,进行tensorrt推理加速,量化精度为int8,但值得注意的是tensorrt对于forward的部份操作是不支持的,如切片等,这时可以考虑模型部份转换或者改写forward方法;(关注查看完整代码)

1.tensorrt加速的原理:将conv、bn、relu 和 conv、relu 进行融合,融合为一层,从而减少网络参数;

2.tensorrt对于分支结构加速效果尤为明显,像inception网络等;如分支1、分支2、分支3可能在同一个时间步骤下都包含1*1卷积层,tensorrt会将这三个分支的1*1卷积层合并为一个1*1卷积层,从而达到减少卷积层的目的;

3.tensorrt还可以对模型进行量化;量化到int8、fp16等;

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

*

yolov5 pt2trt代码示例:

import torchfrom torch2trt import torch2trtfrom torch2trt import TRTModuleimport timemodel = torch.load('/home/Oyj/yolov5/yolov5s_trainModel.pt').cuda()x = torch.rand((1, 3, 608, 608)).cuda()  # 占位符,3通道的608*608尺寸图片,最好选用图片# convert to TensorRT feeding sample data as inputmodel_trt = torch2trt(model, [x],int8_mode=True)#这里首先把pytorch模型加载到CUDA,然后定义好输入的样例x(这里主要用来指定输入的shape,用ones, zeros都可以)。model_trt就是转成功的TensorRT模型,你运行上面代码没报错就证明你转tensorRT成功了。torch.save(model_trt.state_dict(), 'yolov5s_trt.pth')model_trt = TRTModule()model_trt.load_state_dict(torch.load('yolov5s_trt.pth'))def _make_grid(nx=20, ny=20):    yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])    return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()def detect(pre):    z = []    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    stride = torch.Tensor([8,16,32]).cuda()    anchors = torch.Tensor([[[1.25000, 1.87500, 3.62500],                             [2.00000, 3.87500, 4.87500],                             [4.12500, 3.68750, 11.65625]],                            [[1.62500, 3.81250, 2.81250],                             [3.75000, 2.81250, 6.18750],                             [2.87500, 7.43750, 10.18750]]]).cuda()    anchor_grid = torch.Tensor([[[[[[ 10.,  13.]]],                                  [[[ 16.,  30.]]],                                  [[[ 33.,  23.]]]]],                                [[[[[ 30.,  61.]]],                                  [[[ 62.,  45.]]],                                  [[[ 59., 119.]]]]],                                [[[[[116.,  90.]]],                                  [[[156., 198.]]],                                  [[[373., 326.]]]]]]).cuda()    no = 28    for i in range(len(pre)):        # print(pre[0].shape)        # bs, _, ny, nx = x.shape  # x(bs,255,20,20) to x(bs,3,20,20,85)        # x[i] = x.view(bs, 3, 28, ny, nx).permute(0, 1, 3, 4, 2).contiguous()        bs,_,ny,nx,_ = pre[i].shape        grid = _make_grid(nx, ny).to(pre[i].device)        y = pre[i].sigmoid()        y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid) * stride[i]  # xy        y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i]  # wh        z.append(y.view(bs, -1, no))    return (torch.cat(z, 1), pre)c1 = time.time()for i in range(10000):    y_trt = model_trt(x)    # y_trt = model(x)    res = detect(y_trt)    print(res[0].shape,len(res[1]))print(time.time() - c1)# 0.0004208087921142578# print(y_trt.argmax(dim=1, keepdim=True))# tensor([[534]], device='cuda:0')# origin 57.03164482116699# trt 12.815325498580933

转载地址:http://quagj.baihongyu.com/

你可能感兴趣的文章
初识Struts
查看>>
多线程打印A12Z34。。。
查看>>
strutsc踩过的坑
查看>>
maven安装和使用踩坑
查看>>
第一次紧张刺激的面试
查看>>
咕泡笔记导读篇
查看>>
eclipse安装maven和简单使用
查看>>
关于交往所思
查看>>
jdbc的封装
查看>>
数据库存入数据变为???
查看>>
实现数据库源的几种方式和开源数据源的使用
查看>>
元数据的获取和 数据库读写操作封装
查看>>
java文件的上传和下载(细节问题)
查看>>
DBUtils框架QueryRunner的 使用
查看>>
springMVC之controller笔记
查看>>
springmvc类型转换
查看>>
ai 的研究生院校
查看>>
spring开发步骤以及bean的配置和注入方式
查看>>
关于鼻炎的日常饮食和注意
查看>>
Spring的IOC的注解的详解
查看>>