安装了TVM的框架后,追寻TVM提供的教程,以Compile MXNet Models为例子,开始TVM的尝试。[1]
1. Prepare
主要包括环境的配置和网络模型的获取。
1. Environment
通过编译安装TVM
库,可参考: TVM学习(1)–搭建环境
首先使用pip3
安装mxnet的package:
pip3 install mxnet --user
接下来即可在Python中使用TVM
和MXNet
了:
# some standard imports
import mxnet as mx
import tvm
import tvm.relay as relay
import numpy as np
2. Get Model & Data
由于TVM对MXNet原生支持,这里只需要下载预训练好的MXNet模型即可,这里选取了Resnet18
网络。
-
从
mxnet.gluon.model_zoo.vision
获取模型from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True)
首先通过
mxnet.gluon.model_zoo.vision
包的get_model
函数获取预编译的resnet18_v1
模型:其分别调用了
mxnet.gluon.model_zoo.model_store.py
的resnet18_v1()
函数, 以及mxnet.gluon.model_zoo.vision.resnet.py
的get_model_file()
函数获取parameters.get_model('resnet18_v1', pretrained=True) resnet18_v1(pretrained=True) get_resnet(1, 18, pretrained=True) get_model_file('resnet%d_v%d'%(num_layers, version), root=root) # num_layers=18 version=1
get_model_file()函数从https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/下载相应的文件模型参数文件。
下载到本地地址为$HOME/.mxnet/models
-
获取输入Data
这里我们选择了一张小猫图片,如下图:
输入图片的获取和处理程序如下:
from tvm.contrib.download import download_testdata from PIL import Image from matplotlib import pyplot as plt img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' img_name = 'cat.png' img_path = download_testdata(img_url, 'cat.png', module='data') # download from url. Functions: # download.py -> download_testdata() ----> download.py -> download() image = Image.open(img_path).resize((224, 224)) # 将图像resize为特定大小[224, 224] def transform_image(image): # mean = [123., 117., 104.] # std = [58.395, 57.12, 57.375] # 在ImageNet上训练数据集的mean和std # 图片的均值和方差 image = np.array(image) - np.array([123., 117., 104.]) image /= np.array([58.395, 57.12, 57.375]) # 将RGB三维图片转化为numpy数组。 # 得到的数组为(224,224,3)的格式。 image = image.transpose((2, 0, 1)) # 举着转置,(224,224,3) ---> (3,224,224) image = image[np.newaxis, :] # 添加轴 (3,224,224) ---> (1,3,224,224) return image x = transform_image(image) # print('x', x.shape)
这样我们得到的x为[1,3,224,224]维度的ndarray。这个符合NCHW格式标准,也是我们通用的张量格式。
-
图片分类标签
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', '4d0b62f3d01426887599d4f7ede23ee5/raw/', '596b27d23537e5a1b5751d2b0481ef172f58b539/', 'imagenet1000_clsid_to_human.txt']) synset_name = 'imagenet1000_clsid_to_human.txt' synset_path = download_testdata(synset_url, synset_name, module='data') with open(synset_path) as f: synset = eval(f.read())
图片标签从上述文本的网站上下载,保存为文本格式。
2. Compile
将获取的model(resnet18_v1
)从mxnet
格式编译为relay
格式,进一步获取功能函数和权重参数信息。
-
relay读取resnet,并添加层
里我们使用的是TVM中的Relay IR,这个IR简单来说就是可以读取我们的模型并按照模型的顺序搭建出一个可以执行的计算图出来,当然,我们可以对这个计算图进行一系列优化。(现在TVM主推Relay而不是NNVM,Relay可以称为二代NNVM)。
shape_dict = {'data': x.shape} mod, params = relay.frontend.from_mxnet(block, shape_dict) # Function in python/tvm/relay/frontend/mxnet.py # Convert from MXNet"s model into compatible relay Function. ## we want a probability so add a softmax operator func = mod["main"] func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) # 将在mod的relay表达式最后添加 # nn.softmax(*) # 对最后的结果应用softmax
-
对应硬件target编译模型
首先来我们设置Target为
llvm
,也就是部署到CPU端。通过使用relay.build_config
函数对build的参数设定, 通过使用relay.build
函数对其进行编译,获取graph, lib, params.target = 'llvm' with relay.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params)
该步骤获取的结果如下:
- graph : str – The json string that can be accepted by graph runtime.
- lib : tvm.Module – The module containing necessary libraries.
- params : dict – The parameters of the final graph.
Run
这一步将编译好的graph
, lib
, params
在Traget上执行。
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
# 指定TVM Context
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
# 建立运行时, m即为可执行
# set inputs
m.set_input('data', tvm.nd.array(x.astype(dtype)))
m.set_input(**params) #权重信息
# execute
m.run()
# get outputs
tvm_output = m.get_output(0)
top1 = np.argmax(tvm_output.asnumpy()[0])
print('TVM prediction top-1:', top1, synset[top1])
# 将获得的结果的最大值与标签序号对比,获取结果。
Result
TVM结果:
MXNet原始结果:
可见其结果一致,但时间上差距不大,可能由于网络较为简单,不具有代表性。
Reference
[1] Tiqi Chen et., Compile MXNet Models, [OL], https://docs.tvm.ai/tutorials/frontend/from_mxnet.html#sphx-glr-tutorials-frontend-from-mxnet-py
Update Log
- Oct 6 2019, Write Complete. <Procrastination!!!>