图像风格迁移模型-CycleGAN

作者: FutureSI
日期: 2021.03
摘要: 本案例实现了CycleGAN模型用于风格迁移。

一、CycleGAN介绍

CycleGAN,即循环生成对抗网络,是一种用于图片风格迁移的模型。原来的图片风格迁移模型通过在两组一一匹配的图片进行上训练,来学习输入图片组与输出图片组的特征映射关系,从而实现将输入图片的特征迁移到输出图片上,比如将A组图片的斑马的条纹外观特征迁移到B组普通马匹图片上。但是,训练所要求的两组一一对应训练集图片往往难以获得。CycleGAN通过给GAN网络添加循环一致性损失(consistency loss)的方法打破了训练集图片数据的一一对应限制。

二、框架导入设置

# 解压 ai studio 数据集(首次执行后注释)
!unzip -qa -d ~/data/data10040/ ~/data/data10040/horse2zebra.zip

# 如果用wget自行下载数据集需要自行添加训练集列表文件
# !wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
# !unzip -qa -d /home/aistudio/data/data10040/ horse2zebra.zip
import paddle
from paddle.io import Dataset, DataLoader, IterableDataset
import numpy as np
import cv2
import random
import time
import warnings
import matplotlib.pyplot as plt

%matplotlib inline

warnings.filterwarnings("ignore", category=Warning)  # 过滤报警信息

BATCH_SIZE = 1
DATA_DIR = "/home/aistudio/data/data10040/horse2zebra/"  # 设置训练集数据地址

三、准备数据集

from PIL import Image
from paddle.vision.transforms import RandomCrop


# 处理图片数据:随机裁切、调整图片数据形状、归一化数据
def data_transform(img, output_size):
    h, w, _ = img.shape
    assert h == w and h >= output_size  # check picture size
    # random crop
    rc = RandomCrop(224)
    img = rc(img)
    # normalize
    img = img / 255.0 * 2.0 - 1.0
    # from [H,W,C] to [C,H,W]
    img = np.transpose(img, (2, 0, 1))
    # data type
    img = img.astype("float32")
    return img


# 定义horse2zebra数据集对象
class H2ZDateset(Dataset):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        self.pic_list_a = np.loadtxt(data_dir + "trainA.txt", dtype=np.str)
        np.random.shuffle(self.pic_list_a)
        self.pic_list_b = np.loadtxt(data_dir + "trainB.txt", dtype=np.str)
        np.random.shuffle(self.pic_list_b)
        self.pic_list_lenth = min(
            int(self.pic_list_a.shape[0]), int(self.pic_list_b.shape[0])
        )

    def __getitem__(self, idx):
        img_dir_a = self.data_dir + self.pic_list_a[idx]
        img_a = cv2.imread(img_dir_a)
        img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB)
        img_a = data_transform(img_a, 224)
        img_dir_b = self.data_dir + self.pic_list_b[idx]
        img_b = cv2.imread(img_dir_b)
        img_b = cv2.cvtColor(img_b, cv2.COLOR_BGR2RGB)
        img_b = data_transform(img_b, 224)
        return np.array([img_a, img_b])

    def __len__(self):
        return self.pic_list_lenth


# 定义图片loader
h2zdateset = H2ZDateset(DATA_DIR)
loader = DataLoader(
    h2zdateset,
    shuffle=True,
    batch_size=BATCH_SIZE,
    drop_last=False,
    num_workers=0,
    use_shared_memory=False,
)
data = next(loader())[0]
data = np.transpose(data, (1, 0, 2, 3, 4))
print("读取的数据形状:", data.shape)
读取的数据形状: [2, 1, 3, 224, 224]

四、模型组网

4.1 定义辅助功能函数

判别器负责区分图片的“真假”。输入的是训练集图片,判别器的输出越趋近于数值1(即判别此图片为真);如果输入的是生成器生成的图片,判别器的输出越趋近于数值0(即判别此图片为假)。这样,生成器就可以根据判别器输出的变化而计算梯度以优化生成网络。

from PIL import Image
import os


# 打开图片
def open_pic(file_name="./data/data10040/horse2zebra/testA/n02381460_1300.jpg"):
    img = Image.open(file_name).resize((256, 256), Image.BILINEAR)
    img = (np.array(img).astype("float32") / 255.0 - 0.5) / 0.5
    img = img.transpose((2, 0, 1))
    img = img.reshape((-1, img.shape[0], img.shape[1], img.shape[2]))
    return img


# 存储图片
def save_pics(
    pics,
    file_name="tmp",
    save_path="./output/pics/",
    save_root_path="./output/",
):
    if not os.path.exists(save_root_path):
        os.makedirs(save_root_path)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    for i in range(len(pics)):
        pics[i] = pics[i][0]
    pic = np.concatenate(tuple(pics), axis=2)
    pic = pic.transpose((1, 2, 0))
    pic = (pic + 1) / 2
    # plt.imshow(pic)
    pic = np.clip(pic * 256, 0, 255)
    img = Image.fromarray(pic.astype("uint8")).convert("RGB")
    img.save(save_path + file_name + ".jpg")


# 显示图片
def show_pics(pics):
    print(pics[0].shape)
    plt.figure(figsize=(3 * len(pics), 3), dpi=80)
    for i in range(len(pics)):
        pics[i] = (pics[i][0].transpose((1, 2, 0)) + 1) / 2
        plt.subplot(1, len(pics), i + 1)
        plt.imshow(pics[i])
        plt.xticks([])
        plt.yticks([])


# 图片缓存队列
class ImagePool:
    def __init__(self, pool_size=50):
        self.pool = []
        self.count = 0
        self.pool_size = pool_size

    def pool_image(self, image):
        return image
        image = image.numpy()
        rtn = ""
        if self.count < self.pool_size:
            self.pool.append(image)
            self.count += 1
            rtn = image
        else:
            p = np.random.rand()
            if p > 0.5:
                random_id = np.random.randint(0, self.pool_size - 1)
                temp = self.pool[random_id]
                self.pool[random_id] = image
                rtn = temp
            else:
                rtn = image
        return paddle.to_tensor(rtn)

4.2 查看读取的数据集图片

show_pics([data[0].numpy(), data[1].numpy()])
(1, 3, 224, 224)

png

4.3 定义判别器

import paddle
import paddle.nn as nn
import numpy as np


# 定义基础的“卷积层+实例归一化”块
class ConvIN(nn.Layer):
    def __init__(
        self,
        num_channels,
        num_filters,
        filter_size,
        stride=1,
        padding=1,
        bias_attr=None,
        weight_attr=None,
    ):
        super().__init__()
        model = [
            nn.Conv2D(
                num_channels,
                num_filters,
                filter_size,
                stride=stride,
                padding=padding,
                bias_attr=bias_attr,
                weight_attr=weight_attr,
            ),
            nn.InstanceNorm2D(num_filters),
            nn.LeakyReLU(negative_slope=0.2),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# 定义CycleGAN的判别器
class Disc(nn.Layer):
    def __init__(self, weight_attr=nn.initializer.Normal(0.0, 0.02)):
        super().__init__()
        model = [
            ConvIN(
                3,
                64,
                4,
                stride=2,
                padding=1,
                bias_attr=True,
                weight_attr=weight_attr,
            ),
            ConvIN(
                64,
                128,
                4,
                stride=2,
                padding=1,
                bias_attr=False,
                weight_attr=weight_attr,
            ),
            ConvIN(
                128,
                256,
                4,
                stride=2,
                padding=1,
                bias_attr=False,
                weight_attr=weight_attr,
            ),
            ConvIN(
                256,
                512,
                4,
                stride=1,
                padding=1,
                bias_attr=False,
                weight_attr=weight_attr,
            ),
            nn.Conv2D(
                512,
                1,
                4,
                stride=1,
                padding=1,
                bias_attr=True,
                weight_attr=weight_attr,
            ),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

4.4 测试判别器模块

ci = ConvIN(3, 3, 3, weight_attr=nn.initializer.Normal(0.0, 0.02))
logit = ci(paddle.to_tensor(data[0]))
print("ConvIN块输出的特征图形状:", logit.shape)

d = Disc()
logit = d(paddle.to_tensor(data[0]))
print("判别器输出的特征图形状:", logit.shape)
ConvIN块输出的特征图形状: [1, 3, 224, 224]
判别器输出的特征图形状: [1, 1, 26, 26]

4.5 定义生成器

# 定义基础的“转置卷积层+实例归一化”块
class ConvTransIN(nn.Layer):
    def __init__(
        self,
        num_channels,
        num_filters,
        filter_size,
        stride=1,
        padding="same",
        padding_mode="constant",
        bias_attr=None,
        weight_attr=None,
    ):
        super().__init__()
        model = [
            nn.Conv2DTranspose(
                num_channels,
                num_filters,
                filter_size,
                stride=stride,
                padding=padding,
                bias_attr=bias_attr,
                weight_attr=weight_attr,
            ),
            nn.InstanceNorm2D(num_filters),
            nn.LeakyReLU(negative_slope=0.2),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# 定义残差块
class Residual(nn.Layer):
    def __init__(self, dim, bias_attr=None, weight_attr=None):
        super().__init__()
        model = [
            nn.Conv2D(
                dim,
                dim,
                3,
                stride=1,
                padding=1,
                padding_mode="reflect",
                bias_attr=bias_attr,
                weight_attr=weight_attr,
            ),
            nn.InstanceNorm2D(dim),
            nn.LeakyReLU(negative_slope=0.2),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return x + self.model(x)


# 定义CycleGAN的生成器
class Gen(nn.Layer):
    def __init__(
        self,
        base_dim=64,
        residual_num=7,
        downup_layer=2,
        weight_attr=nn.initializer.Normal(0.0, 0.02),
    ):
        super().__init__()
        model = [
            nn.Conv2D(
                3,
                base_dim,
                7,
                stride=1,
                padding=3,
                padding_mode="reflect",
                bias_attr=False,
                weight_attr=weight_attr,
            ),
            nn.InstanceNorm2D(base_dim),
            nn.LeakyReLU(negative_slope=0.2),
        ]
        # 下采样块(down sampling)
        for i in range(downup_layer):
            model += [
                ConvIN(
                    base_dim * 2**i,
                    base_dim * 2 ** (i + 1),
                    3,
                    stride=2,
                    padding=1,
                    bias_attr=False,
                    weight_attr=weight_attr,
                ),
            ]
        # 残差块(residual blocks)
        for i in range(residual_num):
            model += [
                Residual(
                    base_dim * 2**downup_layer,
                    True,
                    weight_attr=nn.initializer.Normal(0.0, 0.02),
                )
            ]
        # 上采样块(up sampling)
        for i in range(downup_layer):
            model += [
                ConvTransIN(
                    base_dim * 2 ** (downup_layer - i),
                    base_dim * 2 ** (downup_layer - i - 1),
                    3,
                    stride=2,
                    padding="same",
                    padding_mode="constant",
                    bias_attr=False,
                    weight_attr=weight_attr,
                ),
            ]
        model += [
            nn.Conv2D(
                base_dim,
                3,
                7,
                stride=1,
                padding=3,
                padding_mode="reflect",
                bias_attr=True,
                weight_attr=nn.initializer.Normal(0.0, 0.02),
            ),
            nn.Tanh(),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

4.6 测试生成器模块

cti = ConvTransIN(
    3,
    3,
    3,
    stride=2,
    padding="same",
    padding_mode="constant",
    bias_attr=False,
    weight_attr=nn.initializer.Normal(0.0, 0.02),
)
logit = cti(paddle.to_tensor(data[0]))
print("ConvTransIN块输出的特征图形状:", logit.shape)

r = Residual(3, True, weight_attr=nn.initializer.Normal(0.0, 0.02))
logit = r(paddle.to_tensor(data[0]))
print("Residual块输出的特征图形状:", logit.shape)

g = Gen()
logit = g(paddle.to_tensor(data[0]))
print("生成器输出的特征图形状:", logit.shape)
ConvTransIN块输出的特征图形状: [1, 3, 448, 448]
Residual块输出的特征图形状: [1, 3, 224, 224]
生成器输出的特征图形状: [1, 3, 224, 224]

五、训练CycleGAN网络

# 模型训练函数
def train(
    epoch_num=99999,
    adv_weight=1,
    cycle_weight=10,
    identity_weight=10,
    load_model=False,
    model_path="./model/",
    model_path_bkp="./model_bkp/",
    print_interval=1,
    max_step=5,
    model_bkp_interval=2000,
):
    # 定义两对生成器、判别器对象
    g_a = Gen()
    g_b = Gen()
    d_a = Disc()
    d_b = Disc()

    # 定义数据读取器
    dataset = H2ZDateset(DATA_DIR)
    reader_ab = DataLoader(
        dataset,
        shuffle=True,
        batch_size=BATCH_SIZE,
        drop_last=False,
        num_workers=2,
    )

    # 定义优化器
    g_a_optimizer = paddle.optimizer.Adam(
        learning_rate=0.0002,
        beta1=0.5,
        beta2=0.999,
        parameters=g_a.parameters(),
    )
    g_b_optimizer = paddle.optimizer.Adam(
        learning_rate=0.0002,
        beta1=0.5,
        beta2=0.999,
        parameters=g_b.parameters(),
    )
    d_a_optimizer = paddle.optimizer.Adam(
        learning_rate=0.0002,
        beta1=0.5,
        beta2=0.999,
        parameters=d_a.parameters(),
    )
    d_b_optimizer = paddle.optimizer.Adam(
        learning_rate=0.0002,
        beta1=0.5,
        beta2=0.999,
        parameters=d_b.parameters(),
    )

    # 定义图片缓存队列
    fa_pool, fb_pool = ImagePool(), ImagePool()

    # 定义总迭代次数为0
    total_step_num = np.array([0])

    # 加载存储的模型
    if load_model == True:
        ga_para_dict = paddle.load(model_path + "gen_b2a.pdparams")
        g_a.set_state_dict(ga_para_dict)

        gb_para_dict = paddle.load(model_path + "gen_a2b.pdparams")
        g_b.set_state_dict(gb_para_dict)

        da_para_dict = paddle.load(model_path + "dis_ga.pdparams")
        d_a.set_state_dict(da_para_dict)

        db_para_dict = paddle.load(model_path + "dis_gb.pdparams")
        d_b.set_state_dict(db_para_dict)

        total_step_num = np.load("./model/total_step_num.npy")

    # 定义本次训练开始时的迭代次数
    step = total_step_num[0]

    # 开始模型训练循环
    print(
        "Start time :",
        time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
        "start step:",
        step + 1,
    )
    for epoch in range(epoch_num):
        for data_ab in reader_ab:
            step += 1

            # 设置模型为训练模式,针对bn、dropout等进行不同处理
            g_a.train()
            g_b.train()
            d_a.train()
            d_b.train()

            # 得到A、B组图片数据
            data_ab = np.transpose(data_ab[0], (1, 0, 2, 3, 4))
            img_ra = paddle.to_tensor(data_ab[0])
            img_rb = paddle.to_tensor(data_ab[1])

            # 训练判别器DA
            d_loss_ra = paddle.mean((d_a(img_ra.detach()) - 1) ** 2)
            d_loss_fa = paddle.mean(
                d_a(fa_pool.pool_image(g_a(img_rb.detach()))) ** 2
            )
            da_loss = (d_loss_ra + d_loss_fa) * 0.5
            da_loss.backward()  # 反向更新梯度
            d_a_optimizer.step()  # 更新模型权重
            d_a_optimizer.clear_grad()  # 清除梯度

            # 训练判别器DB
            d_loss_rb = paddle.mean((d_b(img_rb.detach()) - 1) ** 2)
            d_loss_fb = paddle.mean(
                d_b(fb_pool.pool_image(g_b(img_ra.detach()))) ** 2
            )
            db_loss = (d_loss_rb + d_loss_fb) * 0.5
            db_loss.backward()
            d_b_optimizer.step()
            d_b_optimizer.clear_grad()

            # 训练生成器GA
            ga_gan_loss = paddle.mean((d_a(g_a(img_rb.detach())) - 1) ** 2)
            ga_cyc_loss = paddle.mean(
                paddle.abs(img_rb.detach() - g_b(g_a(img_rb.detach())))
            )
            ga_ide_loss = paddle.mean(
                paddle.abs(img_ra.detach() - g_a(img_ra.detach()))
            )
            ga_loss = (
                ga_gan_loss * adv_weight
                + ga_cyc_loss * cycle_weight
                + ga_ide_loss * identity_weight
            )
            ga_loss.backward()
            g_a_optimizer.step()
            g_a_optimizer.clear_grad()

            # 训练生成器GB
            gb_gan_loss = paddle.mean((d_b(g_b(img_ra.detach())) - 1) ** 2)
            gb_cyc_loss = paddle.mean(
                paddle.abs(img_ra.detach() - g_a(g_b(img_ra.detach())))
            )
            gb_ide_loss = paddle.mean(
                paddle.abs(img_rb.detach() - g_b(img_rb.detach()))
            )
            gb_loss = (
                gb_gan_loss * adv_weight
                + gb_cyc_loss * cycle_weight
                + gb_ide_loss * identity_weight
            )
            gb_loss.backward()
            g_b_optimizer.step()
            g_b_optimizer.clear_grad()

            # 存储训练过程中生成的图片
            if step in range(1, 101):
                pic_save_interval = 1
            elif step in range(101, 1001):
                pic_save_interval = 10
            elif step in range(1001, 10001):
                pic_save_interval = 100
            else:
                pic_save_interval = 500
            if step % pic_save_interval == 0:
                save_pics(
                    [
                        img_ra.numpy(),
                        g_b(img_ra).numpy(),
                        g_a(g_b(img_ra)).numpy(),
                        g_b(img_rb).numpy(),
                        img_rb.numpy(),
                        g_a(img_rb).numpy(),
                        g_b(g_a(img_rb)).numpy(),
                        g_a(img_ra).numpy(),
                    ],
                    str(step),
                )
                test_pic = open_pic()
                test_pic_pp = paddle.to_tensor(test_pic)
                save_pics(
                    [test_pic, g_b(test_pic_pp).numpy()],
                    str(step),
                    save_path="./output/pics_test/",
                )

            # 打印训练过程中的loss值和生成的图片
            if step % print_interval == 0:
                print(
                    [step],
                    "DA:",
                    da_loss.numpy(),
                    "DB:",
                    db_loss.numpy(),
                    "GA:",
                    ga_loss.numpy(),
                    "GB:",
                    gb_loss.numpy(),
                    time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                )
                show_pics(
                    [
                        img_ra.numpy(),
                        g_b(img_ra).numpy(),
                        g_a(g_b(img_ra)).numpy(),
                        g_b(img_rb).numpy(),
                    ]
                )
                show_pics(
                    [
                        img_rb.numpy(),
                        g_a(img_rb).numpy(),
                        g_b(g_a(img_rb)).numpy(),
                        g_a(img_ra).numpy(),
                    ]
                )

            # 定期备份模型
            if step % model_bkp_interval == 0:
                paddle.save(
                    g_a.state_dict(), model_path_bkp + "gen_b2a.pdparams"
                )
                paddle.save(
                    g_b.state_dict(), model_path_bkp + "gen_a2b.pdparams"
                )
                paddle.save(
                    d_a.state_dict(), model_path_bkp + "dis_ga.pdparams"
                )
                paddle.save(
                    d_b.state_dict(), model_path_bkp + "dis_gb.pdparams"
                )
                np.save(model_path_bkp + "total_step_num", np.array([step]))

            # 完成训练时存储模型
            if step >= max_step + total_step_num[0]:
                paddle.save(g_a.state_dict(), model_path + "gen_b2a.pdparams")
                paddle.save(g_b.state_dict(), model_path + "gen_a2b.pdparams")
                paddle.save(d_a.state_dict(), model_path + "dis_ga.pdparams")
                paddle.save(d_b.state_dict(), model_path + "dis_gb.pdparams")
                np.save(model_path + "total_step_num", np.array([step]))
                print(
                    "End time :",
                    time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                    "End Step:",
                    step,
                )
                return


# 从头训练
train()

# 继续训练
# train(print_interval=1, max_step=5, load_model=True)
# train(print_interval=500, max_step=20000, load_model=True)
Start time : 2021-03-10 11:36:45 start step: 1
[1] DA: [1.5323195] DB: [2.9221125] GA: [13.066509] GB: [20.061096] 2021-03-10 11:36:46
(1, 3, 224, 224)
(1, 3, 224, 224)
[2] DA: [3.431984] DB: [4.0848613] GA: [13.800614] GB: [12.840221] 2021-03-10 11:36:46
(1, 3, 224, 224)
(1, 3, 224, 224)
[3] DA: [3.3024106] DB: [2.2502034] GA: [12.881987] GB: [12.331587] 2021-03-10 11:36:47
(1, 3, 224, 224)
(1, 3, 224, 224)
[4] DA: [3.911097] DB: [1.5154138] GA: [12.64529] GB: [14.333654] 2021-03-10 11:36:47
(1, 3, 224, 224)
(1, 3, 224, 224)
[5] DA: [1.9493798] DB: [1.8769395] GA: [14.874502] GB: [11.431137] 2021-03-10 11:36:48
(1, 3, 224, 224)
(1, 3, 224, 224)
End time : 2021-03-10 11:36:48 End Step: 5

png

png

png

png

png

png

png

png

png

png

六、用训练好的模型进行预测

def infer(img_path, model_path="./model/"):
    # 定义生成器对象
    g_b = Gen()

    # 设置模型为训练模式,针对bn、dropout等进行不同处理
    g_b.eval()

    # 读取存储的模型
    gb_para_dict = paddle.load(model_path + "gen_a2b.pdparams")
    g_b.set_state_dict(gb_para_dict)

    # 读取图片数据
    img_a = cv2.imread(img_path)
    img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB)
    img_a = data_transform(img_a, 224)
    img_a = paddle.to_tensor(np.array([img_a]))

    # 正向计算进行推理
    img_b = g_b(img_a)

    # 打印输出输入、输出图片
    print(img_a.numpy().shape, img_a.numpy().dtype)
    show_pics([img_a.numpy(), img_b.numpy()])


infer("./data/data10040/horse2zebra/testA/n02381460_1300.jpg")
(1, 3, 224, 224) float32

png