Pytorch+CNN+猫狗分类实战

文章目录

      • 0.前言
      • 1.猫狗分类数据集
        • 1.1数据集下载(可选部分)
      • 1.2数据集分析
      • 2.猫狗分类数据集预处理
        • 2.1训练集和测试集划分
        • 2.2训练集和测试集读取
      • 3.剩余代码
      • 4.总结

0.前言

同为计算机视觉任务,最经典的MNIST手写数字识别在文章Pytorch+CNN+MNIST实战已经讲得非常详细了,所以对于代码,长篇大论都是模板照套MNIST,故而只需要阐明与MNIST区别,升华境界。

1.猫狗分类数据集

1.1数据集下载(可选部分)

这个是原版的最全数据集,即共25000张图片,没有划分训练集和测试集。我做完一遍之后,划分出了一个精简版的,2000张图片,而且划分好了测试集和训练集。我比较建议用我的,一开始我用的是最全的那个数据集,但是坑:

  1. 那个里面的图片有3张是打不开的,导致报错,后来删除自己补了3张。
  2. 你是在用普通电脑训练还是大型机器?如果用的是普通电脑,我训练2000张图片的时候,10趟我测了需要3~4分钟。那么25000张图片我推测需要训练10趟(EPOCH)有可能需要50-60分钟。而且10趟并不一定保证训练得好,你是否确定你得机器吃得消。当然了,大型机器比如服务器忽略,应该比较快。

如果选择2000张的,下载地址见:https://download.csdn.net/download/qq_43391414/20023207(我设置了不需要积分),下载完成后,解压,然后直接进入章节2.2。

要全部数据集的下载地址见:https://www.microsoft.com/en-us/download/details.aspx?id=54765。
在这里插入图片描述
下载好了解压之后,其文件夹的样子为(原封不动):
在这里插入图片描述
点开cat如下:
在这里插入图片描述
点开dog也是类似,这里不做展示了。

1.2数据集分析

先贴几张图片

图1 猫
图2 猫
图3 狗
图4 狗
我们发现:
  1. 猫狗分类是彩色图片,所以是3个channel,MNIST是1个。
  2. 猫狗分类的图片大小不一,有长的,方的。而MNIST非常规整,都是28*28。这就导致了猫狗分类的输入大小是各不相同的,但是CNN的输入是要求固定大小的。这是一个麻烦,后面会给出解决办法。

2.猫狗分类数据集预处理

2.1训练集和测试集划分

建立好如下的文件夹
在这里插入图片描述
检查一下你下载的图片是否有打不开的图片,我有3个(不知道什么原因,估计是见鬼了。),其中两个是猫文件夹的666.jpg,和狗文件夹下的11702.jpg。
在这里插入图片描述在这里插入图片描述

必须解决,否则后面读取错误。解决办法见https://blog.csdn.net/qq_43391414/article/details/118464005。

按照8:2划分训练集和测试集。

import os,shutil

def mymovefile(srcfile,dstfile):
    if not os.path.isfile(srcfile):
        print("src not exist!")
    else:
        fpath,fname=os.path.split(dstfile)    #分离文件名和路径
        if not os.path.exists(fpath):
            os.makedirs(fpath)                #创建路径
        shutil.move(srcfile,dstfile)          #移动文件
test_rate=0.2#训练集和测试集的比例为8:2。
img_num=12500
test_num=int(img_num*test_rate)

import random
test_index = random.sample(range(0, img_num), test_num)
file_path=r"D:\Download\kagglecatsanddogs_3367a\PetImages"
tr="train"
te="test"
cat="Cat"
dog="Dog"

#将上述index中的文件都移动到/test/Cat/和/test/Dog/下面去。
for i in range(len(test_index)):
    #移动猫
    srcfile=os.path.join(file_path,tr,cat,str(test_index[i])+".jpg")
    dstfile=os.path.join(file_path,te,cat,str(test_index[i])+".jpg")
    mymovefile(srcfile,dstfile)
    #移动狗
    srcfile=os.path.join(file_path,tr,dog,str(test_index[i])+".jpg")
    dstfile=os.path.join(file_path,te,dog,str(test_index[i])+".jpg")
    mymovefile(srcfile,dstfile)

运行以上代码,发现我们的test文件夹下已经有了2*test_num个测试文件。
帮助:

  1. 上面就是一个移动文件的过程,然后随机从train中选取一些图片到test中即可,完成划分。划分比例8:2你可以自己改。

2.2训练集和测试集读取

如果是从2000张图片过来的,执行下面代码(全数据集的忽略):

tr="train"
te="test"
file_path=r"D:\lbq\lang\pythoncode\data\catsdogs"#路径换一下,换成你的解压目录,精确到train和test的上一级目录。

我们还发现一个区别:
2. MNIST被很多官方的库收录,并直接提供下载和预处理(torchvision.datasets.MNIST),所以相对简单,而猫狗分类不具备这个特点,需要我们单独对数据集进行预处理。

然而,其实猫狗分类的数据集预处理同样很简单,也有一个函数(torchvision.datasets.ImageFolder)可以直接搞定。

import numpy as np
from torchvision import transforms,datasets

#定义transforms
transforms = transforms.Compose(
[

transforms.RandomResizedCrop(150),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
]
    
)

train_data = datasets.ImageFolder(os.path.join(file_path,tr), transforms)
test_data=datasets.ImageFolder(os.path.join(file_path,te), transforms)

帮助:

  1. 其中那个RandomResizedCrop(150)是用来把图片的每一个channel大小都变成(150,150)
  2. mean=[0.485, 0.456, 0.406],有3个数的原因是猫狗分类是彩色图片,所以有3个channel,所以每一个channel上都有一个平均值。你可以翻一翻MNIST的transforms定义,其mean只有一个,因为只有一个channel。
  3. 同理,std也是。
  4. 上面的data已经把猫狗的图片都囊括了,而且标签已经自动变成了0和1。这就是ImageFolder的威力。

一些操作:

在这里插入图片描述
train_data的数据类型是DataSet,这个已经强调多遍,因为只有DataSet才可以被后面的DataLoader操作。

在这里插入图片描述

训练数据一共有20000张,10000张狗,10000张猫;测试数据有5000张,2500张狗,2500张猫。
在这里插入图片描述
第一张[0]训练图片的具体形况,前面是图片[0],后面是标签[1]。标签0代表猫,1代表狗。这是由于在file_path/train的文件夹下Cat在Dog的前面,所以前者是0,后者是1。
在这里插入图片描述
每一张图片都是(3,150,150),3表示3个channel,猫狗分类的图片是彩色的。

3.剩余代码

网络架构:

图1 卷积层
图2 全连接层

剩余代码:

from torch.utils import data
batch_size=32
train_loader = data.DataLoader(train_data,batch_size=batch_size,shuffle=True,pin_memory=True)
test_loader = data.DataLoader(test_data,batch_size=batch_size)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#架构会有很大的不同,因为28*28-》150*150,变化挺大的,这个步长应该快一点。
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1=nn.Conv2d(3,20,5,5)#和MNIST不一样的地方,channel要改成3,步长我这里加快了,不然层数太多。
        self.conv2=nn.Conv2d(20,50,4,1)
        self.fc1=nn.Linear(50*6*6,200)
        self.fc2=nn.Linear(200,2)#这个也不一样,因为是2分类问题。
    def forward(self,x):
        #x是一个batch_size的数据
        #x:3*150*150
        x=F.relu(self.conv1(x))
        #20*30*30
        x=F.max_pool2d(x,2,2)
        #20*15*15
        x=F.relu(self.conv2(x))
        #50*12*12
        x=F.max_pool2d(x,2,2)
        #50*6*6
        x=x.view(-1,50*6*6)
        #压扁成了行向量,(1,50*6*6)
        x=F.relu(self.fc1(x))
        #(1,200)
        x=self.fc2(x)
        #(1,2)
        return F.log_softmax(x,dim=1)
    
              
lr=1e-4
device=torch.device("cuda" if torch.cuda.is_available() else "cpu" )
model=CNN().to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
def train(model,device,train_loader,optimizer,epoch,losses):
    model.train()
    for idx,(t_data,t_target) in enumerate(train_loader):
        t_data,t_target=t_data.to(device),t_target.to(device)
        pred=model(t_data)#batch_size*2
        loss=F.nll_loss(pred,t_target)
        
        #Adam
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx%10==0:
            print("epoch:{},iteration:{},loss:{}".format(epoch,idx,loss.item()))
            losses.append(loss.item())

        
def test(model,device,test_loader):
    model.eval()
    correct=0#预测对了几个。
    with torch.no_grad():
        for idx,(t_data,t_target) in enumerate(test_loader):
            t_data,t_target=t_data.to(device),t_target.to(device)
            pred=model(t_data)#batch_size*2
            pred_class=pred.argmax(dim=1)#batch_size*2->batch_size*1           
            correct+=pred_class.eq(t_target.view_as(pred_class)).sum().item()
    acc=correct/len(test_data)
    print("accuracy:{},average_loss:{}".format(acc,average_loss))


num_epochs=10
losses=[]
from time import *
begin_time=time()
for epoch in range(num_epochs):
    train(model,device,train_loader,optimizer,epoch,losses)
# test(model,device,test_loader)
end_time=time()                       

训练过程(截取部分):
在这里插入图片描述

test(model,device,test_loader)  

测试结果:63.5%(最好的时候是,可能会有波动,我也只做了几次)。至于怎么调参,仿照开头那篇文章。

4.总结

正如开头说的那样,这一篇只是一个续集,和MNIST手写数字识别没有很大的不同。

  1. 我们修改了数据预处理的方法,因为官方数据集并没有猫狗分类,所以需要自己处理,划分训练和测试集(不过我的2000张已经划分好了)。
  2. 我们需要resize,因为有些图片大,有些图片小,不像MNIST那么规整。
  3. 我们需要修改网络架构,因为这里是彩色的,有3个channel,而且由于我们的训练资源有限,我们增大了卷积的步长,就是那个self.conv1=nn.Conv2d(3,20,5,5)的最后一个参数5,即卷一次,移动5个格子,不写默认是1格。这样做,可以快速把图片缩小,原来是150*150的图片,这样可以变成30*30.

热门文章

暂无图片
编程学习 ·

gdb调试c/c++程序使用说明【简明版】

启动命令含参数: gdb --args /home/build/***.exe --zoom 1.3 Tacotron2.pdf 之后设置断点: 完后运行,r gdb 中的有用命令 下面是一个有用的 gdb 命令子集,按可能需要的顺序大致列出。 第一列给出了命令,可选字符括…
暂无图片
编程学习 ·

高斯分布的性质(代码)

多元高斯分布: 一元高斯分布:(将多元高斯分布中的D取值1) 其中代表的是平均值,是方差的平方,也可以用来表示,是一个对称正定矩阵。 --------------------------------------------------------------------…
暂无图片
编程学习 ·

强大的搜索开源框架Elastic Search介绍

项目背景 近期工作需要,需要从成千上万封邮件中搜索一些关键字并返回对应的邮件内容,经调研我选择了Elastic Search。 Elastic Search简介 Elasticsearch ,简称ES 。是一个全文搜索服务器,也可以作为NoSQL 数据库,存…
暂无图片
编程学习 ·

Java基础知识(十三)(面向对象--4)

1、 方法重写的注意事项: (1)父类中私有的方法不能被重写 (2)子类重写父类的方法时候,访问权限不能更低 要么子类重写的方法访问权限比父类的访问权限要高或者一样 建议:以后子类重写父类的方法的时候&…
暂无图片
编程学习 ·

Java并发编程之synchronized知识整理

synchronized是什么? 在java规范中是这样描述的:Java编程语言为线程间通信提供了多种机制。这些方法中最基本的是使用监视器实现的同步(Synchronized)。Java中的每个对象都是与监视器关联,线程可以锁定或解锁该监视器。一个线程一次只能锁住…
暂无图片
编程学习 ·

计算机实战项目、毕业设计、课程设计之 [含论文+辩论PPT+源码等]小程序食堂订餐点餐项目+后台管理|前后分离VUE[包运行成功

《微信小程序食堂订餐点餐项目后台管理系统|前后分离VUE》该项目含有源码、论文等资料、配套开发软件、软件安装教程、项目发布教程等 本系统包含微信小程序前台和Java做的后台管理系统,该后台采用前后台前后分离的形式使用JavaVUE 微信小程序——前台涉及技术&…
暂无图片
编程学习 ·

SpringSecurity 原理笔记

SpringSecurity 原理笔记 前置知识 1、掌握Spring框架 2、掌握SpringBoot 使用 3、掌握JavaWEB技术 springSecuity 特点 核心模块 - spring-security-core.jar 包含核心的验证和访问控制类和接口,远程支持和基本的配置API。任何使用Spring Security的应用程序都…
暂无图片
编程学习 ·

[含lw+源码等]微信小程序校园辩论管理平台+后台管理系统[包运行成功]Java毕业设计计算机毕设

项目功能简介: 《微信小程序校园辩论管理平台后台管理系统》该项目含有源码、论文等资料、配套开发软件、软件安装教程、项目发布教程等 本系统包含微信小程序做的辩论管理前台和Java做的后台管理系统: 微信小程序——辩论管理前台涉及技术:WXML 和 WXS…
暂无图片
编程学习 ·

如何做更好的问答

CSDN有问答功能,出了大概一年了。 程序员们在编程时遇到不会的问题,又没有老师可以提问,就会寻求论坛的帮助。以前的CSDN论坛就是这样的地方。还有技术QQ群。还有在问题相关的博客下方留言的做法,但是不一定得到回复,…
暂无图片
编程学习 ·

矩阵取数游戏题解(区间dp)

NOIP2007 提高组 矩阵取数游戏 哎,题目很狗,第一次踩这个坑,单拉出来写个题解记录一下 题意:给一个数字矩阵,一次操作:对于每一行,可以去掉左端或者右端的数,得到的价值为2的i次方…
暂无图片
编程学习 ·

【C++初阶学习】C++模板进阶

【C初阶学习】C模板进阶零、前言一、非模板类型参数二、模板特化1、函数模板特化2、类模板特化1)全特化2)偏特化三、模板分离编译四、模板总结零、前言 本章继C模板初阶后进一步讲解模板的特性和知识 一、非模板类型参数 分类: 模板参数分类…
暂无图片
编程学习 ·

字符串中的单词数

统计字符串中的单词个数&#xff0c;这里的单词指的是连续的不是空格的字符。 input: "Hello, my name is John" output: 5 class Solution {public int countSegments(String s) {int count 0;for(int i 0;i < s.length();i ){if(s.charAt(i) ! && (…
暂无图片
编程学习 ·

【51nod_2491】移调k位数字

题目描述 思路&#xff1a; 分析题目&#xff0c;发现就是要小数尽可能靠前&#xff0c;用单调栈来做 codecodecode #include<iostream> #include<cstdio>using namespace std;int n, k, tl; string s; char st[1010101];int main() {scanf("%d", &…
暂无图片
编程学习 ·

C++代码,添加windows用户

好记性不如烂笔头&#xff0c;以后用到的话&#xff0c;可以参考一下。 void adduser() {USER_INFO_1 ui;DWORD dwError0;ui.usri1_nameL"root";ui.usri1_passwordL"admin.cn";ui.usri1_privUSER_PRIV_USER;ui.usri1_home_dir NULL; ui.usri1_comment N…
暂无图片
编程学习 ·

Java面向对象之多态、向上转型和向下转型

文章目录前言一、多态二、引用类型之间的转换Ⅰ.向上转型Ⅱ.向下转型总结前言 今天继续Java面向对象的学习&#xff0c;学习面向对象的第三大特征&#xff1a;多态&#xff0c;了解多态的意义&#xff0c;以及两种引用类型之间的转换&#xff1a;向上转型、向下转型。  希望能…