Tensorflow variational self encoder: Reconstruction and generation of Fashion MNIST image

The input is Fashion MNIST picture vector. After three full connection layers, the mean and variance of hidden vector z are obtained. They are represented by two full connection layers with 20 output nodes, 20 output nodes of fc2 represent the mean vector μ of 20 feature distributions, and 20 output nodes of fc3 represent the log value of the variance vector of 20 feature distributions. The hidden vector z with length of 20 is obtained by reparameterization trip sampling, and the sample image is reconstructed by fc4/fc5.

As a generation model, VAE can not only reconstruct the input samples, but also generate the samples by using the decoder alone. The hidden vector Z is obtained by sampling directly from the prior distribution p(z). After decoding, the generated samples can be generated.

Code

import tensorflow as tf 
from tensorflow import keras
import numpy as np
from    matplotlib import pyplot as plt
from    PIL import Image


(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = tf.convert_to_tensor(x_train/255., tf.float32)
x_test = tf.convert_to_tensor(x_test/255., tf.float32) 

batchsz = 100
train_db = tf.data.Dataset.from_tensor_slices(x_train)
test_db = tf.data.Dataset.from_tensor_slices(x_test) 

train_db = train_db.shuffle(batchsz*5).batch(batchsz).repeat(10)
test_db = test_db.batch(batchsz)


class VAE(keras.Model):
    # Variational self encoder
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder network
        self.fc1 = keras.layers.Dense(128)
        self.fc2 = keras.layers.Dense(20)
        self.fc3 = keras.layers.Dense(20)
        # Decoder network
        self.fc4 = keras.layers.Dense(128)
        self.fc5 = keras.layers.Dense(784)
    
    def encoder(self, x):
        h = tf.nn.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        eps = tf.random.normal(log_var.shape)
        std = tf.exp(log_var)**0.5
        z = mu + std*eps
        return z

    def decoder(self, z):
        out = tf.nn.relu(self.fc4(z))
        out = self.fc5(out)
        return out

    def call(self, inputs, training=None):
        mu, log_var = self.encoder(inputs)
        z = self.reparameterize(mu, log_var)

        x_hat = self.decoder(z)
        return x_hat, mu, log_var

model = VAE()
model.build(input_shape=(4,784))
model.summary()


optimizer = keras.optimizers.Adam(learning_rate=1e-3) 
for step, x in enumerate(train_db):
    x = tf.reshape(x, [-1,784])
    with tf.GradientTape() as tape:
        x_rec_logits, mu, log_var = model(x)
        rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
        rec_loss = tf.reduce_sum(rec_loss) / x.shape[0] 

        kl_div = -0.5 * (log_var + 1 - mu**2 - tf.exp(log_var))                       
        kl_div = tf.reduce_sum(kl_div) / x.shape[0]
        loss = rec_loss + 1. * kl_div 

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables)) 

    if step%100 == 0:
        print(step, 'kl div: ', float(kl_div), 'loss: ', float(loss))

def save_images(imgs, name):
    new_im = Image.new('L', (280, 280))
    index = 0
    for i in range(0, 280, 28): # 10 line image array
        for j in range(0, 280, 28): # 10 column image array
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j)) # Write corresponding position
            index += 1

    # Save picture array
    new_im.save(name) 

z = tf.random.normal((100, 20))
logits = model.decoder(z)
x_hat = tf.sigmoid(logits)
x_hat = tf.reshape(x_hat, [-1,28,28]).numpy() *255.
x_hat = x_hat.astype(np.uint8)
save_images(x_hat, 'vaebuild.png')

x = next(iter(test_db))
logits, _, _ = model(tf.reshape(x, [-1,784]))
x_hat = tf.sigmoid(logits)
x_hat = tf.reshape(x_hat, [-1,28,28]) 

x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
x_concat = x_concat.numpy() * 255. 
x_concat = x_concat.astype(np.uint8)
save_images(x_concat,'10_vae.png')

 

 

93 original articles published, praised 2, visited 3003
Private letter follow

Tags: network

Posted on Tue, 17 Mar 2020 10:30:23 -0400 by Lashiec