Neste tutorial, exploraremos como desenvolver uma rede neural (NN) com o JAX. E que modelo melhor escolher do que o Transformador. À medida que Jax está crescendo em popularidade, mais e mais equipes de desenvolvedores estão começando a experimentá -lo e incorporando -a em seus projetos. Apesar do fato de não ter a maturidade do Tensorflow ou Pytorch, ele fornece ótimos recursos para criar e treinar modelos de aprendizado profundo.
Para uma sólida compreensão do Basics Jax, verifique meu Artigo anterior Se você ainda não o fez. Também você pode encontrar o código completo em nosso Repositório do GitHub.
Um dos problemas comuns que as pessoas têm ao começar com Jax é a escolha de uma estrutura. As pessoas em DeepMind parecem estar muito ocupadas e já lançaram uma infinidade de estruturas no topo de Jax. Aqui está uma lista dos mais famosos:
-
Haiku: Haiku é a estrutura pretendida para o aprendizado profundo e é usado por muitas equipes internas do Google e DeepMind. Ele fornece algumas abstrações simples e compostas para pesquisas de aprendizado de máquina, bem como módulos e camadas prontos para uso.
-
Optax: Optax é uma biblioteca de processamento e otimização de gradiente que contém otimizadores prontos para uso e operações matemáticas relacionadas.
-
Rlax: RLAX é uma estrutura de aprendizado de reforço com muitos subcomponentes e operações da RL.
-
Chex: Chex é uma biblioteca de utilitários para testar e depurar o código JAX.
-
Jrph: Jraph é uma biblioteca de redes neurais gráficas em Jax.
-
Linho: O linho é outra biblioteca de rede neural com uma variedade de módulos, otimizadores e utilitários prontos para uso. É provavelmente o mais próximo que temos em uma estrutura JAX All-In.
-
Objax: Objax é uma terceira biblioteca de ML que se concentra na programação e legibilidade do código orientadas a objetos. Mais uma vez ele contém os módulos mais populares, funções de ativação, perdas, otimizadores e um punhado de modelos pré-treinados.
-
Trax: Trax é uma biblioteca de ponta a ponta para aprendizado profundo que se concentra nos transformadores
-
Jaxline: Jaxline é uma biblioteca de aprendizado supervisionado que é usado para Treinamento Jax distribuído e avaliação.
-
ACME: ACME é outra estrutura de pesquisa para a aprendizagem de reforço.
-
JAX-MD: JAX-MD é uma estrutura de nicho que lida com a dinâmica molecular.
-
Jachchem: Jaxchem é outra biblioteca de nicho que enfatiza a modelagem química.
Claro, a questão é qual eu escolho?
Para ser sincero, não tenho certeza.
Mas se eu fosse você e quisesse aprender Jax, começaria com os mais populares. Haiku e linho parecem ser muito usados no Google/DeepMind e têm a comunidade mais ativa do Github. Para este artigo, começarei com o primeiro e verei se vou precisar de outro no futuro.
Então você está pronto para construir um transformador com Jax e Haiku? A propósito, presumo que você tenha um sólido entendimento dos transformadores. Se não o fizer, informe nossos artigos sobre atenção e transformadores.
Vamos começar com o bloco de auto-atuação.
O bloco de auto-ataque
Primeiro, precisamos importar Jax e Haiku
import jax
import jax.numpy as jnp
import haiku as hk
Import numpy as np
Felizmente para nós, Haiku tem um embutido MultiHeadAttention
Bloco que pode ser estendido para construir um bloco de auto-ataque mascarado. Nosso bloco aceita a consulta, a chave, o valor e a máscara e retorna a saída como uma matriz JAX. Você pode ver que o código está muito familiarizado com o código Pytorch ou Tensorflow padrão. Tudo o que fazemos é construir a máscara causal, usando np.trill()
que anulam todos hk.MultiHeadAttention
módulo.
class SelfAttention(hk.MultiHeadAttention):
"""Self attention with a causal mask applied."""
def __call__(
self,
query: jnp.ndarray,
key: Optional(jnp.ndarray) = None,
value: Optional(jnp.ndarray) = None,
mask: Optional(jnp.ndarray) = None,
) -> jnp.ndarray:
key = key if key is not None else query
value = value if value is not None else query
seq_len = query.shape(1)
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(query, key, value, mask)
Este trecho me permite apresentar o primeiro princípio -chave do haiku. Todos os módulos devem ser uma subclasse de hk.Module
. Isso significa que eles devem implementar __init__
e __call__
juntamente com qualquer outro método. Em certo sentido, é a mesma arquitetura com módulos pytorch, onde implementamos um __init__
e a forward
.
Para deixar isso claro, vamos construir uma simples multilayerperceptron de 2 camadas como um hk.Module
que convenientemente será usado no transformador abaixo.
A camada linear
Um MLP simples de 2 camadas ficará assim. Mais uma vez, você pode notar como é familiar.
class DenseBlock(hk.Module):
"""A 2-layer MLP"""
def __init__(self,
init_scale: float,
widening_factor: int = 4,
name: Optional(str) = None):
super().__init__(name=name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hiddens = x.shape(-1)
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
x = jax.nn.gelu(x)
return hk.Linear(hiddens, w_init=initializer)(x)
Algumas coisas a perceber aqui:
-
Haiku nos fornece um conjunto de inicializadores de pesos sob
hk.initializers
onde podemos encontrar o máximo abordagens comuns. -
Também possui muitas camadas e módulos populares, como
hk.Linear
. Para a lista completa, dê uma olhada no documentação oficial. -
As funções de ativação não são fornecidas porque Jax já tem um subpackage chamado
jax.nn
onde podemos encontrar Funções de ativação comorelu
ousoftmax
.
A camada de normalização
A normalização da camada é outro bloco integral da arquitetura do transformador, que também podemos encontrar nos módulos comuns dentro do haiku.
def layer_norm(x: jnp.ndarray, name: Optional(str) = None) -> jnp.ndarray:
"""Apply a unique LayerNorm to x with default settings."""
return hk.LayerNorm(axis=-1,
create_scale=True,
create_offset=True,
name=name)(x)
O transformador
E agora para as coisas boas. Abaixo, você pode encontrar um transformador muito simplista, que utiliza nossos módulos predefinidos. Dentro __init__
definimos as variáveis básicas, como o número de camadas, cabeças de atenção e a taxa de abandono. Dentro __call__
comemos uma lista de blocos usando um for
laço.
Como você pode ver, cada bloco inclui:
No final, também adicionamos uma camada de normalização final.
class Transformer(hk.Module):
"""A transformer stack."""
def __init__(self,
num_heads: int,
num_layers: int,
dropout_rate: float,
name: Optional(str) = None):
super().__init__(name=name)
self._num_layers = num_layers
self._num_heads = num_heads
self._dropout_rate = dropout_rate
def __call__(self,
h: jnp.ndarray,
mask: Optional(jnp.ndarray),
is_training: bool) -> jnp.ndarray:
"""Connects the transformer.
Args:
h: Inputs, (B, T, H).
mask: Padding mask, (B, T).
is_training: Whether we're training or not.
Returns:
Array of shape (B, T, H).
"""
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.
if mask is not None:
mask = mask(:, None, None, :)
for i in range(self._num_layers):
h_norm = layer_norm(h, name=f'h{i}_ln_1')
h_attn = SelfAttention(
num_heads=self._num_heads,
key_size=64,
w_init_scale=init_scale,
name=f'h{i}_attn')(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
h_norm = layer_norm(h, name=f'h{i}_ln_2')
h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = layer_norm(h, name='ln_f')
return h
Eu acho que agora você já percebeu que a construção de uma rede neural com Jax é morta simples.
A camada de incorporação
Para conclusão, vamos também incluir a camada de incorporação. É bom saber que o Haiku também fornece uma camada de incorporação que criará os tokens a partir de nossa frase de entrada. O token é então adicionado ao incorporações posicionaisque produz a entrada final.
def embeddings(data: Mapping(str, jnp.ndarray), vocab_size: int) :
tokens = data('obs')
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape(1)
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
'pos_embs', (seq_length, d_model), init=embed_init)
input_embeddings = token_embs + positional_embeddings
return input_embeddings, input_mask
hk.get_parameter(param_name, ...)
é usado para acessar os parâmetros treináveis de um módulo. Mas você pode perguntar, por que não apenas usar as propriedades do objeto como em Pytorch. É aqui que o segundo princípio -chave de Haiku entra em jogo. Usamos esta API para que possamos converter o código em uma função pura usando hk.transform
. Isso não é muito simples de entender, mas tentarei deixá -lo o mais claro possível.
Por que funções puras?
O poder de Jax entra em sua função Transformações: a capacidade de vetorizar uma função com vmap
a paralelização automática com pmap
bem no tempo compilação com jit
. A ressalva aqui é que, para transformar uma função, ela precisa ser pura.
UM função pura é uma função que possui as seguintes propriedades:
-
Os valores de retorno da função são idênticos para argumentos idênticos (nenhuma variação com variáveis estáticas locais, variáveis não locais, argumentos de referência mutável ou fluxos de entrada).
-
O aplicativo de função não tem efeitos colaterais (nenhuma mutação de variáveis estáticas locais, variáveis não locais, argumentos de referência mutável ou fluxos de entrada/saída).
Fonte: Funções puras da Scala
Isso praticamente significa que uma função pura sempre será:
-
Retorne o mesmo resultado se invocado com as mesmas entradas
-
Todos os dados de entrada são passados através dos argumentos da função, todos os resultados são emitidos através dos resultados da função
Haiku fornece uma transformação de função, chamada hk.transform
isso transforma funções com módulos orientados a objetos e funcionalmente “impuros” em funções puras que podem ser usadas com o JAX. Para ver isso na prática, vamos continuar com o treinamento do nosso modelo de transformador.
O passe para a frente
Um passe para a frente típico inclui:
-
Pegando a entrada e calcula a incorporação de entrada
-
Percorrer os blocos do transformador
-
Retornar a saída
As etapas acima mencionadas podem ser facilmente compostas com Jax como seguinte:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
num_layers: int, dropout_rate: float):
"""Create the model's forward pass."""
def forward_fn(data: Mapping(str, jnp.ndarray),
is_training: bool = True) -> jnp.ndarray:
"""Forward pass."""
input_embeddings, input_mask = embeddings(data, vocab_size)
transformer = Transformer(
num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
Embora o código seja direto, sua estrutura pode parecer um pouco estranha. O passe para a frente real é executado através do forward_fn
função. No entanto, embrulhamos isso com o build_forward_fn
função que retorna o forward_fn
. Que diabos?
No caminho, precisaremos transformar o forward_fn
função em uma função pura usando hk.transform
para que possamos aproveitar a diferenciação automática, a paralelização etc.
Isso será realizado por:
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
É por isso que, em vez de simplesmente definir uma função, embrulhamos e retornamos a própria função, ou um chamável para ser mais preciso. Este chamável pode então ser passado para o hk.transform
e se tornar uma função pura. Se isso estiver claro, vamos continuar com nossa função de perda.
A função de perda
A função de perda é a nossa conhecida função entre entropia, com a diferença de que também estamos levando a máscara em consideração. Mais uma vez, Jax fornece one_hot
e log_softmax
funcionalidades.
def lm_loss_fn(forward_fn,
vocab_size: int,
params,
rng,
data: Mapping(str, jnp.ndarray),
is_training: bool = True) -> jnp.ndarray:
"""Compute the loss on data wrt params."""
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data('target'), vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data('obs'), 0)
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
Se você ainda estiver comigo, tome um gole de café, porque as coisas vão levar a sério a partir de agora. É hora de construir nosso ciclo de treinamento.
O loop de treinamento
Como nem Jax nem Haiku têm funcionalidades de otimização embutidas, faremos uso de outra estrutura, chamada Optax. Como mencionado no início, o Optax é o pacote GoTo para processamento de gradiente.
Primeiro, aqui estão algumas coisas que você precisa saber sobre Optax:
A principal transformação do optax é o GradientTransformation
. A transformação é definida por duas funções, o __init__
e o __update__
. O __init__
inicializa o estado e o __update__
transforma os gradientes em relação ao estado e ao valor atual dos parâmetros
state = init(params)
grads, state = update(grads, state, params=None)
Mais uma coisa a saber antes de vermos o código é embutido de Python functools.partial
função. O functools
O pacote lida com funções e operações de ordem superior em objetos chamáveis.
Uma função é chamada de função de ordem superior se contiver outras funções como um parâmetro ou retornar uma função como uma saída.
O partial
que também pode ser usado como anotação, retorna uma nova função com base em uma original, mas com menos ou argumentos fixos. Se, por exemplo, F multiplicar dois valores x, y, o parcial criará uma nova função em que x será fixo e igual a 2
from functools import partial
def f(x,y):
return x * y
g = partial(f,2)
print(g(4))
Após esse pequeno desvio, vamos prosseguir. Para descongestionar nosso main
Função, extrairemos a atualização dos gradientes em sua própria classe.
Primeiro de tudo GradientUpdater
Aceita o modelo, a função de perda e um otimizador.
- O modelo será um puro
forward_fn
função transformada porhk.transform
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
- A função de perda será o resultado de um parcial com um fixo
forward_fn
e `vocab_size
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
- O otimizador é um conjunto de transformações de otimização que serão executadas sequencialmente (as operações podem ser combinadas usando
optax.chain
)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
O atualizador do gradiente será inicializado da seguinte forma:
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
e vai ficar assim:
class GradientUpdater:
"""A stateless abstraction around an init_fn/update_fn pair.
This extracts some common boilerplate from the training loop.
"""
def __init__(self, net_init, loss_fn,
optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, master_rng, data):
"""Initializes state of the updater."""
out_rng, init_rng = jax.random.split(master_rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(
step=np.array(0),
rng=out_rng,
opt_state=opt_state,
params=params,
)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Mapping(str, Any), data: Mapping(str, jnp.ndarray)):
"""Updates the state using some data and returns metrics."""
rng, new_rng = jax.random.split(state('rng'))
params = state('params')
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state('opt_state'))
params = optax.apply_updates(params, updates)
new_state = {
'step': state('step') + 1,
'rng': new_rng,
'opt_state': opt_state,
'params': params,
}
metrics = {
'step': state('step'),
'loss': loss,
}
return new_state, metrics
Dentro __init__
inicializamos nosso otimizador com self._opt.init(params)
e declaramos o estado da otimização. O estado será um dicionário com:
O update
A função atualizará o estado do otimizador e os parâmetros treináveis. No final, ele retornará o novo estado.
updates, opt_state = self._opt.update(g, state('opt_state'))
params = optax.apply_updates(params, updates)
Mais duas coisas a notar aqui:
-
jax.value_and_grad()
é a função especial que retorna uma função diferenciável com seus gradientes -
Ambos
__init__
e__update__
são anotados com@functools.partial(jax.jit, static_argnums=0)
que acionará o compilador just-in-time e os compilará no XLA durante o tempo de execução. Observe que se não transformarmosforward_fn
Em uma função pura, isso não seria possível.
Finalmente, estamos prontos para construir todo o ciclo de treinamento, que combina todas as idéias e código mencionados até agora.
def main():
train_dataset, vocab_size = load(batch_size,
sequence_length)
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
logging.info('Initializing parameters...')
rng = jax.random.PRNGKey(428)
data = next(train_dataset)
state = updater.init(rng, data)
logging.info('Starting train loop...')
prev_time = time.time()
for step in range(MAX_STEPS):
data = next(train_dataset)
state, metrics = updater.update(state, data)
Observe como incorporamos o GradientUpdate
. São apenas duas linhas de código:
-
state = updater.init(rng, data)
-
state, metrics = updater.update(state, data)
E é isso. Espero que agora você tenha um entendimento mais claro do JAX e de suas capacidades.
Agradecimentos
O código apresentado é fortemente inspirado nos exemplos oficiais da estrutura do Haiku. Foi modificado para atender às necessidades deste artigo. Para a lista completa de exemplos, verifique o repositório oficial
Conclusão
Neste artigo, vimos como se pode desenvolver e treinar um transformador de baunilha em Jax usando o Haiku. Embora o código não seja necessariamente difícil de entender, ele ainda não possui a legibilidade de pytorch ou tensorflow. Eu recomendo brincar com ele, descobrir os pontos fortes e fracos de Jax e ver se seria uma boa opção para o seu próximo projeto. Na minha experiência, o JAX é muito forte para aplicativos de pesquisa que exigem alto desempenho, mas bastante imaturos para projetos da vida real. Deixe -nos saber o que você pensa em nosso Discord Channel.
* Divulgação: Observe que alguns dos links acima podem ser links de afiliados e, sem custo adicional, ganharemos uma comissão se você decidir fazer uma compra depois de clicar.

Luis es un experto en Ciberseguridad, Computación en la Nube, Criptomonedas e Inteligencia Artificial. Con amplia experiencia en tecnología, su objetivo es compartir conocimientos prácticos para ayudar a los lectores a entender y aprovechar estas áreas digitales clave.