FGSM快速梯度符号法非定向攻击代码(PyTorch)

news/2024/7/8 4:28:55 标签: pytorch, 人工智能, python

数据集:手写字体识别MNIST

模型:LeNet

import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")


# LeNet 模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)  # 防止过拟合,实现时必须标明training的状态为self.training
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


test_loader = torch.utils.data.DataLoader(#导入数据
    datasets.MNIST('data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            ])),
        batch_size=1, shuffle=True)


model = Net().to(device)
pretrained_model = "lenet_mnist_model.pth"
model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
model.eval()


def fgsm_attack(image, epsilon, data_grad):  # 此函数的功能是进行fgsm攻击,需要输入三个变量,干净的图片,扰动量和输入图片梯度
    sign_data_grad = data_grad.sign()  # 梯度符号
    # print(sign_data_grad)
    perturbed_image = image+epsilon*sign_data_grad  # 公式
    perturbed_image = torch.clamp(perturbed_image, 0, 1)  # 为了保持图像的原始范围,将受干扰的图像裁剪到一定的范围【0,1】
    return perturbed_image


epsilons = [0, .05, .1, .15, .2, .25, .3]


def test(model, device, test_loader, epsilon):
    correct = 0
    adv_examples = []
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True
        output = model(data)
        init_pred = output.max(1, keepdim=True)[1]  # 选取最大的类别概率
        loss = F.nll_loss(output, target)
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm_attack(data, epsilon, data_grad)
        output = model(perturbed_data)
        final_pred = output.max(1, keepdim=True)[1]
        if final_pred.item() == target.item():  # 判断类别是否相等
            correct += 1
        if len(adv_examples) < 6:
            adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
            adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))

    final_acc = correct / float(len(test_loader))  # 算正确率
    print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
    return final_acc, adv_examples


accuracies = []
examples = []

# Run test for each epsilon
for eps in epsilons:
    acc, ex = test(model, device, test_loader, eps)
    accuracies.append(acc)
    examples.append(ex)

plt.plot(epsilons, accuracies)
plt.show()

cnt = 0
plt.figure(figsize=(8, 10))
for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons), len(examples[0]), cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0:
            plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
        orig, adv, ex = examples[i][j]
        plt.title("{} -> {}".format(orig, adv))
        plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()

在这里插入图片描述

在这里插入图片描述


http://www.niftyadmin.cn/n/5110474.html

相关文章

Canvas系列绘制图片学习:绘制图片和渐变效果

我们现在已经可以绘制好多东西了&#xff0c;不过在实际开发中&#xff0c;绘制最多的当然是图片了&#xff0c;这章我们就讲讲图片的绘制。 绘制图片 绘制图片的API是drawImage&#xff0c;它的参数有三种情况&#xff1a; // 将图片绘制在canvas的(dX, dY)坐标处 context.…

配置Linux

首先安装VMware&#xff1a; 安装说明&#xff1a;&#xff08;含许可证的key&#xff09; https://mp.weixin.qq.com/s/XE-BmeKHlhfiRA1bkNHTtg 给大家提供了VMware Workstation Pro16&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1q8VE3TkPzDnM3u9bkTdA_g 提取码&…

高校教务系统登录页面JS分析——巢湖学院

高校教务系统密码加密逻辑及JS逆向 本文将介绍高校教务系统的密码加密逻辑以及使用JavaScript进行逆向分析的过程。通过本文&#xff0c;你将了解到密码加密的基本概念、常用加密算法以及如何通过逆向分析来破解密码。 本文仅供交流学习&#xff0c;勿用于非法用途。 一、密码加…

16、监测数据采集物联网应用开发步骤(12.1)

阶段性源码将于本章节末尾给出下载 监测数据采集物联网应用开发步骤(11) 本章节进行前端web UI开发web数据接口服务开发 web数据接口服务利用SOCKET TCP服务方式解析http协议内容模式 在com.zxy.common.Com_Para.py中添加如下内容 #修改web数据接口服务端口 port 9000 #是…

3D双目跟踪瞳孔识别

人眼数据集通常用于眼部相关的计算机视觉、眼动追踪、瞳孔检测、情感识别以及生物特征识别等领域的研究和开发。以下是一些常见的人眼数据集&#xff1a; BioID Face Database: 这个数据库包含1,521张近距离的人脸图像&#xff0c;其中包括瞳孔位置的标记。它通常用于瞳孔检测和…

EPLAN_010#STEP格式_箱柜模型的定义、拼柜

一、导入 首先创建一个宏项目——在布局空间中导航器新建一个布局空间 菜单栏——布局空间——导入(3D图形&#xff09;——导入下载下来的STEP 如果导入进来的箱柜是这种模样的&#xff0c;表示可以使用。如果左侧只显示一个逻辑组件&#xff0c;则无法使用。&#xff08;如果…

最近又火了!吴恩达《生成式 AI》重磅发布!

吴恩达教授可能是许多人接触 AI 的启蒙课导师吧&#xff0c;在过去的十多年中&#xff0c;他的《Machine Learning》课程已经对数百万的学习者产生了积极影响。 而随着 ChatGPT 的推出&#xff0c;大模型和各类生成式人工智能&#xff08;GenAI&#xff09;技术在行业内外备受…

【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割2(基础数据流篇)

构建pytorch训练模型读取的数据,是有模版可以参考的,是有套路的,这点相信使用过的人都知道。我也会给出一个套路的模版,方便学习和查询。 同时,也可以先去参考学习之前的一篇较为简单的3D分类任务的数据构建方法,链接在这里:【3D图像分类】基于Pytorch的3D立体图像分类…