博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PyTorch使用总览
阅读量:5236 次
发布时间:2019-06-14

本文共 2212 字,大约阅读时间需要 7 分钟。

 

PyTorch使用总览

 

 

深度学习框架训练模型时的代码主要包含数据读取、网络构建和其他设置三方面,基本上掌握这三方面就可以较为灵活地使用框架训练模型。PyTorch是Facebook的官方深度学习框架之一,到现在开源1年时间,势头非常猛,相信使用过的人都会被其轻便和快速等特点深深吸引,因此这篇博客从整体上介绍如何使用PyTorch

PyTorch的官方github地址:
PyTorch官方文档:

 

建议先看看:,对Pytorch的使用有一个快速的了解。

 

接下来就按照上述的3个方面来介绍如何使用PyTorch。

 

一、数据读取

 

数据读取部分包含如何将你的图像和标签数据转换成PyTorch框架的Tensor数据类型,官方代码库中有一个接口例子:torchvision.ImageFolder,这个接口在 中有简单介绍。因为这个接口针对的数据存放方式是每个文件夹包含一个类的图像,但是实际应用中可能你的数据不是这样维护的,或者你的数据是多标签的,或者其他更复杂的形式,那么就需要自定义一个数据读取接口,这个时候就不得不提一个PyTorch中数据读取基类:torch.utils.data.Dataset,包括前面提到的torchvision.ImageFolder接口的对应类也是继承torch.utils.data.Dataset实现的,因此torch.utils.data.Dataset类是PyTorch框架中数据读取的核心。那么如何自定义一个数据读取接口呢?可以看博客:,这篇博客中从剖析torchvision.ImageFolder接口切入,然后引出如何自定义数据读取接口。这样就完成了数据的第一层封装。

 

在自定义数据读取接口时还有一步很重要的操作:数据预处理。常常我们在论文中看到的data argumentation就是指的数据预处理,对实验结果影响还是比较大的。该操作在PyTorch中可以通过torchvision.transforms接口来实现,具体请看博客: 的介绍。

 

经过上述的两个操作后,还需再进行一次封装,将数据和标签封装成数据迭代器,这样才方便模型训练的时候一个batch一个batch地进行,这就要用到torch.utils.data.DataLoader接口,该接口的一个输入就是前面继承自torch.utils.data.Dataset类的自定义了的对象(比如torchvision.ImageFolder类的对象),具体可以参考博客:

至此,从图像和标签文件就生成了Tensor类型的数据迭代器,后续仅需将Tensor对象用torch.autograd.Variable接口封装成Variable类型(比如train_data=torch.autograd.Variable(train_data),如果要在gpu上运行则是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作为模型的输入了。

 

其他自定义的数据读取接口例子可以参考:,该项目中的read_ImageNetData.py脚本自定义了读取ImageNet数据集的接口,训练数据的读取和验证数据的读取采取不同的接口实现,比较有特点。

 

二、网络构建

 

PyTorch框架中提供了一些方便使用的网络结构及预训练模型接口:torchvision.models,具体可以看博客:。该接口可以直接导入指定的网络结构,并且可以选择是否用预训练模型初始化导入的网络结构。

 

那么如何自定义网络结构呢?在PyTorch中,构建网络结构的类都是基于torch.nn.Module这个基类进行的,也就是说所有网络结构的构建都可以通过继承该类来实现,包括torchvision.models接口中的模型实现类也是继承这个基类进行重写的。自定义网络结构可以参考:1、。该项目中的MobileNetV2.py脚本自定义了网络结构。2、。该项目中的se_resnet.py和se_resnext.py脚本分别自定义了不同的网络结构。

 

如果要用某预训练模型为自定义的网络结构进行参数初始化,可以用torch.load接口导入预训练模型,然后调用自定义的网络结构对象的load_state_dict方式进行参数初始化,具体可以看项目中的train.py脚本中if args.resume条件语句。

 

三、其他设置

 

优化函数通过torch.optim包实现,比如torch.optim.SGD()接口表示随机梯度下降。更多优化函数可以看官方文档:。

 

学习率策略通过torch.optim.lr_scheduler接口实现,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch数减少学习率。更多学习率变化策略可以看官方文档:。

 

损失函数通过torch.nn包实现,比如torch.nn.CrossEntropyLoss()接口表示交叉熵等。

 

多GPU训练通过torch.nn.DataParallel接口实现,比如:model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上训练模型。

 

转载于:https://www.cnblogs.com/DicksonJYL/p/9576835.html

你可能感兴趣的文章
redis哨兵集群、docker入门
查看>>
hihoCoder 1233 : Boxes(盒子)
查看>>
oracle中anyData数据类型的使用实例
查看>>
C++对vector里面的元素排序及取任意重叠区间
查看>>
软件测试——性能测试总结
查看>>
12.4站立会议
查看>>
Java Concurrentmodificationexception异常原因和解决方法
查看>>
客户端访问浏览器的流程
查看>>
codeforces水题100道 第二十二题 Codeforces Beta Round #89 (Div. 2) A. String Task (strings)
查看>>
c++||template
查看>>
[BZOJ 5323][Jxoi2018]游戏
查看>>
编程面试的10大算法概念汇总
查看>>
Vue
查看>>
python-三级菜单和购物车程序
查看>>
条件断点 符号断点
查看>>
VMware12 + Ubuntu16.04 虚拟磁盘扩容
查看>>
水平垂直居中
查看>>
MySQL简介
查看>>
设计模式之桥接模式(Bridge)
查看>>
jquery的$(document).ready()和onload的加载顺序
查看>>