关于GAN的学习

  • GAN的基本模型结构:
image-20230322201751924
  • GAN最关键的损失函数Loss 定义

    image-20230322202001675

基于MNIST的GAN网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
""" 基于MNIST 实现对抗生成网络 (GAN) """

import torch
import torchvision
import torch.nn as nn
import numpy as np

image_size = [1, 28, 28]
latent_dim = 96
batch_size = 64
use_gpu = torch.cuda.is_available()


class Generator(nn.Module):

def __init__(self):
super(Generator, self).__init__()

self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
torch.nn.BatchNorm1d(128),
torch.nn.GELU(),

nn.Linear(128, 256),
torch.nn.BatchNorm1d(256),
torch.nn.GELU(),
nn.Linear(256, 512),
torch.nn.BatchNorm1d(512),
torch.nn.GELU(),
nn.Linear(512, 1024),
torch.nn.BatchNorm1d(1024),
torch.nn.GELU(),
nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
# nn.Tanh(),
nn.Sigmoid(),
)

def forward(self, z):
# shape of z: [batchsize, latent_dim]

output = self.model(z)
image = output.reshape(z.shape[0], *image_size)

return image


class Discriminator(nn.Module):

def __init__(self):
super(Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Linear(np.prod(image_size, dtype=np.int32), 512),
torch.nn.GELU(),
nn.Linear(512, 256),
torch.nn.GELU(),
nn.Linear(256, 128),
torch.nn.GELU(),
nn.Linear(128, 64),
torch.nn.GELU(),
nn.Linear(64, 32),
torch.nn.GELU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)

def forward(self, image):
# shape of image: [batchsize, 1, 28, 28]

prob = self.model(image.reshape(image.shape[0], -1))

return prob


# Training
dataset = torchvision.datasets.MNIST("data", train=True, download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize([0.5], [0.5]),
]
)
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True, drop_last=True)

generator = Generator()
discriminator = Discriminator()


g_optimizer = torch.optim.Adam(generator.parameters(
), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(
), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

loss_fn = nn.BCELoss()
labels_one = torch.ones(batch_size, 1)
labels_zero = torch.zeros(batch_size, 1)

if use_gpu:
print("use gpu for training")
generator = generator.cuda()
discriminator = discriminator.cuda()
loss_fn = loss_fn.cuda()
labels_one = labels_one.to("cuda")
labels_zero = labels_zero.to("cuda")

num_epoch = 200
for epoch in range(num_epoch):
for i, mini_batch in enumerate(dataloader):
gt_images, _ = mini_batch

z = torch.randn(batch_size, latent_dim)

if use_gpu:
gt_images = gt_images.to("cuda")
z = z.to("cuda")

pred_images = generator(z)
g_optimizer.zero_grad()

recons_loss = torch.abs(pred_images-gt_images).mean()

g_loss = recons_loss*0.05 + \
loss_fn(discriminator(pred_images), labels_one)

g_loss.backward()
g_optimizer.step()

d_optimizer.zero_grad()

real_loss = loss_fn(discriminator(gt_images), labels_one)
fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
d_loss = (real_loss + fake_loss)

# 观察real_loss与fake_loss,同时下降同时达到最小值,并且差不多大,说明D已经稳定了

d_loss.backward()
d_optimizer.step()

if i % 50 == 0:
print(f"step:{len(dataloader)*epoch+i}, recons_loss:{recons_loss.item()}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")

if i % 400 == 0:
image = pred_images[:16].data
torchvision.utils.save_image(
image, f"image_{len(dataloader)*epoch+i}.png", nrow=4)

生成过程

第0次:image-20230410130100256

第400次:image-20230410130128026

第5000次:image-20230410130215475

第20000次:image-20230410130317027

第8000次: