Posted in

Jax para aprendizado de máquina: como funciona e por que aprender

Jax para aprendizado de máquina: como funciona e por que aprender

Jax é o novo garoto na cidade de aprendizado de máquina (ML) e promete tornar a programação da ML mais intuitiva, estruturada e limpa. Pode substituir os gostos do TensorFlow e Pytorch, apesar de ser muito diferente em seu núcleo.

Como disse um amigo meu, tínhamos todos os tipos de ases, reis e rainhas. Agora temos Jax.

Neste artigo, exploraremos o que é Jax e por que alguém deve usá -lo em todas as outras bibliotecas. Faremos nossos pontos usando trechos de código que capturam o poder do JAX e apresentaremos alguns recursos bons para saber.

Se isso parece interessante, entre.

O que é Jax?

Jax é uma biblioteca Python projetada para pesquisa de ML de alto desempenho. Jax nada mais é do que uma biblioteca de computação numérica, assim como Numpy, mas com algumas melhorias importantes. Foi desenvolvido pelo Google e usado internamente por equipes do Google e DeepMind.




JAX-LOGO


Fonte: Documentação Jax

Instale Jax

Antes de discutirmos as principais vantagens do JAX, sugiro que você instale o JAX em seu ambiente Python ou em um Google Colab para que você possa acompanhar e executar o código sozinho. Obviamente, deixarei um link para o código completo no final do artigo.

Para instalar Jax, podemos simplesmente usar pip da nossa linha de comando:

$ pip install --upgrade jax jaxlib

Observe que isso apoiará apenas a execução na CPU. Se você também deseja apoiar a GPU, primeiro precisa CUDA e importância e depois execute o seguinte comando (certifique -se de mapear a versão JaxLib com sua versão CUDA):

$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Para solução de problemas, verifique o oficial Instruções do GitHub.

Agora vamos importar Jax ao lado de Numpy. Usaremos o Numpy para comparar diferentes casos de uso.

import jax

import jax.numpy as jnp

import numpy as np

Jax Basics

Vamos começar com o básico. Como já dissemos, o principal e o único objetivo de Jax é executar operações numéricas de maneira expressa e de alto desempenho. Isso significa que a sintaxe é quase idêntica a Numpy. Por exemplo, se queremos criar uma variedade de zeros, teríamos:

x = np.zeros(10)

y= jnp.zeros(10)

A diferença está nos bastidores.

O deviceArray

Você vê uma das principais vantagens de Jax é que Podemos executar o mesmo programa, sem nenhuma mudança, em aceleradores de hardware como GPUs e TPUs.

Isso é realizado por uma estrutura subjacente chamada DeviceArrayque essencialmente substitui Array padrão de Numpy.

Os deviquearys são preguiçosos, o que significa que eles mantêm os valores no acelerador e os puxam apenas Quando necessário.

x

y

Podemos usar deviquears, assim como usamos matrizes padrão. Podemos passar para outras bibliotecas, plotar gráficos, executar a diferenciação e as coisas funcionarão. Observe também que a maioria da API de Numpy (funções e operações) é suportada pelo JAX, portanto, seu código JAX será quase idêntico ao Numpy.

A outra grande coisa é a velocidade. Bem, Jax é mais rápido. Muito mais rápido. Vejamos um exemplo simples. Criamos duas matrizes com tamanho (1000, 1000), um com Numpy e outro com Jax, e calculamos o produto interno consigo mesmo.

Vamos timeit as duas operações

x = np.random.rand(1000,1000)

y = jnp.array(x)

%timeit -n 1 -r 1 np.dot(x,x)

%timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready()

Impressionante certo? Bem, é esperado. Os cálculos são mais rápidos nas GPUs. Você também notou o block_until_ready() função. Como o JAX é assíncrono, precisamos esperar até que a execução seja concluída para medir adequadamente o tempo.

Você não pode acreditar que isso é tudo o que Jax tem a oferecer, certo?

Agora, para as coisas boas …

Por que Jax?

Se a velocidade e o suporte automático para as GPUs não forem suficientes para você, não o culpo. Parece que todas as outras bibliotecas podem lidar com isso. Para entender melhor os benefícios do JAX, temos que mergulhar mais fundo. O JAX pode ser visto como um conjunto de transformações de função do código Python e Numpy regular.

Um exemplo de tais transformações é diferenciação. O JAX suporta a diferenciação automática?

Tenho certeza que você adivinhou corretamente.

Diferenciação automática com função grad ()

Jax é capaz de diferenciar todos os tipos de funções Python e Numpy, incluindo loops, ramificações, recursões e muito mais.

Isso é incrivelmente útil para aplicativos de aprendizagem profunda, pois podemos executar a backpropagação sem esforço. A principal função para realizar isso é chamada grad(). Aqui está um exemplo. Definimos uma função quadrática simples e tomamos sua derivada no ponto 1.0.

Para provar que o resultado está correto, também calcularemos a derivada manualmente.

from jax import grad

def f(x):

return 3*x**2 + 2*x + 5

def f_prime(x):

return 6*x +2

grad(f)(1.0)

f_prime(1.0)

Uma coisa muito surpreendente para mim foi que Jax está realmente fazendo analítico O gradiente resolve sob o capô, em vez de outra técnica sofisticada. Simplesmente assume a forma da função e executa a regra da cadeia. Como a diferenciação automática é muito mais do que isso, eu recomendo olhar para o documentação oficial para um entendimento mais completo.

Álgebra linear acelerada (compilador XLA)

Um dos fatores que tornam o JAX tão rápido também é acelerado em álgebra linear ou XLA.

O XLA é um compilador específico do domínio para álgebra linear que tem sido usada extensivamente pelo TensorFlow.

Para executar operações da matriz o mais rápido possível, o código é compilado em um conjunto de kernels de computação que podem ser extensivamente otimizados com base na natureza do código.

Exemplo de tais otimizações incluem:

Compilação de Just In Time (JIT)

A compilação bem no tempo vem de mãos dadas com o XLA. Para aproveitar o poder do XLA, o código deve ser compilado nos núcleos XLA. É aqui que jit entra em jogo.

A compilação just-in-time (JIT) é uma maneira de executar o código do computador que envolve a compilação durante a execução de um programa-no tempo de execução-e não antes da execução.

Para usar XLA e JIT, pode -se usar o jit() função ou o @jit anotação.

from jax import jit

x = np.random.rand(1000,1000)

y = jnp.array(x)

def f(x):

for _ in range(10):

x = 0.5*x + 0.1* jnp.sin(x)

return x

g = jit(f)

%timeit -n 5 -r 5 f(y).block_until_ready()

%timeit -n 5 -r 5 g(y).block_until_ready()

Mais uma vez, a melhoria no tempo de execução é mais do que óbvia. Claro, jit também pode ser combinado com grad Transformação (ou qualquer outra transformação para esse assunto), tornando a retropacagação super rápida.

Além disso, observe isso jit Tem algumas deficiências: por exemplo, se não puder representar com precisão a função (que geralmente acontece com as ramificações “se”), provavelmente falhará. No entanto, para os casos de maior uso relacionados ao aprendizado profundo, é incrivelmente útil.

Replicar a computação entre os dispositivos com PMAP

O PMAP é outra transformação que nos permite replicar o cálculo em vários núcleos ou dispositivos e executá -los em paralelo (P no PMAP significa paralelo).

Ele distribui automaticamente a computação em todos os dispositivos atuais e lida com toda a comunicação entre eles. Para inspecionar os dispositivos disponíveis, você pode executar jax.devices().

from jax import pmap

def f(x):

return jnp.sin(x) + x**2

f(np.arange(4))

pmap(f)(np.arange(4))

Observe que o DeviceArray agora se tornou ShardedDeviceArra, que é a estrutura que lida com a execução paralela.

Outra coisa muito legal que Jax nos permite fazer é comunicação coletiva entre dispositivos. Digamos que queremos executar uma operação de “reduzir” entre os valores em todos os dispositivos (por exemplo, pegue a soma). Para executar isso, precisamos reunir todos os dados de todos os dispositivos e executar a soma. Isso pode ser facilmente realizado da seguinte maneira:

from functools import partial

from jax.lax import psum

@partial(pmap, axis_name="i")

def normalize(x):

return x/ psum(x,'i')

normalize(np.arange(8.))

O código acima mapeia o vetor x em todos os dispositivos e executa uma operação de comunicação coletiva para executar o psum (soma paralela). Em outras palavras, ele coleta todos os “x” dos dispositivos, resume -os e retorna o resultado para cada dispositivo para continuar com a computação paralela. Peguei emprestado o exemplo acima deste conversa incrível de Matthew Johnson durante o GTC 2020.

Você também pode imaginar isso com pmap Podemos definir nossos próprios padrões de computação e explorar nossos dispositivos da melhor maneira possível. Assim como costumamos fazer com CUDA para núcleos individuais, mas desta vez é para dispositivos separados.

Vectorização automática com VMAP

O VMAP é, como o nome sugere, uma transformação de função que nos permite vetorizar funções (V significa Vector!).

Podemos tirar uma função que opera em um único ponto de dados e o vetorize para que ele possa aceitar um lote desses pontos de dados (ou um vetor) de tamanho arbitrário. Aqui está um exemplo:

from jax import vmap

def f(x):

return jnp.square(x)

f(jnp.arange(10))

vmap(f)(jnp.arange(10))

Você pode se perguntar o que ganhamos aqui. Para entender isso, vamos dar uma olhada no que acontece quando f(x) executa sem o vmap:

  • Uma lista de saída é inicializada.

  • O quadrado de 0 é calculado e retornado.

  • O resultado 0 é anexado à lista.

  • O quadrado de 1 é calculado e retornado.

  • O resultado 1 é anexado à lista.

  • O quadrado de 2 é calculado e retornado.

  • O resultado 4 é anexado à lista.

  • E assim por diante …

O que o VMAP faz é que ele executa a operação quadrada apenas uma vez, porque ele lotam todos os valores e os passa pela função. E, é claro, isso resulta em um aumento no consumo de velocidade e memória.

Embora as transformações acima mencionadas sejam as que você definitivamente precisa saber, eu gostaria de mencionar mais algumas coisas que me surpreenderam durante minha jornada de Jax.

Gerador de números pseudo-aleatórios

O gerador de números aleatórios de Jax funciona um pouco diferente dos de Numpy. Em vez de ser um gerador de números de pseudorandom padrão padrão (PRNGs) como em Numpy e Scipy, as funções aleatórias JAX exigem que um estado de PRNG explícito seja aprovado como um primeiro argumento.

Um gerador de números aleatórios tem um estado. O próximo número “aleatório” é uma função do número anterior e da semente/estado. A sequência de valores aleatórios é finita e se repete.

Uma coisa importante a perceber é que os PRNGs estão funcionando bem em termos de vetorização e computação paralela entre dispositivos

from jax import random

key = random.PRNGKey(5)

random.uniform(key)

Despacho assíncrono

Outro aspecto de Jax que me impressionou é que ele usa despacho assíncrono. Isso significa que ele não espera que as operações sejam concluídas antes de retornar o controle ao programa Python. Em vez disso, ele retorna um DeviceArray que é um futuro (assim como Futuro completo em Java)

Um futuro é um valor que será produzido no futuro em um dispositivo acelerador, mas não está necessariamente disponível imediatamente.

O futuro pode ser passado para outras operações sem esperar que o cálculo seja concluído. Dessa forma, o JAX permite que o código Python seja executado à frente do acelerador, garantindo que ele possa envolver operações para o acelerador de hardware (por exemplo, GPU) sem ter que esperar.

Profiler de perfil JAX e Memória do dispositivo

O último recurso que quero mencionar é o perfil. Você ficará satisfeito em saber disso Tensoboard suporta o perfil JAX.

! (Perfil Jax Tensorboard)(Tensorboard Jax Profiling.png)
Fonte: Documentação Jax

O mesmo é verdadeiro para NVIDIA de Nsightque é usado para depurar e perfil do código da GPU. Ao lado, também é possível usar o Profiler de memória de dispositivo interno da Jax, que fornece visibilidade sobre como o código JAX é executado nas GPUs e TPUs. Aqui está um trecho da documentação:

import jax

import jax.numpy as jnp

import jax.profiler

def func1(x):

return jnp.tile(x, 10) * 0.5

def func2(x):

y = func1(x)

return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))

y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

Se você instalou pprofuma biblioteca do Google, você pode executar o seguinte comando, que abrirá uma janela do navegador com todas as informações necessárias.

$ pprof --web memory.prof

! (Perfil de memória do dispositivo)(Profiling de memória do dispositivo.png)
Fonte: Documentação Jax

Isso é incrível ou o quê?

Sinta -se à vontade para brincar com isso. Eu sei que sim.

Conclusão

Nesta postagem, tentei dar uma visão geral dos benefícios da JAX em relação a outras bibliotecas e apresentar trechos de código simples para aprender sua sintaxe e meandros básicos. A propósito, você pode encontrar o código completo neste Caderno de Colab ou em nosso Repositório do GitHub.

Nos próximos artigos, daremos um passo adiante e exploraremos como construir e treinar redes neurais profundas com Jax, além de dar uma olhada nas diferentes estruturas construídas sobre ela.

Se você achar este artigo interessante, não se esqueça de compartilhá -lo nas mídias sociais.

Referências

Aprendizagem profunda no livro de produção 📖

Aprenda a construir, treinar, implantar, escalar e manter modelos de aprendizado profundo. Entenda a infraestrutura de ML e os MLOPs usando exemplos práticos.

Saber mais

* Divulgação: Observe que alguns dos links acima podem ser links de afiliados e, sem custo adicional, ganharemos uma comissão se você decidir fazer uma compra depois de clicar.

Luis es un experto en Ciberseguridad, Computación en la Nube, Criptomonedas e Inteligencia Artificial. Con amplia experiencia en tecnología, su objetivo es compartir conocimientos prácticos para ayudar a los lectores a entender y aprovechar estas áreas digitales clave.

Leave a Reply

Your email address will not be published. Required fields are marked *

fun bingo casino app