Source code for tensortrade.models.generative.wgan

# Copyright 2019 The TensorTrade Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
#
# Modified source: https://github.com/timsainb/tensorflow2-generative-models/blob/master/3.0-WGAN-GP-fashion-mnist.ipynb
# Source reference: https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Tensorflow-2/
# Original paper: https://arxiv.org/abs/1701.07875

import tensorflow as tf


[docs]class WGAN(tf.keras.Model): def __init__(self, generator: tf.keras.Sequential, discriminator: tf.keras.Sequential, **kwargs): super().__init__() self.n_samples = kwargs.get('n_samples', 64) self.gradient_penalty_weight = kwargs.get('gradient_penalty_weight', 10.0) self.generator_lr = kwargs.get('generator_lr', 0.0001) self.generator_beta_1 = kwargs.get('generator_beta_1', 0.5) self.discriminator_lr = kwargs.get('discriminator_lr', 0.0005) self.generator = generator self.discriminator = discriminator self.gen_optimizer = tf.keras.optimizers.Adam( self.generator_lr, beta_1=self.generator_beta_1) self.disc_optimizer = tf.keras.optimizers.RMSprop(self.discriminator_lr)
[docs] def generate(self, z): return self.generator(z)
[docs] def discriminate(self, x): return self.discriminator(x)
[docs] def generate_random(self): return self.generate(tf.random.normal(shape=(1, self.n_samples)))
[docs] def gradient_penalty(self, x, x_gen): epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0) x_hat = epsilon * x + (1 - epsilon) * x_gen with tf.GradientTape() as t: t.watch(x_hat) d_hat = self.discriminate(x_hat) gradients = t.gradient(d_hat, x_hat) ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2])) d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2) return d_regularizer
[docs] def compute_loss(self, x): z_samp = tf.random.normal([x.shape[0], 1, 1, self.n_samples]) x_gen = self.generate(z_samp) logits_x = self.discriminate(x) logits_x_gen = self.discriminate(x_gen) d_regularizer = self.gradient_penalty(x, x_gen) disc_loss = ( tf.reduce_mean(logits_x) - tf.reduce_mean(logits_x_gen) + d_regularizer * self.gradient_penalty_weight ) gen_loss = tf.reduce_mean(logits_x_gen) return disc_loss, gen_loss
[docs] def compute_gradients(self, x): with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: disc_loss, gen_loss = self.compute_loss(x) gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables) disc_gradients = disc_tape.gradient( disc_loss, self.discriminator.trainable_variables) return gen_gradients, disc_gradients
[docs] def apply_gradient(self, model, optimizer, gradients): optimizer.apply_gradients(zip(gradients, model.trainable_variables))
[docs] @tf.function def train(self, train_x): gen_gradients, disc_gradients = self.compute_gradients(train_x) self.apply_gradient(self.generator, self.gen_optimizer, gen_gradients) self.apply_gradient(self.discriminator, self.disc_optimizer, disc_gradients)