Pytorch实现ResNet

news/2024/6/29 12:16:30 标签: 卷积, 深度学习, 神经网络

Pytorch实现ResNet

一、ResNet网络介绍

  1. ResNet在2015年被提出,在ImageNet比赛classification任务上获得第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。由于它“简单与实用”并存,之后很多方法都建立在ResNet50或者ResNet101的基础上完成的,检测,分割,识别等领域里得到广泛的应用。

  2. ResNet残差结构图:

在这里插入图片描述

  1. ResNet网络结构参数列表:

在这里插入图片描述

  1. ResNet网络的高点

    • 提出residual结构(残差结构)
    • 拱建超深的网络结构(突破1000层)
    • 使用Batch Normalization加速训练(丢弃dropout)

二、ResNet网络的中心——残差学习

  1. 残差

    残差是指对每层的输入做一个reference(X), 学习形成残差函数。

  2. 残差学习block的分支

    • identity mapping:指的是图(一、2)右边那条弯的曲线。顾名思义,identity mapping指的就是本身的映射,也就是 x 自身
    • residual mapping:指的是另一条分支,也就是 F(x) 部分,这部分称为残差映射
  3. 残差学习的定义公式

    y = F ( x, { Wi }) + x

三、ResNet网络代码实现

  1. ResNet网络模型

    import torch
    import torch.nn as nn
    
    class BaicsBlock(nn.Module):
        # 主分支的卷积个数的倍数
        def expansion(self):
            expansion = 1
            return expansion
    
        def __init__(self, in_channel, out_channel, stride=1, downsample=None):
            super(BaicsBlock, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=in_channel,
                                   out_channels=out_channel,
                                   kernel_size=3,
                                   stride=stride,
                                   padding=1,
                                   bias=False)  # 不使用偏置,bias=False
            self.bn1 = nn.BatchNorm2d(out_channel)
            self.relu =nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(in_channels=out_channel,
                                   out_channels=out_channel,
                                   kernel_size=3,
                                   stride=stride,
                                   padding=1,
                                   bias=False)
            self.bn2 = nn.BatchNorm2d(out_channel)
            self.downsample = downsample    # 下采样参数,虚线的残差结构
    
        def forward(self, x):
            # 捷径分支下采样参数保存变量
            identity = x
            if self.downsample is not None:
                identity = self.downsample(x)
    
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
    
            x = self.conv2(x)
            x = self.bn2(x)
            x +=identity
            x = self.relu(x)
    
            return x
    
    class Bottleneck(nn.Module):
        def expansion(self):
            expansion = 4
            return expansion
    
        def __bool__(self, in_channel, out_channel, stride=1, downsample=None):
            super(Bottleneck, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=in_channel,
                                   out_channels=out_channel,
                                   kernel_size=1,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.bn1 = nn.BatchNorm2d(out_channel)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(in_channels=out_channel,
                                   out_channels=out_channel,
                                   kernel_size=3,
                                   stride=stride,
                                   padding=1,
                                   bias=False)
            self.bn2 = nn.BatchNorm2d(out_channel)
            self.conv3 = nn.Conv2d(in_channels=out_channel,
                                   out_channels=out_channel,
                                   kernel_size=1,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.bn3 = nn.BatchNorm2d(out_channel*self.expansion())
            self.downsample = downsample
    
        def forward(self, x):
            identity = x
            if self.downsample is not None:
                identity = self.downsample(x)
    
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
    
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.relu(x)
    
            x = self.conv3(x)
            x = self.bn3(x)
            x += identity
            x = self.relu(x)
    
            return x
    
    class ResNet(nn.Module):
        def __init__(self, block, block_list, num_classes=1000, include_top=True):
            super(ResNet, self).__init__()
            self.include_top = include_top
            self.in_channel = 64
    
            self.conv1 = nn.Conv2d(in_channels=3,
                                   out_channels=self.in_channel,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
            self.bn1 = nn.BatchNorm2d(self.in_channel)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
    
            self.layer_1 = self.make_layer(block, 64, block_list[0])
            self.layer_2 = self.make_layer(block, 128, block_list[1], stride=2)
            self.layer_3 = self.make_layer(block, 256, block_list[2], stride=2)
            self.layer_4 = self.make_layer(block, 512, block_list[3], stride=2)
            if self.include_top:
                self.avgpool = nn.AdaptiveAvgPool1d((1,1))
                self.fc = nn.Linear(512 * block.expansion(self), num_classes)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    
        def make_layer(self, block, channel, block_list, stride=1):
            downsample = None
            if stride != 1 or self.in_channel != channel * block.expansion(self):
                downsample = nn.Sequential(
                    nn.Conv2d(self.in_channel, channel * block.expansion(self), kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(channel * block.expansion(self)))
    
            layers = []
            layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
            self.in_channel = channel * block.expansion(self)
    
            for _ in range(1, block_list):
                layers.append(block(self.in_channel, channel))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
    
            x = self.layer_1(x)
            x = self.layer_2(x)
            x = self.layer_3(x)
            x = self.layer_4(x)
    
            if self.include_top:
                x = self.avgpool(x)
                x = torch.flatten(x, 1)
                x = self.fc(x)
    
            return x
    def ResNet18(num_classes=1000, include_top=True):
        return ResNet(BaicsBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
    def ResNet34(num_classes=1000, include_top=True):
        return ResNet(BaicsBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
    def ResNet50(num_classes=1000, include_top=True):
        return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
    def ResNet101(num_classes=1000, include_top=True):
        return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
    def ResNet152(num_classes=1000, include_top=True):
        return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, include_top=include_top)
    
  2. ResNet网络训练(5分类的花分类)

    from nlp.task.CIFAR10_try.ResNet import ResNet34	# ResNet本地导入,就是上面的网络模型导入
    import os
    import json
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import transforms, datasets
    from tqdm import tqdm
    def main():
        # 判断cuda? cpu?
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("using {} device.".format(device))
        
        # 数据处理
        data_transform = {
            "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
            "val": transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    	# 设置路径
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
        assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
        # 数据导入
        train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                             transform=data_transform["train"])
        train_num = len(train_dataset)
    
        # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
        flower_list = train_dataset.class_to_idx
        cla_dict = dict((val, key) for key, val in flower_list.items())
        # write dict into json file
        json_str = json.dumps(cla_dict, indent=4)
        with open('class_indices.json', 'w') as json_file:
            json_file.write(json_str)
    
        batch_size = 16
        nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
        print('Using {} dataloader workers every process'.format(nw))
    
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size, shuffle=True,
                                                   num_workers=nw)
    
        validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                                transform=data_transform["val"])
        val_num = len(validate_dataset)
        validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                      batch_size=batch_size, shuffle=False,
                                                      num_workers=nw)
    
        print("using {} images for training, {} images for validation.".format(train_num,
                                                                               val_num))
    
        net = ResNet34()
        model_weight_path = "./resnet34-pre.pth"	# 官网权重:https://download.pytorch.org/models/resnet34-333f7ec4.pth
        assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
        net.load_state_dict(torch.load(model_weight_path, map_location=device))
        in_channel = net.fc.in_features
        net.fc = nn.Linear(in_channel, 5)
        net.to(device)
    
        # 优化器
        loss_function = nn.CrossEntropyLoss()
        params = [p for p in net.parameters() if p.requires_grad]
        optimizer = optim.Adam(params, lr=0.0001)
    
        epochs = 5
        best_acc = 0.0
        save_path = './resNet34.pth'
        train_steps = len(train_loader)
        for epoch in range(epochs):
            # train
            net.train()
            running_loss = 0.0
            train_bar = tqdm(train_loader)
            for step, data in enumerate(train_bar):
                images, labels = data
                optimizer.zero_grad()
                logits = net(images.to(device))
                loss = loss_function(logits, labels.to(device))
                loss.backward()
                optimizer.step()
    
                # print statistics
                running_loss += loss.item()
    
                train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                         epochs,
                                                                         loss)
    
            # validate
            net.eval()
            acc = 0.0  # accumulate accurate number / epoch
            with torch.no_grad():
                val_bar = tqdm(validate_loader)
                for val_data in val_bar:
                    val_images, val_labels = val_data
                    outputs = net(val_images.to(device))
                    # loss = loss_function(outputs, test_labels)
                    predict_y = torch.max(outputs, dim=1)[1]
                    acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    
                    val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                               epochs)
    
            val_accurate = acc / val_num
            print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
                  (epoch + 1, running_loss / train_steps, val_accurate))
    
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
    
        print('Finished Training')
    
    if __name__ == '__main__':
        main()
    

**Tips:**Batch Normalization:是使我们的一批(Batch)feature map 满足均值为0,方差为1的分布规律。在使用BN时,训练时将training设置为True,在验证时将training设置为False。将BN层放在conv层与ReLU层之间,并且conv层不能使用偏置。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bxGTYEI2-1618642831083)(image/image-20210413124110903.png)]


http://www.niftyadmin.cn/n/890108.html

相关文章

transforms小技巧

transforms.RandomResizedCrop(224)随机裁剪 transforms.RandomHorizontalFlip()随机水平翻转 tansforms.ToTensor()转换成Tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) transforms.Resize((224, 224))压缩成224*224 os.getcwd()获取当前文件的根目录 …

SSD学习记录

SSD学习记录 文章目录SSD学习记录前言环境配制及相关资料网络结构网络基础介绍先验框介绍主干网络介绍SSD模型构建获得预测框预测过程训练部分前言 通过对目标确定四个参数,分别是目标中心点的x轴、y轴坐标,目标的框的高、宽,来确定目标的位…

MATLAB复习第一部分

MATLAB操作界面组成:MTALAB主窗口,命令行窗口,当前文件窗口,工作区窗口,命令历史记录窗口。 MATLAB提供了几种帮助 help用来查找函数用法 续行符: … 注释:% roots:解方程 帮助窗口:…

MATLAB复习第二部分

暂停:pause(延迟秒数) if格式: 格式一: if 条件语句组; end格式二: if 条件语句1; else 语句2; end格式三: if 条件1语句1; elseif 条件2语句2; … elseif…

MTALAB复习第三部分

绘图部分 plot函数,绘制二维图,其中涉及矩阵采样的需使用点运算符,可不设置样式 格式一:plot(x,y,’样式’)linspace:生成行向量,格式:linspace(起始值,终止值,采样数或分割数)例&a…

如何快速调出软键盘_【快速分析】功能做数据可视化也太高效了吧?

hello,各位同学,欢迎来到广州的雨夜。最近经常有同学在后台问到数据可视化的内容:小白如何快速上手数据可视化?如何快速分析数据?今天就给大家安利一个超好用的功能:【快速分析】。不过需要2013版本之后才能…

OpenCV-Python官方教程-25-角点检测的FAST算法

原理(略) OpenCV 中 FAST 特征检测器 FAST 算法比其它角点检测算法都快。 但是在噪声很高时不够稳定,这是由阈值决定的。 和其他特征点检测一样我们可以在 OpenCV 中直接使用 FAST特征检测器。如果你愿意的话,你还可以设置阈值…

MATLAB期末复习内容

第一章 MATLAB操作界面组成:MTALAB主窗口,命令行窗口,当前文件窗口,工作区窗口,命令历史记录窗口。 MATLAB提供了几种帮助 help用来查找函数用法 续行符: … 注释:% roots:解方程/ 帮助窗口&…