Posted in

Construa um transformador em Jax do zero: como escrever e treinar seus próprios modelos

Construa um transformador em Jax do zero: como escrever e treinar seus próprios modelos

Neste tutorial, exploraremos como desenvolver uma rede neural (NN) com o JAX. E que modelo melhor escolher do que o Transformador. À medida que Jax está crescendo em popularidade, mais e mais equipes de desenvolvedores estão começando a experimentá -lo e incorporando -a em seus projetos. Apesar do fato de não ter a maturidade do Tensorflow ou Pytorch, ele fornece ótimos recursos para criar e treinar modelos de aprendizado profundo.

Para uma sólida compreensão do Basics Jax, verifique meu Artigo anterior Se você ainda não o fez. Também você pode encontrar o código completo em nosso Repositório do GitHub.

Um dos problemas comuns que as pessoas têm ao começar com Jax é a escolha de uma estrutura. As pessoas em DeepMind parecem estar muito ocupadas e já lançaram uma infinidade de estruturas no topo de Jax. Aqui está uma lista dos mais famosos:

  • Haiku: Haiku é a estrutura pretendida para o aprendizado profundo e é usado por muitas equipes internas do Google e DeepMind. Ele fornece algumas abstrações simples e compostas para pesquisas de aprendizado de máquina, bem como módulos e camadas prontos para uso.

  • Optax: Optax é uma biblioteca de processamento e otimização de gradiente que contém otimizadores prontos para uso e operações matemáticas relacionadas.

  • Rlax: RLAX é uma estrutura de aprendizado de reforço com muitos subcomponentes e operações da RL.

  • Chex: Chex é uma biblioteca de utilitários para testar e depurar o código JAX.

  • Jrph: Jraph é uma biblioteca de redes neurais gráficas em Jax.

  • Linho: O linho é outra biblioteca de rede neural com uma variedade de módulos, otimizadores e utilitários prontos para uso. É provavelmente o mais próximo que temos em uma estrutura JAX All-In.

  • Objax: Objax é uma terceira biblioteca de ML que se concentra na programação e legibilidade do código orientadas a objetos. Mais uma vez ele contém os módulos mais populares, funções de ativação, perdas, otimizadores e um punhado de modelos pré-treinados.

  • Trax: Trax é uma biblioteca de ponta a ponta para aprendizado profundo que se concentra nos transformadores

  • Jaxline: Jaxline é uma biblioteca de aprendizado supervisionado que é usado para Treinamento Jax distribuído e avaliação.

  • ACME: ACME é outra estrutura de pesquisa para a aprendizagem de reforço.

  • JAX-MD: JAX-MD é uma estrutura de nicho que lida com a dinâmica molecular.

  • Jachchem: Jaxchem é outra biblioteca de nicho que enfatiza a modelagem química.

Claro, a questão é qual eu escolho?

Para ser sincero, não tenho certeza.

Mas se eu fosse você e quisesse aprender Jax, começaria com os mais populares. Haiku e linho parecem ser muito usados ​​no Google/DeepMind e têm a comunidade mais ativa do Github. Para este artigo, começarei com o primeiro e verei se vou precisar de outro no futuro.

Então você está pronto para construir um transformador com Jax e Haiku? A propósito, presumo que você tenha um sólido entendimento dos transformadores. Se não o fizer, informe nossos artigos sobre atenção e transformadores.

Vamos começar com o bloco de auto-atuação.

O bloco de auto-ataque

Primeiro, precisamos importar Jax e Haiku

import jax

import jax.numpy as jnp

import haiku as hk

Import numpy as np

Felizmente para nós, Haiku tem um embutido MultiHeadAttention Bloco que pode ser estendido para construir um bloco de auto-ataque mascarado. Nosso bloco aceita a consulta, a chave, o valor e a máscara e retorna a saída como uma matriz JAX. Você pode ver que o código está muito familiarizado com o código Pytorch ou Tensorflow padrão. Tudo o que fazemos é construir a máscara causal, usando np.trill()que anulam todos hk.MultiHeadAttention módulo.

class SelfAttention(hk.MultiHeadAttention):

"""Self attention with a causal mask applied."""

def __call__(

self,

query: jnp.ndarray,

key: Optional(jnp.ndarray) = None,

value: Optional(jnp.ndarray) = None,

mask: Optional(jnp.ndarray) = None,

) -> jnp.ndarray:

key = key if key is not None else query

value = value if value is not None else query

seq_len = query.shape(1)

causal_mask = np.tril(np.ones((seq_len, seq_len)))

mask = mask * causal_mask if mask is not None else causal_mask

return super().__call__(query, key, value, mask)

Este trecho me permite apresentar o primeiro princípio -chave do haiku. Todos os módulos devem ser uma subclasse de hk.Module. Isso significa que eles devem implementar __init__ e __call__juntamente com qualquer outro método. Em certo sentido, é a mesma arquitetura com módulos pytorch, onde implementamos um __init__ e a forward.

Para deixar isso claro, vamos construir uma simples multilayerperceptron de 2 camadas como um hk.Moduleque convenientemente será usado no transformador abaixo.

A camada linear

Um MLP simples de 2 camadas ficará assim. Mais uma vez, você pode notar como é familiar.

class DenseBlock(hk.Module):

"""A 2-layer MLP"""

def __init__(self,

init_scale: float,

widening_factor: int = 4,

name: Optional(str) = None):

super().__init__(name=name)

self._init_scale = init_scale

self._widening_factor = widening_factor

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:

hiddens = x.shape(-1)

initializer = hk.initializers.VarianceScaling(self._init_scale)

x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)

x = jax.nn.gelu(x)

return hk.Linear(hiddens, w_init=initializer)(x)

Algumas coisas a perceber aqui:

  • Haiku nos fornece um conjunto de inicializadores de pesos sob hk.initializersonde podemos encontrar o máximo abordagens comuns.

  • Também possui muitas camadas e módulos populares, como hk.Linear. Para a lista completa, dê uma olhada no documentação oficial.

  • As funções de ativação não são fornecidas porque Jax já tem um subpackage chamado jax.nnonde podemos encontrar Funções de ativação como relu ou softmax.

A camada de normalização

A normalização da camada é outro bloco integral da arquitetura do transformador, que também podemos encontrar nos módulos comuns dentro do haiku.

def layer_norm(x: jnp.ndarray, name: Optional(str) = None) -> jnp.ndarray:

"""Apply a unique LayerNorm to x with default settings."""

return hk.LayerNorm(axis=-1,

create_scale=True,

create_offset=True,

name=name)(x)

O transformador

E agora para as coisas boas. Abaixo, você pode encontrar um transformador muito simplista, que utiliza nossos módulos predefinidos. Dentro __init__definimos as variáveis ​​básicas, como o número de camadas, cabeças de atenção e a taxa de abandono. Dentro __call__comemos uma lista de blocos usando um for laço.

Como você pode ver, cada bloco inclui:

No final, também adicionamos uma camada de normalização final.

class Transformer(hk.Module):

"""A transformer stack."""

def __init__(self,

num_heads: int,

num_layers: int,

dropout_rate: float,

name: Optional(str) = None):

super().__init__(name=name)

self._num_layers = num_layers

self._num_heads = num_heads

self._dropout_rate = dropout_rate

def __call__(self,

h: jnp.ndarray,

mask: Optional(jnp.ndarray),

is_training: bool) -> jnp.ndarray:

"""Connects the transformer.

Args:

h: Inputs, (B, T, H).

mask: Padding mask, (B, T).

is_training: Whether we're training or not.

Returns:

Array of shape (B, T, H).

"""

init_scale = 2. / self._num_layers

dropout_rate = self._dropout_rate if is_training else 0.

if mask is not None:

mask = mask(:, None, None, :)

for i in range(self._num_layers):

h_norm = layer_norm(h, name=f'h{i}_ln_1')

h_attn = SelfAttention(

num_heads=self._num_heads,

key_size=64,

w_init_scale=init_scale,

name=f'h{i}_attn')(h_norm, mask=mask)

h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)

h = h + h_attn

h_norm = layer_norm(h, name=f'h{i}_ln_2')

h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)

h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)

h = h + h_dense

h = layer_norm(h, name='ln_f')

return h

Eu acho que agora você já percebeu que a construção de uma rede neural com Jax é morta simples.

A camada de incorporação

Para conclusão, vamos também incluir a camada de incorporação. É bom saber que o Haiku também fornece uma camada de incorporação que criará os tokens a partir de nossa frase de entrada. O token é então adicionado ao incorporações posicionaisque produz a entrada final.

def embeddings(data: Mapping(str, jnp.ndarray), vocab_size: int) :

tokens = data('obs')

input_mask = jnp.greater(tokens, 0)

seq_length = tokens.shape(1)

embed_init = hk.initializers.TruncatedNormal(stddev=0.02)

token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)

token_embs = token_embedding_map(tokens)

positional_embeddings = hk.get_parameter(

'pos_embs', (seq_length, d_model), init=embed_init)

input_embeddings = token_embs + positional_embeddings

return input_embeddings, input_mask

hk.get_parameter(param_name, ...) é usado para acessar os parâmetros treináveis ​​de um módulo. Mas você pode perguntar, por que não apenas usar as propriedades do objeto como em Pytorch. É aqui que o segundo princípio -chave de Haiku entra em jogo. Usamos esta API para que possamos converter o código em uma função pura usando hk.transform. Isso não é muito simples de entender, mas tentarei deixá -lo o mais claro possível.

Por que funções puras?

O poder de Jax entra em sua função Transformações: a capacidade de vetorizar uma função com vmapa paralelização automática com pmapbem no tempo compilação com jit. A ressalva aqui é que, para transformar uma função, ela precisa ser pura.

UM função pura é uma função que possui as seguintes propriedades:

  • Os valores de retorno da função são idênticos para argumentos idênticos (nenhuma variação com variáveis ​​estáticas locais, variáveis ​​não locais, argumentos de referência mutável ou fluxos de entrada).

  • O aplicativo de função não tem efeitos colaterais (nenhuma mutação de variáveis ​​estáticas locais, variáveis ​​não locais, argumentos de referência mutável ou fluxos de entrada/saída).




Construa um transformador em Jax do zero: como escrever e treinar seus próprios modelos


Fonte: Funções puras da Scala

Isso praticamente significa que uma função pura sempre será:

  • Retorne o mesmo resultado se invocado com as mesmas entradas

  • Todos os dados de entrada são passados ​​através dos argumentos da função, todos os resultados são emitidos através dos resultados da função

Haiku fornece uma transformação de função, chamada hk.transformisso transforma funções com módulos orientados a objetos e funcionalmente “impuros” em funções puras que podem ser usadas com o JAX. Para ver isso na prática, vamos continuar com o treinamento do nosso modelo de transformador.

O passe para a frente

Um passe para a frente típico inclui:

  1. Pegando a entrada e calcula a incorporação de entrada

  2. Percorrer os blocos do transformador

  3. Retornar a saída

As etapas acima mencionadas podem ser facilmente compostas com Jax como seguinte:

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,

num_layers: int, dropout_rate: float):

"""Create the model's forward pass."""

def forward_fn(data: Mapping(str, jnp.ndarray),

is_training: bool = True) -> jnp.ndarray:

"""Forward pass."""

input_embeddings, input_mask = embeddings(data, vocab_size)

transformer = Transformer(

num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)

output_embeddings = transformer(input_embeddings, input_mask, is_training)

return hk.Linear(vocab_size)(output_embeddings)

return forward_fn

Embora o código seja direto, sua estrutura pode parecer um pouco estranha. O passe para a frente real é executado através do forward_fn função. No entanto, embrulhamos isso com o build_forward_fn função que retorna o forward_fn. Que diabos?

No caminho, precisaremos transformar o forward_fn função em uma função pura usando hk.transform para que possamos aproveitar a diferenciação automática, a paralelização etc.

Isso será realizado por:

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,

num_layers, dropout_rate)

forward_fn = hk.transform(forward_fn)

É por isso que, em vez de simplesmente definir uma função, embrulhamos e retornamos a própria função, ou um chamável para ser mais preciso. Este chamável pode então ser passado para o hk.transform e se tornar uma função pura. Se isso estiver claro, vamos continuar com nossa função de perda.

A função de perda

A função de perda é a nossa conhecida função entre entropia, com a diferença de que também estamos levando a máscara em consideração. Mais uma vez, Jax fornece one_hot e log_softmax funcionalidades.

def lm_loss_fn(forward_fn,

vocab_size: int,

params,

rng,

data: Mapping(str, jnp.ndarray),

is_training: bool = True) -> jnp.ndarray:

"""Compute the loss on data wrt params."""

logits = forward_fn(params, rng, data, is_training)

targets = jax.nn.one_hot(data('target'), vocab_size)

assert logits.shape == targets.shape

mask = jnp.greater(data('obs'), 0)

loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)

loss = jnp.sum(loss * mask) / jnp.sum(mask)

return loss

Se você ainda estiver comigo, tome um gole de café, porque as coisas vão levar a sério a partir de agora. É hora de construir nosso ciclo de treinamento.

O loop de treinamento

Como nem Jax nem Haiku têm funcionalidades de otimização embutidas, faremos uso de outra estrutura, chamada Optax. Como mencionado no início, o Optax é o pacote GoTo para processamento de gradiente.

Primeiro, aqui estão algumas coisas que você precisa saber sobre Optax:

A principal transformação do optax é o GradientTransformation. A transformação é definida por duas funções, o __init__ e o __update__. O __init__ inicializa o estado e o __update__ transforma os gradientes em relação ao estado e ao valor atual dos parâmetros

state = init(params)

grads, state = update(grads, state, params=None)

Mais uma coisa a saber antes de vermos o código é embutido de Python functools.partial função. O functools O pacote lida com funções e operações de ordem superior em objetos chamáveis.

Uma função é chamada de função de ordem superior se contiver outras funções como um parâmetro ou retornar uma função como uma saída.

O partialque também pode ser usado como anotação, retorna uma nova função com base em uma original, mas com menos ou argumentos fixos. Se, por exemplo, F multiplicar dois valores x, y, o parcial criará uma nova função em que x será fixo e igual a 2

from functools import partial

def f(x,y):

return x * y

g = partial(f,2)

print(g(4))

Após esse pequeno desvio, vamos prosseguir. Para descongestionar nosso main Função, extrairemos a atualização dos gradientes em sua própria classe.

Primeiro de tudo GradientUpdater Aceita o modelo, a função de perda e um otimizador.

  1. O modelo será um puro forward_fn função transformada por hk.transform

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,

num_layers, dropout_rate)

forward_fn = hk.transform(forward_fn)

  1. A função de perda será o resultado de um parcial com um fixo forward_fn e `vocab_size

loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

  1. O otimizador é um conjunto de transformações de otimização que serão executadas sequencialmente (as operações podem ser combinadas usando optax.chain )

optimizer = optax.chain(

optax.clip_by_global_norm(grad_clip_value),

optax.adam(learning_rate, b1=0.9, b2=0.99))

O atualizador do gradiente será inicializado da seguinte forma:

updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

e vai ficar assim:

class GradientUpdater:

"""A stateless abstraction around an init_fn/update_fn pair.

This extracts some common boilerplate from the training loop.

"""

def __init__(self, net_init, loss_fn,

optimizer: optax.GradientTransformation):

self._net_init = net_init

self._loss_fn = loss_fn

self._opt = optimizer

@functools.partial(jax.jit, static_argnums=0)

def init(self, master_rng, data):

"""Initializes state of the updater."""

out_rng, init_rng = jax.random.split(master_rng)

params = self._net_init(init_rng, data)

opt_state = self._opt.init(params)

out = dict(

step=np.array(0),

rng=out_rng,

opt_state=opt_state,

params=params,

)

return out

@functools.partial(jax.jit, static_argnums=0)

def update(self, state: Mapping(str, Any), data: Mapping(str, jnp.ndarray)):

"""Updates the state using some data and returns metrics."""

rng, new_rng = jax.random.split(state('rng'))

params = state('params')

loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)

updates, opt_state = self._opt.update(g, state('opt_state'))

params = optax.apply_updates(params, updates)

new_state = {

'step': state('step') + 1,

'rng': new_rng,

'opt_state': opt_state,

'params': params,

}

metrics = {

'step': state('step'),

'loss': loss,

}

return new_state, metrics

Dentro __init__inicializamos nosso otimizador com self._opt.init(params) e declaramos o estado da otimização. O estado será um dicionário com:

O update A função atualizará o estado do otimizador e os parâmetros treináveis. No final, ele retornará o novo estado.

updates, opt_state = self._opt.update(g, state('opt_state'))

params = optax.apply_updates(params, updates)

Mais duas coisas a notar aqui:

  • jax.value_and_grad() é a função especial que retorna uma função diferenciável com seus gradientes

  • Ambos __init__ e __update__ são anotados com @functools.partial(jax.jit, static_argnums=0)que acionará o compilador just-in-time e os compilará no XLA durante o tempo de execução. Observe que se não transformarmosforward_fn Em uma função pura, isso não seria possível.

Finalmente, estamos prontos para construir todo o ciclo de treinamento, que combina todas as idéias e código mencionados até agora.

def main():

train_dataset, vocab_size = load(batch_size,

sequence_length)

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,

num_layers, dropout_rate)

forward_fn = hk.transform(forward_fn)

loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

optimizer = optax.chain(

optax.clip_by_global_norm(grad_clip_value),

optax.adam(learning_rate, b1=0.9, b2=0.99))

updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

logging.info('Initializing parameters...')

rng = jax.random.PRNGKey(428)

data = next(train_dataset)

state = updater.init(rng, data)

logging.info('Starting train loop...')

prev_time = time.time()

for step in range(MAX_STEPS):

data = next(train_dataset)

state, metrics = updater.update(state, data)

Observe como incorporamos o GradientUpdate. São apenas duas linhas de código:

  • state = updater.init(rng, data)

  • state, metrics = updater.update(state, data)

E é isso. Espero que agora você tenha um entendimento mais claro do JAX e de suas capacidades.

Agradecimentos

O código apresentado é fortemente inspirado nos exemplos oficiais da estrutura do Haiku. Foi modificado para atender às necessidades deste artigo. Para a lista completa de exemplos, verifique o repositório oficial

Conclusão

Neste artigo, vimos como se pode desenvolver e treinar um transformador de baunilha em Jax usando o Haiku. Embora o código não seja necessariamente difícil de entender, ele ainda não possui a legibilidade de pytorch ou tensorflow. Eu recomendo brincar com ele, descobrir os pontos fortes e fracos de Jax e ver se seria uma boa opção para o seu próximo projeto. Na minha experiência, o JAX é muito forte para aplicativos de pesquisa que exigem alto desempenho, mas bastante imaturos para projetos da vida real. Deixe -nos saber o que você pensa em nosso Discord Channel.

Aprendizagem profunda no livro de produção 📖

Aprenda a construir, treinar, implantar, escalar e manter modelos de aprendizado profundo. Entenda a infraestrutura de ML e os MLOPs usando exemplos práticos.

Saber mais

* Divulgação: Observe que alguns dos links acima podem ser links de afiliados e, sem custo adicional, ganharemos uma comissão se você decidir fazer uma compra depois de clicar.

Luis es un experto en Ciberseguridad, Computación en la Nube, Criptomonedas e Inteligencia Artificial. Con amplia experiencia en tecnología, su objetivo es compartir conocimientos prácticos para ayudar a los lectores a entender y aprovechar estas áreas digitales clave.

Leave a Reply

Your email address will not be published. Required fields are marked *

lodi646