Fiquei muito curioso para ver como o JAX é comparado ao Pytorch ou Tensorflow. Imaginei que a melhor maneira de alguém comparar estruturas é construir a mesma coisa do zero nos dois. E foi exatamente isso que eu fiz. Neste artigo, estou desenvolvendo um autoencoder variacional com Jax, Tensorflow e Pytorch ao mesmo tempo. Apresentarei o código para cada componente lado a lado, a fim de encontrar diferenças, semelhanças, fraquezas e pontos fortes.
Devemos começar?
Prólogo
Algumas coisas a serem observadas antes de explorarmos o código:
-
Eu vou usar Linho No topo da JAX, que é uma biblioteca de rede neural desenvolvida pelo Google. Ele contém muitos módulos, camadas, funções e operações de aprendizado profundo pronto para uso
-
Para a implementação do tensorflow, vou confiar em Duro Abstrações.
-
Para Pytorch, usarei o padrão
nn.module
.
Como a maioria de nós conhece o Tensorflow e o Pytorch, prestaremos mais atenção em Jax e Linho. É por isso que explicarei as coisas ao longo do caminho que podem não estar familiarizadas com muitos. Assim, você pode considerar este artigo como um tutorial de luz sobre linho também.
Além disso, presumo que você esteja familiarizado com os princípios básicos por trás de Vaes. Caso contrário, você pode aconselhar meu artigo anterior sobre Modelos variáveis latentes. Se tudo parecer claro, vamos continuar.
Recapitulação rápida: O AutoEncoder de baunilha consiste em um codificador e um decodificador. O codificador converte a entrada em uma representação latente e o decodificador tenta reconstruir a entrada com base nessa representação. Nos autoencoders variacionais, a estocástica também é adicionada à mistura em termos de que a representação latente fornece uma distribuição de probabilidade. Isso está acontecendo com o truque de reparo.
O codificador
Para o codificador, uma camada linear simples seguida de uma ativação RelU deve ser suficiente para um exemplo de brinquedo. A saída da camada será a média e o desvio padrão da distribuição de probabilidade.
O bloco básico de construção da API de linho é o Module
Abstração, que é o que usaremos para implementar nosso codificador no JAX. O module
faz parte do linen
Subpackage. Semelhante ao de Pytorch nn.module
novamente precisamos definir nossos argumentos de classe. Em Pytorch, estamos acostumados a declará -los dentro do __init__
função e implementação do passe direto dentro do forward
método. Em linho, as coisas são um pouco diferentes. Os argumentos são definidos como atributos de dataclass ou como argumentos do método. Geralmente, as propriedades fixas são definidas como argumentos de dataclass, enquanto propriedades dinâmicas como argumentos do método. Também em vez de implementar um forward
Método, implementamos __call__
O Módulo Dataclass é introduzido no Python 3.7 como uma ferramenta utilitária para fazer classes estruturadas, especialmente para armazenar dados. Essas classes mantêm certas propriedades e funções para lidar especificamente com os dados e sua representação. Eles também reduzem muito código de caldeira em comparação com as classes regulares.
Então, para criar um novo módulo de linho, precisamos:
-
Inicialize uma classe que herda
flax.linen.nn.Module
-
Defina os argumentos estáticos como argumentos de dataclass
-
Implementar o passe direto dentro do
__call_
método.
Para amarrar os argumentos com o modelo e ser capaz de definir submódulos diretamente dentro do módulo, também precisamos anotar o __call__
método com @nn.compact
.
Observe que, em vez de usar argumentos de dataclass e o @nn.compact
anotação, poderíamos ter declarado todos os argumentos dentro de um setup
Método exatamente da mesma maneira que fazemos em Pytorch ou Tensorflow __init__
.
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from flax import optim
class Encoder(nn.Module):
latents: int
@nn.compact
def __call__(self, x):
x = nn.Dense(500, name='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
return mean_x, logvar_x
import tensorflow as tf
from tensorflow.keras import layers
class Encoder(layers.Layer):
def __init__(self,
latent_dim =20,
name='encoder',
**kwargs):
super(Encoder, self).__init__(name=name, **kwargs)
self.enc1 = layers.Dense(500, activation='relu')
self.mean_x = layers.Dense(latent_dim)
self.logvar_x = layers.Dense(latent_dim)
def call(self, inputs):
x = self.enc1(inputs)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
import torch
import torch.nn.functional as F
class Encoder(torch.nn.Module):
def __init__(self, latent_dim=20):
super(Encoder, self).__init__()
self.enc1 = torch.nn.Linear(784, 500)
self.mean_x = torch.nn.Linear(500,latent_dim)
self.logvar_x = torch.nn.Linear(500, latent_dim)
def forward(self,inputs):
x = self.enc1(inputs)
x= F.relu(x)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
Mais algumas coisas a perceber aqui antes de prosseguirmos:
-
Linho
nn.linen
o pacote contém a maioria das camadas de aprendizado profundo e operação, comoDense
Assim,relu
e muito mais -
O código em linho, tensorflow e pytorch é quase indistinguível um do outro.
O decodificador
De uma maneira muito semelhante, podemos desenvolver o decodificador em todas as três estruturas. O decodificador será duas camadas lineares que receberão a representação latente e emitir a entrada reconstruída.
Novamente, as implementações são muito semelhantes.
class Decoder(nn.Module):
@nn.compact
def __call__(self, z):
z = nn.Dense(500, name='fc1')(z)
z = nn.relu(z)
z = nn.Dense(784, name='fc2')(z)
return z
class Decoder(layers.Layer):
def __init__(self,
name='decoder',
**kwargs):
super(Decoder, self).__init__(name=name, **kwargs)
self.dec1 = layers.Dense(500, activation='relu')
self.out = layers.Dense(784)
def call(self, z):
z = self.dec1(z)
return self.out(z)
class Decoder(torch.nn.Module):
def __init__(self, latent_dim=20):
super(Decoder, self).__init__()
self.dec1 = torch.nn.Linear(latent_dim, 500)
self.out = torch.nn.Linear(500, 784)
def forward(self,z):
z = self.dec1(z)
z = F.relu(z)
return self.out(z)
Carscoder variacional
Para combinar o codificador e o decodificador, vamos ter mais uma classe, chamado VAE
isso representará toda a arquitetura. Aqui também precisamos escrever algum código para o truque de reparameterização. No geral, temos: a variável latente do codificador é reparameterizada e alimentada ao decodificador, que produz a entrada reconstruída.
Como lembrete, aqui está uma imagem intuitiva que explica o truque de reparameterização:
Fonte: Alexander Amini e Ava Soleimany, modelagem generativa profunda | MIT 6.S191, http://introtodeeplearning.com/
Observe que desta vez, em Jax, fazemos uso do setup
método em vez do nn.compact
anotação. Além disso, confira como as funções de reparameterização são semelhantes. Claro que cada estrutura usa suas próprias funções e operações, mas a imagem geral é quase idêntica.
class VAE(nn.Module):
latents: int = 20
def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __call__(self, x, z_rng):
mean, logvar = self.encoder(x)
z = reparameterize(z_rng, mean, logvar)
recon_x = self.decoder(z)
return recon_x, mean, logvar
def reparameterize(rng, mean, logvar):
std = jnp.exp(0.5 * logvar)
eps = random.normal(rng, logvar.shape)
return mean + eps * std
def model():
return VAE(latents=LATENTS)
class VAE(tf.keras.Model):
def __init__(self,
latent_dim=20,
name='vae',
**kwargs):
super(VAE, self).__init__(name=name, **kwargs)
self.encoder = Encoder(latent_dim=latent_dim)
self.decoder = Decoder()
def call(self, inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return mean + eps * tf.exp(logvar * .5)
class VAE(torch.nn.Module):
def __init__(self, latent_dim=20):
super(VAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def forward(self,inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + (eps * std)
Etapa de perda e treinamento
As coisas estão começando a diferir quando começamos a implementar a etapa de treinamento e a função de perda. Mas não muito.
-
Para aproveitar completamente de Capacidades JAXprecisamos adicionar vetorização automática e compilação XLA ao nosso código. Isso pode ser feito facilmente com a ajuda de
vmap
ejit
anotações. -
Além disso, temos que permitir a diferenciação automática, que pode ser realizada com o
grad_fn
transformação -
Nós usamos o
flax.optim
Pacote para algoritmos de otimização
Outra pequena diferença que precisamos estar ciente é como passamos dados para o nosso modelo. Isso pode ser alcançado através do método de aplicação na forma de model().apply({'params': params}, batch, z_rng)
onde batch
são nossos dados de treinamento.
@jax.vmap
def kl_divergence(mean, logvar):
return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
@jax.jit
def train_step(optimizer, batch, z_rng):
def loss_fn(params):
recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss, recon_x
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
_, grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer
def kl_divergence(mean, logvar):
return -0.5 * tf.reduce_sum(
1 + logvar - tf.square(mean) -
tf.exp(logvar), axis=1)
def binary_cross_entropy_with_logits(logits, labels):
logits = tf.math.log(logits)
return - tf.reduce_sum(
labels * logits +
(1-labels) * tf.math.log(- tf.math.expm1(logits)),
axis=1
)
@tf.function
def train_step(model, x, optimizer):
with tf.GradientTape() as tape:
recon_x, mean, logvar = model(x)
bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))
kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))
loss = bce_loss + kld_loss
print(loss, kld_loss, bce_loss)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
def final_loss(reconstruction, train_x, mu, logvar):
BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train_step(train_x):
train_x = torch.from_numpy(train_x)
optimizer.zero_grad()
reconstruction, mu, logvar = model(train_x)
loss = final_loss(reconstruction, train_x, mu, logvar)
running_loss += loss.item()
loss.backward()
optimizer.step()
Lembre -se de que os VAEs são treinados ao maximizar a evidência limite inferior, conhecido como Elbo.
Loop de treinamento
Finalmente, é hora de todo o ciclo de treinamento que executará o train_step
função iterativamente.
Em linho, o modelo deve ser inicializado antes do treinamento, o que é feito pelo init
função como: params = model().init(key, init_data, rng)('params')
. Uma inicialização semelhante também é necessária para o otimizador: optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params )
.
jax.device_put
é usado para transferir o otimizador para a memória da GPU.
rng = random.PRNGKey(0)
rng, key = random.split(rng)
init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
params = model().init(key, init_data, rng)('params')
optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)
optimizer = jax.device_put(optimizer)
rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, LATENTS))
steps_per_epoch = 50000 // BATCH_SIZE
for epoch in range(NUM_EPOCHS):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
optimizer = train_step(optimizer, batch, key)
vae = VAE(latent_dim=LATENTS)
optimizer = tf.keras.optimizers.Adam(1e-4)
for epoch in range(NUM_EPOCHS):
for train_x in train_ds:
train_step(vae, train_x, optimizer)
def train(model,training_data):
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
running_loss = 0.0
for epoch in range(NUM_EPOCHS):
for i, train_x in enumerate(training_data, 0):
train_step(train_x)
vae = VAE(LATENTS)
train(vae, train_ds)
Carregar e processar dados
Uma coisa que não mencionei são dados. Como carregamos dados de pré -processamento em linho? Bem, o linho não inclui pacotes de manipulação de dados, além das operações básicas de jax.numpy
. No momento, o nosso melhor é emprestar pacotes de outras estruturas, como o TensorFlow DataSets (TFDS) ou a Torchvision. Para tornar o artigo auto-preenchido, incluirei o código que usei para carregar um conjunto de dados de treinamento de amostra com tfds
. Sinta -se à vontade para usar seu próprio Dataloader se você planeja executar as implementações apresentadas neste artigo.
import tensorflow_datasets as tfds
tf.config.experimental.set_visible_devices((), 'GPU')
def prepare_image(x):
x = tf.cast(x('image'), tf.float32)
x = tf.reshape(x, (-1,))
return x
ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))
test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)(0))
Observações finais
Para encerrar o artigo, vamos discutir algumas observações finais que aparecem após uma análise estreita do código:
-
Todas as três estruturas reduziram o código da caldeira ao mínimo, com o linho sendo aquele que requer um pouco mais, especialmente na parte de treinamento. No entanto, isso é apenas para garantir que exploremos todas as transformações disponíveis, como diferenciação automática, vetorização e compilador just-in-time.
-
A definição de módulos, camadas e modelos é quase idêntica em todos eles
-
Linho e Jax são por design bastante flexíveis e expansíveis
-
O linho não possui recursos de carregamento e processamento de dados ainda
-
Em termos de camadas e otimizadores prontos para uso, o linho não precisa ter ciúmes do Tensorflow e Pytorch. Com certeza não possui a biblioteca gigante de seus concorrentes, mas está gradualmente chegando lá.
* Divulgação: Observe que alguns dos links acima podem ser links de afiliados e, sem custo adicional, ganharemos uma comissão se você decidir fazer uma compra depois de clicar.

Luis es un experto en Ciberseguridad, Computación en la Nube, Criptomonedas e Inteligencia Artificial. Con amplia experiencia en tecnología, su objetivo es compartir conocimientos prácticos para ayudar a los lectores a entender y aprovechar estas áreas digitales clave.