Posted in

Jax vs Tensorflow vs pytorch: Construindo um autoencoder variacional (VAE)

Jax vs Tensorflow vs pytorch: Construindo um autoencoder variacional (VAE)

Fiquei muito curioso para ver como o JAX é comparado ao Pytorch ou Tensorflow. Imaginei que a melhor maneira de alguém comparar estruturas é construir a mesma coisa do zero nos dois. E foi exatamente isso que eu fiz. Neste artigo, estou desenvolvendo um autoencoder variacional com Jax, Tensorflow e Pytorch ao mesmo tempo. Apresentarei o código para cada componente lado a lado, a fim de encontrar diferenças, semelhanças, fraquezas e pontos fortes.

Devemos começar?

Prólogo

Algumas coisas a serem observadas antes de explorarmos o código:

  • Eu vou usar Linho No topo da JAX, que é uma biblioteca de rede neural desenvolvida pelo Google. Ele contém muitos módulos, camadas, funções e operações de aprendizado profundo pronto para uso

  • Para a implementação do tensorflow, vou confiar em Duro Abstrações.

  • Para Pytorch, usarei o padrão nn.module.

Como a maioria de nós conhece o Tensorflow e o Pytorch, prestaremos mais atenção em Jax e Linho. É por isso que explicarei as coisas ao longo do caminho que podem não estar familiarizadas com muitos. Assim, você pode considerar este artigo como um tutorial de luz sobre linho também.

Além disso, presumo que você esteja familiarizado com os princípios básicos por trás de Vaes. Caso contrário, você pode aconselhar meu artigo anterior sobre Modelos variáveis ​​latentes. Se tudo parecer claro, vamos continuar.

Recapitulação rápida: O AutoEncoder de baunilha consiste em um codificador e um decodificador. O codificador converte a entrada em uma representação latente zz e o decodificador tenta reconstruir a entrada com base nessa representação. Nos autoencoders variacionais, a estocástica também é adicionada à mistura em termos de que a representação latente fornece uma distribuição de probabilidade. Isso está acontecendo com o truque de reparo.




Jax vs Tensorflow vs pytorch: Construindo um autoencoder variacional (VAE)


Imagem por autor

O codificador

Para o codificador, uma camada linear simples seguida de uma ativação RelU deve ser suficiente para um exemplo de brinquedo. A saída da camada será a média e o desvio padrão da distribuição de probabilidade.

O bloco básico de construção da API de linho é o Module Abstração, que é o que usaremos para implementar nosso codificador no JAX. O module faz parte do linen Subpackage. Semelhante ao de Pytorch nn.modulenovamente precisamos definir nossos argumentos de classe. Em Pytorch, estamos acostumados a declará -los dentro do __init__ função e implementação do passe direto dentro do forward método. Em linho, as coisas são um pouco diferentes. Os argumentos são definidos como atributos de dataclass ou como argumentos do método. Geralmente, as propriedades fixas são definidas como argumentos de dataclass, enquanto propriedades dinâmicas como argumentos do método. Também em vez de implementar um forward Método, implementamos __call__

O Módulo Dataclass é introduzido no Python 3.7 como uma ferramenta utilitária para fazer classes estruturadas, especialmente para armazenar dados. Essas classes mantêm certas propriedades e funções para lidar especificamente com os dados e sua representação. Eles também reduzem muito código de caldeira em comparação com as classes regulares.

Então, para criar um novo módulo de linho, precisamos:

  • Inicialize uma classe que herda flax.linen.nn.Module

  • Defina os argumentos estáticos como argumentos de dataclass

  • Implementar o passe direto dentro do __call_ método.

Para amarrar os argumentos com o modelo e ser capaz de definir submódulos diretamente dentro do módulo, também precisamos anotar o __call__ método com @nn.compact.

Observe que, em vez de usar argumentos de dataclass e o @nn.compact anotação, poderíamos ter declarado todos os argumentos dentro de um setup Método exatamente da mesma maneira que fazemos em Pytorch ou Tensorflow __init__.

import numpy as np

import jax

import jax.numpy as jnp

from jax import random

from flax import linen as nn

from flax import optim

class Encoder(nn.Module):

latents: int

@nn.compact

def __call__(self, x):

x = nn.Dense(500, name='fc1')(x)

x = nn.relu(x)

mean_x = nn.Dense(self.latents, name='fc2_mean')(x)

logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)

return mean_x, logvar_x

import tensorflow as tf

from tensorflow.keras import layers

class Encoder(layers.Layer):

def __init__(self,

latent_dim =20,

name='encoder',

**kwargs):

super(Encoder, self).__init__(name=name, **kwargs)

self.enc1 = layers.Dense(500, activation='relu')

self.mean_x = layers.Dense(latent_dim)

self.logvar_x = layers.Dense(latent_dim)

def call(self, inputs):

x = self.enc1(inputs)

z_mean = self.mean_x(x)

z_log_var = self.logvar_x(x)

return z_mean, z_log_var

import torch

import torch.nn.functional as F

class Encoder(torch.nn.Module):

def __init__(self, latent_dim=20):

super(Encoder, self).__init__()

self.enc1 = torch.nn.Linear(784, 500)

self.mean_x = torch.nn.Linear(500,latent_dim)

self.logvar_x = torch.nn.Linear(500, latent_dim)

def forward(self,inputs):

x = self.enc1(inputs)

x= F.relu(x)

z_mean = self.mean_x(x)

z_log_var = self.logvar_x(x)

return z_mean, z_log_var

Mais algumas coisas a perceber aqui antes de prosseguirmos:

  • Linho nn.linen o pacote contém a maioria das camadas de aprendizado profundo e operação, como DenseAssim, relue muito mais

  • O código em linho, tensorflow e pytorch é quase indistinguível um do outro.

O decodificador

De uma maneira muito semelhante, podemos desenvolver o decodificador em todas as três estruturas. O decodificador será duas camadas lineares que receberão a representação latente zz e emitir a entrada reconstruída.

Novamente, as implementações são muito semelhantes.

class Decoder(nn.Module):

@nn.compact

def __call__(self, z):

z = nn.Dense(500, name='fc1')(z)

z = nn.relu(z)

z = nn.Dense(784, name='fc2')(z)

return z

class Decoder(layers.Layer):

def __init__(self,

name='decoder',

**kwargs):

super(Decoder, self).__init__(name=name, **kwargs)

self.dec1 = layers.Dense(500, activation='relu')

self.out = layers.Dense(784)

def call(self, z):

z = self.dec1(z)

return self.out(z)

class Decoder(torch.nn.Module):

def __init__(self, latent_dim=20):

super(Decoder, self).__init__()

self.dec1 = torch.nn.Linear(latent_dim, 500)

self.out = torch.nn.Linear(500, 784)

def forward(self,z):

z = self.dec1(z)

z = F.relu(z)

return self.out(z)

Carscoder variacional

Para combinar o codificador e o decodificador, vamos ter mais uma classe, chamado VAEisso representará toda a arquitetura. Aqui também precisamos escrever algum código para o truque de reparameterização. No geral, temos: a variável latente do codificador é reparameterizada e alimentada ao decodificador, que produz a entrada reconstruída.

Como lembrete, aqui está uma imagem intuitiva que explica o truque de reparameterização:




Reparameterização-trick


Fonte: Alexander Amini e Ava Soleimany, modelagem generativa profunda | MIT 6.S191, http://introtodeeplearning.com/

Observe que desta vez, em Jax, fazemos uso do setup método em vez do nn.compact anotação. Além disso, confira como as funções de reparameterização são semelhantes. Claro que cada estrutura usa suas próprias funções e operações, mas a imagem geral é quase idêntica.

class VAE(nn.Module):

latents: int = 20

def setup(self):

self.encoder = Encoder(self.latents)

self.decoder = Decoder()

def __call__(self, x, z_rng):

mean, logvar = self.encoder(x)

z = reparameterize(z_rng, mean, logvar)

recon_x = self.decoder(z)

return recon_x, mean, logvar

def reparameterize(rng, mean, logvar):

std = jnp.exp(0.5 * logvar)

eps = random.normal(rng, logvar.shape)

return mean + eps * std

def model():

return VAE(latents=LATENTS)

class VAE(tf.keras.Model):

def __init__(self,

latent_dim=20,

name='vae',

**kwargs):

super(VAE, self).__init__(name=name, **kwargs)

self.encoder = Encoder(latent_dim=latent_dim)

self.decoder = Decoder()

def call(self, inputs):

z_mean, z_log_var = self.encoder(inputs)

z = self.reparameterize(z_mean, z_log_var)

reconstructed = self.decoder(z)

return reconstructed, z_mean, z_log_var

def reparameterize(self, mean, logvar):

eps = tf.random.normal(shape=mean.shape)

return mean + eps * tf.exp(logvar * .5)

class VAE(torch.nn.Module):

def __init__(self, latent_dim=20):

super(VAE, self).__init__()

self.encoder = Encoder(latent_dim)

self.decoder = Decoder(latent_dim)

def forward(self,inputs):

z_mean, z_log_var = self.encoder(inputs)

z = self.reparameterize(z_mean, z_log_var)

reconstructed = self.decoder(z)

return reconstructed, z_mean, z_log_var

def reparameterize(self, mu, log_var):

std = torch.exp(0.5 * log_var)

eps = torch.randn_like(std)

return mu + (eps * std)

Etapa de perda e treinamento

As coisas estão começando a diferir quando começamos a implementar a etapa de treinamento e a função de perda. Mas não muito.

  1. Para aproveitar completamente de Capacidades JAXprecisamos adicionar vetorização automática e compilação XLA ao nosso código. Isso pode ser feito facilmente com a ajuda de vmap e jit anotações.

  2. Além disso, temos que permitir a diferenciação automática, que pode ser realizada com o grad_fn transformação

  3. Nós usamos o flax.optim Pacote para algoritmos de otimização

Outra pequena diferença que precisamos estar ciente é como passamos dados para o nosso modelo. Isso pode ser alcançado através do método de aplicação na forma de model().apply({'params': params}, batch, z_rng)onde batch são nossos dados de treinamento.

@jax.vmap

def kl_divergence(mean, logvar):

return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap

def binary_cross_entropy_with_logits(logits, labels):

logits = nn.log_sigmoid(logits)

return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

@jax.jit

def train_step(optimizer, batch, z_rng):

def loss_fn(params):

recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)

bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()

kld_loss = kl_divergence(mean, logvar).mean()

loss = bce_loss + kld_loss

return loss, recon_x

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

_, grad = grad_fn(optimizer.target)

optimizer = optimizer.apply_gradient(grad)

return optimizer

def kl_divergence(mean, logvar):

return -0.5 * tf.reduce_sum(

1 + logvar - tf.square(mean) -

tf.exp(logvar), axis=1)

def binary_cross_entropy_with_logits(logits, labels):

logits = tf.math.log(logits)

return - tf.reduce_sum(

labels * logits +

(1-labels) * tf.math.log(- tf.math.expm1(logits)),

axis=1

)

@tf.function

def train_step(model, x, optimizer):

with tf.GradientTape() as tape:

recon_x, mean, logvar = model(x)

bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))

kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))

loss = bce_loss + kld_loss

print(loss, kld_loss, bce_loss)

gradients = tape.gradient(loss, model.trainable_variables)

optimizer.apply_gradients(zip(gradients, model.trainable_variables))

def final_loss(reconstruction, train_x, mu, logvar):

BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

return BCE + KLD

def train_step(train_x):

train_x = torch.from_numpy(train_x)

optimizer.zero_grad()

reconstruction, mu, logvar = model(train_x)

loss = final_loss(reconstruction, train_x, mu, logvar)

running_loss += loss.item()

loss.backward()

optimizer.step()

Lembre -se de que os VAEs são treinados ao maximizar a evidência limite inferior, conhecido como Elbo.

LthAssim,ϕ(x)=Eqϕ(zx)(logpth(xz))KL(qϕ(zx)pth(z))L _ {\ theta, \ heta, (x) (x) (x) (x) (x) (x) (x) (x) (x) _ {e} _ {q _ {\ phi} (z | x) -s} (log p _ {{\ theuta} (x | \ textb´) (q- \ phi} (zix) (z | x || p_ \ thet} (z} (z))

Loop de treinamento

Finalmente, é hora de todo o ciclo de treinamento que executará o train_step função iterativamente.

Em linho, o modelo deve ser inicializado antes do treinamento, o que é feito pelo init função como: params = model().init(key, init_data, rng)('params'). Uma inicialização semelhante também é necessária para o otimizador: optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params ).

jax.device_put é usado para transferir o otimizador para a memória da GPU.

rng = random.PRNGKey(0)

rng, key = random.split(rng)

init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)

params = model().init(key, init_data, rng)('params')

optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)

optimizer = jax.device_put(optimizer)

rng, z_key, eval_rng = random.split(rng, 3)

z = random.normal(z_key, (64, LATENTS))

steps_per_epoch = 50000 // BATCH_SIZE

for epoch in range(NUM_EPOCHS):

for _ in range(steps_per_epoch):

batch = next(train_ds)

rng, key = random.split(rng)

optimizer = train_step(optimizer, batch, key)

vae = VAE(latent_dim=LATENTS)

optimizer = tf.keras.optimizers.Adam(1e-4)

for epoch in range(NUM_EPOCHS):

for train_x in train_ds:

train_step(vae, train_x, optimizer)

def train(model,training_data):

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

running_loss = 0.0

for epoch in range(NUM_EPOCHS):

for i, train_x in enumerate(training_data, 0):

train_step(train_x)

vae = VAE(LATENTS)

train(vae, train_ds)

Carregar e processar dados

Uma coisa que não mencionei são dados. Como carregamos dados de pré -processamento em linho? Bem, o linho não inclui pacotes de manipulação de dados, além das operações básicas de jax.numpy. No momento, o nosso melhor é emprestar pacotes de outras estruturas, como o TensorFlow DataSets (TFDS) ou a Torchvision. Para tornar o artigo auto-preenchido, incluirei o código que usei para carregar um conjunto de dados de treinamento de amostra com tfds. Sinta -se à vontade para usar seu próprio Dataloader se você planeja executar as implementações apresentadas neste artigo.

import tensorflow_datasets as tfds

tf.config.experimental.set_visible_devices((), 'GPU')

def prepare_image(x):

x = tf.cast(x('image'), tf.float32)

x = tf.reshape(x, (-1,))

return x

ds_builder = tfds.builder('binarized_mnist')

ds_builder.download_and_prepare()

train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)

train_ds = train_ds.map(prepare_image)

train_ds = train_ds.cache()

train_ds = train_ds.repeat()

train_ds = train_ds.shuffle(50000)

train_ds = train_ds.batch(BATCH_SIZE)

train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)

test_ds = test_ds.map(prepare_image).batch(10000)

test_ds = np.array(list(test_ds)(0))

Observações finais

Para encerrar o artigo, vamos discutir algumas observações finais que aparecem após uma análise estreita do código:

  • Todas as três estruturas reduziram o código da caldeira ao mínimo, com o linho sendo aquele que requer um pouco mais, especialmente na parte de treinamento. No entanto, isso é apenas para garantir que exploremos todas as transformações disponíveis, como diferenciação automática, vetorização e compilador just-in-time.

  • A definição de módulos, camadas e modelos é quase idêntica em todos eles

  • Linho e Jax são por design bastante flexíveis e expansíveis

  • O linho não possui recursos de carregamento e processamento de dados ainda

  • Em termos de camadas e otimizadores prontos para uso, o linho não precisa ter ciúmes do Tensorflow e Pytorch. Com certeza não possui a biblioteca gigante de seus concorrentes, mas está gradualmente chegando lá.

Aprendizagem profunda no livro de produção 📖

Aprenda a construir, treinar, implantar, escalar e manter modelos de aprendizado profundo. Entenda a infraestrutura de ML e os MLOPs usando exemplos práticos.

Saber mais

* Divulgação: Observe que alguns dos links acima podem ser links de afiliados e, sem custo adicional, ganharemos uma comissão se você decidir fazer uma compra depois de clicar.

Luis es un experto en Ciberseguridad, Computación en la Nube, Criptomonedas e Inteligencia Artificial. Con amplia experiencia en tecnología, su objetivo es compartir conocimientos prácticos para ayudar a los lectores a entender y aprovechar estas áreas digitales clave.

Leave a Reply

Your email address will not be published. Required fields are marked *

router with sim card slot