项目简介
这篇文章主要介绍了生成对抗网络(Generative Adversarial Network),简称 GAN。
GAN 可以看作是一种可以生成特定分布数据的模型。
2.生成人脸图像
下面的代码是使用 Generator 来生成人脸图像,Generator 已经训练好保存在 pkl 文件中,只需要加载参数即可。由于模型是在多 GPU 的机器上训练的,因此加载参数后需要使用
remove_module()
函数来修改
state_dict
中的
key
。
def remove_module(state_dict_g):
# remove module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_g.items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v
return new_state_dict
把随机的高斯噪声输入到模型中,就可以得到人脸输出,最后进行可视化。全部代码如下:
import os
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from common_tools i
版权归原作者 程序员uu 所有, 如有侵权,请联系我们删除。