email:[email protected]
本文介绍了使用nni进行yolo v5的剪枝适配和测试方法
import torch, torchvision
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner, L2NormPruner,FPGMPruner,ActivationAPoZRankPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from rich import print
from utils.general import check_img_size
from models.common import Conv
from models.experimental import attempt_load
from models.yolo import Detect
from utils.activations import SiLU
import torch.nn as nn
from nni.compression.pytorch.utils.counter import count_flops_params
首先,导入对应的包,接着导入模型
device = device = torch.device("cuda:1")
model = attempt_load('/backup/nni/yolov5/output_pruned/deepsort_det20211202.pt', map_location=device, inplace=True, fuse=True) # load FP32 model
model.eval()
得到model类,此时model的类应该包含所有层的name,信息,这部分信息之后会用到。
for k, m in model.named_modules():
if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, Detect):
m.inplace = False
m.onnx_dynamic = False
if hasattr(m, 'forward_export'):
m.forward = m.forward_export # assign custom forward (optional)
接着遍历模型所有modules,卷积层激活函数不变,即:
class SiLU(nn.Module): # export-friendly version of nn.SiLU()
def forward(x):
return x * torch.sigmoid(x)
关闭onnx动态配置
imgsz = (640, 640)
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
gs = int(max(model.stride)) # grid size (max stride)
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
im = torch.zeros(1, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
dummy_input = im
设置输入im
cfg_list = [{
'sparsity': 0.3, 'op_types': ['Conv2d'],'op_names': [
'model.0.conv',
'model.1.conv',
'model.2.cv1.conv',
'model.2.cv2.conv',
'model.2.cv3.conv',
'model.2.m.0.cv1.conv',
'model.2.m.0.cv2.conv',
'model.2.m.1.cv1.conv',
'model.2.m.1.cv2.conv',
'model.2.m.2.cv1.conv',
'model.2.m.2.cv2.conv',
'model.2.m.3.cv1.conv',
'model.2.m.3.cv2.conv',
'model.3.conv',
'model.4.cv1.conv',
'model.4.cv2.conv',
'model.4.cv3.conv',
'model.4.m.0.cv1.conv',
'model.4.m.0.cv2.conv',
'model.4.m.1.cv1.conv',
'model.4.m.1.cv2.conv',
'model.4.m.2.cv1.conv',
'model.4.m.2.cv2.conv',
'model.4.m.3.cv1.conv',
'model.4.m.3.cv2.conv',
'model.4.m.4.cv1.conv',
'model.4.m.4.cv2.conv',
'model.4.m.5.cv1.conv',
'model.4.m.5.cv2.conv',
'model.4.m.6.cv1.conv',
'model.4.m.6.cv2.conv',
'model.4.m.7.cv1.conv',
'model.4.m.7.cv2.conv',
'model.5.conv',
'model.6.cv1.conv',
'model.6.cv2.conv',
'model.6.cv3.conv',
'model.6.m.0.cv1.conv',
'model.6.m.0.cv2.conv',
'model.6.m.1.cv1.conv',
'model.6.m.1.cv2.conv',
'model.6.m.2.cv1.conv',
'model.6.m.2.cv2.conv',
'model.6.m.3.cv1.conv',
'model.6.m.3.cv2.conv',
'model.6.m.4.cv1.conv',
'model.6.m.4.cv2.conv',
'model.6.m.5.cv1.conv',
'model.6.m.5.cv2.conv',
'model.6.m.6.cv1.conv',
'model.6.m.6.cv2.conv',
'model.6.m.7.cv1.conv',
'model.6.m.7.cv2.conv',
'model.6.m.8.cv1.conv',
'model.6.m.8.cv2.conv',
'model.6.m.9.cv1.conv',
'model.6.m.9.cv2.conv',
'model.6.m.10.cv1.conv',
'model.6.m.10.cv2.conv',
'model.6.m.11.cv1.conv',
'model.6.m.11.cv2.conv',
'model.7.conv',
'model.8.cv1.conv',
'model.8.cv2.conv',
'model.8.cv3.conv',
'model.8.m.0.cv1.conv',
'model.8.m.0.cv2.conv',
'model.8.m.1.cv1.conv',
'model.8.m.1.cv2.conv',
'model.8.m.2.cv1.conv',
'model.8.m.2.cv2.conv',
'model.8.m.3.cv1.conv',
'model.8.m.3.cv2.conv',
'model.9.cv1.conv',
'model.9.cv2.conv',
'model.10.conv',
'model.13.cv1.conv',
'model.13.cv2.conv',
'model.13.cv3.conv',
'model.13.m.0.cv1.conv',
'model.13.m.0.cv2.conv',
'model.13.m.1.cv1.conv',
'model.13.m.1.cv2.conv',
'model.13.m.2.cv1.conv',
'model.13.m.2.cv2.conv',
'model.13.m.3.cv1.conv',
'model.13.m.3.cv2.conv',
'model.14.conv',
'model.17.cv1.conv',
'model.17.cv2.conv',
'model.17.cv3.conv',
'model.17.m.0.cv1.conv',
'model.17.m.0.cv2.conv',
'model.17.m.1.cv1.conv',
'model.17.m.1.cv2.conv',
'model.17.m.2.cv1.conv',
'model.17.m.2.cv2.conv',
'model.17.m.3.cv1.conv',
'model.17.m.3.cv2.conv',
'model.18.conv',
'model.20.cv1.conv',
'model.20.cv2.conv',
'model.20.cv3.conv',
'model.20.m.0.cv1.conv',
'model.20.m.0.cv2.conv',
'model.20.m.1.cv1.conv',
'model.20.m.1.cv2.conv',
'model.20.m.2.cv1.conv',
'model.20.m.2.cv2.conv',
'model.20.m.3.cv1.conv',
'model.20.m.3.cv2.conv',
'model.21.conv',
'model.23.cv1.conv',
'model.23.cv2.conv',
'model.23.cv3.conv',
'model.23.m.0.cv1.conv',
'model.23.m.0.cv2.conv',
'model.23.m.1.cv1.conv',
'model.23.m.1.cv2.conv',
'model.23.m.2.cv1.conv',
'model.23.m.2.cv2.conv',
'model.23.m.3.cv1.conv',
'model.23.m.3.cv2.conv'
]
}
{
'op_names':['model.24.m.0','model.24.m.1','model.24.m.2'],
'exclude': True
}
]
设置config to prune,将所有conv加入toprune list,记得将最后detect部分的三个conv过滤。
pruner = L1NormPruner(model, cfg_list)
# pruner = L2NormPruner(model, cfg_list)
# pruner = FPGMPruner(model, cfg_list)
_, masks = pruner.compress()
# print(masks)
pruner.export_model(model_path='deepsort_yolov5m.pt', mask_path='deepsort_mask.pt')
pruner.show_pruned_weights()
pruner._unwrap_model()
编译方法,运用更新模型掩码。
print("im.shape:",dummy_input.shape)
# 1.
start = time.time()
for _ in range(100):
use_mask_out = model(dummy_input)
# print(use_mask_out[0].shape)
print('elapsed time_before_pruned: ', (time.time() - start)*100)
测试模型输出速度。
xxxxxxxxxx
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file="mask.pt")
m_speedup.speedup_model()
model.eval()
_,__,___ = count_flops_params(model,dummy_input)
torch.save(model,"output_pruned/pruned_deepsortdetv2.pt")
start = time.time()
for _ in range(10):
use_mask_out = model(dummy_input)
print(get_parameter_number(model))
print('elapsed time when use mask: ', (time.time() - start)*100)
保存模型掩码和原模型。
xxxxxxxxxx
#剪枝后模型加载
model_to_test = torch.load("output_pruned/pruned_deepsortdetv2.pt")
以上