0


Python——图像缺失弥补

    最近开始确认自己想要在Python和深度学习学习的一个方向,就是图像处理,自己对这部分还是很有兴趣的,所以最近看视频,然后根据代码做了一个图像缺失弥补的程序。这个课程我2年前是看过的,但是因为那时候的笔记本没办法跑这种吃资源的项目,所以工作后自己凑了一台3060的笔记本和2060的台式,专门用来跑程序。以下是对程序的理解。

    一、模型解读

    这个项目来源于一篇论文Globally and Locally Consistent Image Completion,如果想要理解这个模型,需要先大致了解一下这个论文。论文的中心思想是:先给图片挖掉一部分区域——用

这个图片去跑global completion网络,并且把网络参数保存——然后在completion基础上,用global completion得到的全局图片和生成的local图片分别跑Global Discriminator和local Discriminator,项目模型可以看下图:注意这里的图片输入,一个是完整未动过的图片(completion生成的),一个是从网络自己生成的图片中截取的local图片。我们本文的模型是跑一个completion和一个completion+discriminator,然后结果可以比较。

    二、网络解读

    通过模型可以看到这里面有两个大网络:completion和discriminator,而discriminator又分为global和local两部分,论文中对网络组成进行了详细描述,如下图:

其中的dilated conv是指空洞卷积,其目的是为了增加感受野,而deconv conv是反卷积,目的是把图片进行还原。在tensorflow中用空洞卷积使用tf.nn.atrous_conv2d(x, filters, dilation, padding='SAME'),而反卷积使用tf.nn.conv2d_transpose(x, filters, output_shape, [1, stride, stride, 1]),这个是比较方便的。网络的定义如下代码所示:

from layer import *
import tensorflow as tf
class Network:
    def __init__(self, x, mask, local_x, global_completion, local_completion, is_training, batch_size):
        self.batch_size = batch_size
        self.imitation = self.generator(x * (1 - mask), is_training)
        self.completion = self.imitation * mask + x * (1 - mask)
        # 由真的图片上截取下来的local_x跟原图x,输出结果就是True
        self.real = self.discriminator(x, local_x, reuse=False)
        # 由completion自己补的图片跟local discriminator补出来的图片,输出结果就是Fake
        self.fake = self.discriminator(global_completion, local_completion, reuse=True)
        self.g_loss = self.calc_g_loss(x, self.completion)
        self.d_loss = self.calc_d_loss(self.real, self.fake)
        self.g_variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        self.d_variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

    def generator(self, x, is_training):
        with tf.compat.v1.variable_scope('generator'):
            with tf.compat.v1.variable_scope('conv1'):
                x = conv_layer(x, [5, 5, 3, 64], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv2'):
                x = conv_layer(x, [3, 3, 64, 128], 2)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv3'):
                x = conv_layer(x, [3, 3, 128, 128], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv4'):
                x = conv_layer(x, [3, 3, 128, 256], 2)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv5'):
                x = conv_layer(x, [3, 3, 256, 256], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv6'):
                x = conv_layer(x, [3, 3, 256, 256], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('dilated1'):
                x = dilated_conv_layer(x, [3, 3, 256, 256], 2)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('dilated2'):
                x = dilated_conv_layer(x, [3, 3, 256, 256], 4)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('dilated3'):
                x = dilated_conv_layer(x, [3, 3, 256, 256], 8)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('dilated4'):
                x = dilated_conv_layer(x, [3, 3, 256, 256], 16)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv7'):
                x = conv_layer(x, [3, 3, 256, 256], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv8'):
                x = conv_layer(x, [3, 3, 256, 256], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('deconv1'):
                x = deconv_layer(x, [4, 4, 128, 256], [self.batch_size, 64, 64, 128], 2)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv9'):
                x = conv_layer(x, [3, 3, 128, 128], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('deconv2'):
                x = deconv_layer(x, [4, 4, 64, 128], [self.batch_size, 128, 128, 64], 2)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv10'):
                x = conv_layer(x, [3, 3, 64, 32], 1)
                x = batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.compat.v1.variable_scope('conv11'):
                x = conv_layer(x, [3, 3, 32, 3], 1)
                x = tf.nn.tanh(x)

        return x

    def discriminator(self, global_x, local_x, reuse):
        def global_discriminator(x):
            is_training = tf.constant(True)
            with tf.compat.v1.variable_scope('global'):
                with tf.compat.v1.variable_scope('conv1'):
                    x = conv_layer(x, [5, 5, 3, 64], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('conv2'):
                    x = conv_layer(x, [5, 5, 64, 128], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('conv3'):
                    x = conv_layer(x, [5, 5, 128, 256], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('conv4'):
                    x = conv_layer(x, [5, 5, 256, 512], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('conv5'):
                    x = conv_layer(x, [5, 5, 512, 512], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('fc'):
                    x = flatten_layer(x)
                    x = full_connection_layer(x, 1024)
            return x

        def local_discriminator(x):
            is_training = tf.constant(True)
            with tf.compat.v1.variable_scope('local'):
                with tf.compat.v1.variable_scope('conv1'):
                    x = conv_layer(x, [5, 5, 3, 64], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('conv2'):
                    x = conv_layer(x, [5, 5, 64, 128], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('conv3'):
                    x = conv_layer(x, [5, 5, 128, 256], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('conv4'):
                    x = conv_layer(x, [5, 5, 256, 512], 2)
                    x = batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.compat.v1.variable_scope('fc'):
                    x = flatten_layer(x)
                    x = full_connection_layer(x, 1024)
            return x

        with tf.compat.v1.variable_scope('discriminator', reuse=reuse):
            global_output = global_discriminator(global_x)
            local_output = local_discriminator(local_x)
            with tf.compat.v1.variable_scope('concatenation'):
                output = tf.compat.v1.concat((global_output, local_output), 1)
                output = full_connection_layer(output, 1)
               
        return output

    def calc_g_loss(self, x, completion):
        loss = tf.compat.v1.nn.l2_loss(x - completion)
        return tf.reduce_mean(loss)

    def calc_d_loss(self, real, fake):
        alpha = 4e-4
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real, labels=tf.ones_like(real)))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake, labels=tf.zeros_like(fake)))
        return tf.add(d_loss_real, d_loss_fake) * alpha
    关于loss值的选取:对于completion比较简单,采用MSE值来计算,就是简单地用生成的图片和真实图片做一个减法,就可以得出loss值;而discriminator则比较复杂一点,我理解了很久,因为论文提及用fake和real来进行判别,使用tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake, labels=tf.zeros_like(fake)))来进行计算,但是对于谁是fake谁是real并不是很清晰,我分析代码后得出:
由真的图片上截取下来的local_x跟原图x,输出结果就是True
由completion自己补的图片跟local discriminator补出来的图片,输出结果就是Fake

在程序中,它们的定义为

self.real = self.discriminator(x, local_x, reuse=False)
self.fake = self.discriminator(global_completion, local_completion, reuse=True)
    三、程序分析

    程序的架构是:数据处理——网络定义——建立模型——模型计算——结果展示

    数据处理:这次补全的图片采用人像,整个数据有20万+,如果每次都导入这么多数据,将会非常浪费时间跟资源,因此,程序先将这些图片进行压缩,并且转为npy格式,同时,为了节省资源,只选取其中5000张图片,x_train为95%。

    网络定义:已经在上面讲过,这里不再赘述。

    建立模型:模型建立就是把一些过程量定义出来,这里需要解释一下mask,它本身是一个黑白图,把要被填充的地方,置为1,其它地方置为0,而这部分是通过get_points函数来实现,这个函数计算出local的大小和坐标,然后对该部分进行填充,代码如下所示。
x = tf.compat.v1.placeholder(tf.float32,[BATCH_SIZE,IMAGE_SIZE,IMAGE_SIZE,3])
    mask = tf.compat.v1.placeholder(tf.float32,[BATCH_SIZE,IMAGE_SIZE,IMAGE_SIZE,1])
    local_x = tf.compat.v1.placeholder(tf.float32,[BATCH_SIZE,LOCAL_SIZE,LOCAL_SIZE,3])
    global_completion = tf.compat.v1.placeholder(tf.float32,[BATCH_SIZE,IMAGE_SIZE,IMAGE_SIZE,3])
    local_completion = tf.compat.v1.placeholder(tf.float32,[BATCH_SIZE,LOCAL_SIZE,LOCAL_SIZE,3])
    is_training = tf.compat.v1.placeholder(tf.bool,[])
    model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
    sess = tf.compat.v1.Session()
    global_step = tf.compat.v1.Variable(0,name='global_step',trainable=False)
    epoch = tf.compat.v1.Variable(0,name='epoch',trainable=False)
    opt = tf.compat.v1.train.AdamOptimizer(learning_rate=LEARNING_RATE)
    # var_list:默认是GraphKeys.TRAINABLE_VARIABLES
    g_train_op = opt.minimize(model.g_loss, global_step=global_step, var_list=model.g_variables)
    d_train_op = opt.minimize(model.d_loss, global_step=global_step, var_list=model.d_variables)
    init_opt = tf.compat.v1.global_variables_initializer()
    sess.run(init_opt)
def get_points():
    points = []
    mask = []
    for i in range(BATCH_SIZE):
        # starting coordinate of the hole
        x1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2)
        x2, y2 = np.array([x1, y1]) + LOCAL_SIZE
        points.append([x1, y1, x2, y2])

        # weight,height
        w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2)
        p1 = x1 + np.random.randint(0, LOCAL_SIZE - w)
        q1 = y1 + np.random.randint(0, LOCAL_SIZE - h)
        p2 = p1 + w
        q2 = q1 + h

        m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
        m[q1:q2 + 1, p1:p2 + 1] = 1
        mask.append(m)

    return np.array(points), np.array(mask)
     模型计算:通过模型定义,我们给定一个PRETRAIN_EPOCH值,如果epoch超过这个值,就停止completion的计算,保存模型参数,然后开始计算discriminator,而这部分源程序中没有给出停止的条件,所以我给定一个stop_loss:1e-4,当loss值低于这个数就保存模型跳出。这里就是跑的最久的地方,如果batch_size给的太大,电脑资源容易不够,我用台式电脑:6g 2060显卡跑这个模型,只能用batch_size=16,不然就会算不下去。

    模型展示:最后,我们通过x_test来看一下计算结果,结果分为两个,一个是completion完成的,一个是completion+discriminator完成的。

    下图是completion最后出来的图,效果还可以,有点像打了马赛克;

    下图是原始图和模型图的对照,结果也还不错,如果模型继续训练可以得到更好的结果,论文中的图是跑了好几天的:

    ![](https://img-blog.csdnimg.cn/66e354fcd2e946f28ee01546a9000c0b.png)

    下图是一个是completion(上)完成的,一个是completion+discriminator(下)完成的,下面那张的肤色比上面的偏白。

代码下载链接python图像缺失弥补源码资源-CSDN文库


本文转载自: https://blog.csdn.net/wenpeitao/article/details/127723125
版权归原作者 我是灵魂人物 所有, 如有侵权,请联系我们删除。

“Python——图像缺失弥补”的评论:

还没有评论