JobPlus知识库 IT 工业智能4.0 文章
用MXnet预训练模型初始化Pytorch模型

1、MXnet符号图:

基于MXnet所构建的符号图是一种静态计算图,图结构与内存管理都是静态的。以Resnet50_v2为例,Bottleneck结构的符号图如下:

[python]

  1. bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')  
  2. act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')  
  3. conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),  
  4.                            no_bias=True, workspace=workspace, name=name + '_conv1')  
  5. bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')  
  6. act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')  
  7. conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1),  
  8.                            no_bias=True, workspace=workspace, name=name + '_conv2')  
  9. bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')  
  10. act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')  
  11. conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,  
  12.                            workspace=workspace, name=name + '_conv3')  
  13. if dim_match:  
  14.     shortcut = data  
  15. else:  
  16.     shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,  
  17.                                     workspace=workspace, name=name+'_sc')  
  18. return conv3 + shortcut  

2、加载符号图与模型参数:

MXnet预训练模型包括json配置文件与param参数文件:

-- resnet-50-0000.params

-- resnet-50-symbol.json

通过加载这两个文件,便可以获得符号图结构、模型权重与辅助参数信息:

[python] 

  1.         prefix, index, num_layer = 'resnet-50', args.epoch, 50  
  2.         prefix = os.path.join(ROOT_PATH, "./mx_model/models/{}".format(prefix))  
  3.         symbol, param_args, param_auxs = mx.model.load_checkpoint(prefix, index)  

3、Pytorch动态图:

Pytorch是一种动态类型框架,计算图构建与内存管理都是动态的,适合专注于研究的算法开发。按照命令式编程方式,能够及时获取计算图中Tensor及其导数的数值信息。Resnet50_v2的Bottleneck结构如下:

[python]

  1. class Bottleneck(nn.Module):  
  2.     expansion = 4  
  3.   
  4.     def __init__(self, inplanes, planes, stride=1, downsample=False):  
  5.         super(Bottleneck, self).__init__()  
  6.         self.bn1 = nn.BatchNorm2d(inplanes, eps)  
  7.         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)  
  8.         self.bn2 = nn.BatchNorm2d(planes, eps)  
  9.         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,  
  10.                                padding=1, bias=False)  
  11.         self.bn3 = nn.BatchNorm2d(planes, eps)  
  12.         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)  
  13.         self.relu = nn.ReLU(inplace=True)  
  14.         self.downsample = downsample  
  15.         if downsample:  
  16.             self.conv_sc = nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=stride, bias=False)  
  17.         self.stride = stride  
  18.   
  19.     def forward(self, input):  
  20.   
  21.         out = self.bn1(input)  
  22.         out1 = self.relu(out)  
  23.         residual = input  
  24.         out = self.conv1(out1)  
  25.   
  26.         out = self.bn2(out)  
  27.         out = self.relu(out)  
  28.         out = self.conv2(out)  
  29.   
  30.         out = self.bn3(out)  
  31.         out = self.relu(out)  
  32.         out = self.conv3(out)  
  33.   
  34.         if self.downsample:  
  35.             residual = self.conv_sc(out1)  
  36.         out += residual  
  37.         return out  

4、解析MXnet参数、初始化Pytorch模型:

首先需要将MXnet参数转为Numpy数组形式的字典。BN层、Conv2D层、FC层解析如下:

[python] 

  1. def bn_parse(args, auxs, name, args_dict, fix_gamma=False):  
  2.     """ name0: PyTorch layer name; 
  3.         name1: MXnet layer name."""  
  4.     args_dict[name[0]] = {}  
  5.     if not fix_gamma:  
  6.         args_dict[name[0]]['running_mean'] = auxs[name[1]+'_moving_mean'].asnumpy()  
  7.         args_dict[name[0]]['running_var'] = auxs[name[1]+'_moving_var'].asnumpy()  
  8.         args_dict[name[0]]['gamma'] = args[name[1]+'_gamma'].asnumpy()  
  9.         args_dict[name[0]]['beta'] = args[name[1]+'_beta'].asnumpy()  
  10.     else:  
  11.         _mv = auxs[name[1]+'_moving_var'].asnumpy()  
  12.         _mm = auxs[name[1]+'_moving_mean'].asnumpy() - np.multiply(args[name[1]+'_beta'].asnumpy(), np.sqrt(_mv+eps))  
  13.         args_dict[name[0]]['running_mean'] = _mm  
  14.         args_dict[name[0]]['running_var'] = _mv  
  15.     return args_dict  

[python] 

  1. def conv_parse(args, auxs, name, args_dict):  
  2.     """ name0: PyTorch layer name; 
  3.         name1: MXnet layer name."""  
  4.     args_dict[name[0]] = {}  
  5.     w = args[name[1]+'_weight'].asnumpy()  
  6.     args_dict[name[0]]['weight'] = w # N, M, k1, k2  
  7.     return args_dict  

[python]

  1. def fc_parse(args, auxs, name, args_dict):  
  2.     """ name0: PyTorch layer name; 
  3.         name1: MXnet layer name."""  
  4.     args_dict[name[0]] = {}  
  5.     args_dict[name[0]]['weight'] = args[name[1]+'_weight'].asnumpy()  
  6.     args_dict[name[0]]['bias'] = args[name[1]+'_bias'].asnumpy()  
  7.     return args_dict  

然后逐层遍历Pytorch的每个module,并完成模型参数赋值,从而实现用MXnet预训练模型初始化Pytorch模型的目的:

[python] 

  1. # model initialization for PyTorch from MXnet params  
  2. class resnet(object):  
  3.     def __init__(self, name, num_layer, args, auxs, prefix='module.'):  
  4.         self.name = name  
  5.         num_stages = 4  
  6.         if num_layer == 50:  
  7.             units = [3, 4, 6, 3]  
  8.         elif num_layer == 101:  
  9.             units = [3, 4, 23, 3]  
  10.         self.num_layer = str(num_layer)  
  11.         self.param_dict = arg_parse(args, auxs, num_stages, units, prefix=prefix)  
  12.   
  13.     def bn_init(self, n, m):  
  14.         if not (m.weight is None):  
  15.             m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['gamma']))  
  16.             m.bias.data.copy_(torch.FloatTensor(self.param_dict[n]['beta']))  
  17.         m.running_mean.copy_(torch.FloatTensor(self.param_dict[n]['running_mean']))  
  18.         m.running_var.copy_(torch.FloatTensor(self.param_dict[n]['running_var']))  
  19.   
  20.     def conv_init(self, n, m):  
  21.         #m.weight.data.zero_()  
  22.         m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['weight']))  
  23.   
  24.     def fc_init(self, n, m):  
  25.         m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['weight']))  
  26.         m.bias.data.copy_(torch.FloatTensor(self.param_dict[n]['bias']))  
  27.   
  28.     def init_model(self, model):  
  29.         for n, m in model.named_modules():  
  30.             if isinstance(m, nn.BatchNorm2d):  
  31.                 self.bn_init(n, m)  
  32.             elif isinstance(m, nn.Conv2d):  
  33.                 self.conv_init(n, m)  
  34.             elif isinstance(m, nn.Linear):  
  35.                 self.fc_init(n, m)  
  36.         return model  

5、使用MXnet的数据加载器:

mx.io.ImageRecordIter的输出转为Pytorch Tensor,便可用于Pytorch模型的训练、验证与测试,迭代器设计如下:

[python] 

  1. def __iter__(self):  
  2.         for batch in self.data:  
  3.             nd_data = batch.data[0].asnumpy()  
  4.             nd_label = batch.label[0].asnumpy()  
  5.             input_data = torch.FloatTensor(nd_data)  
  6.             input_label = torch.LongTensor(nd_label)  
  7.   
  8.             if self.cuda:  
  9.                 yield input_data.cuda(non_blocking=True), input_label.cuda(non_blocking=True)  
  10.             else:  
  11.                 yield input_data, input_label  



如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
250人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序