Nota: Para seguir esta publicación, necesitará torch
versión 0.5, que al momento de escribir este artículo aún no está en CRAN. Mientras tanto, instale la versión de desarrollo desde GitHub.
Cada dominio tiene sus conceptos, y estos son los que uno necesita entender, en algún momento, en su viaje desde copiar y hacer que funcione hasta una utilización deliberada y con propósito. Además, lamentablemente cada ámbito tiene su propia jerga, en la que los términos se utilizan de forma técnicamente correcta, pero no logran evocar una imagen clara para los no iniciados. El JIT de (Py-)Torch es un ejemplo.
Introducción terminológica
“El JIT”, del que se habla mucho en PyTorch-world y una característica eminente de R torch
además, es dos cosas al mismo tiempo – dependiendo de cómo se mire: un compilador optimizador; y un pase libre para la ejecución en muchos entornos donde ni R ni Python están presentes.
Compilado, interpretado, compilado justo a tiempo
“JIT” es un acrónimo común de “justo a tiempo” [to wit: compilation]. Compilación significa generar código ejecutable por máquina; es algo que tiene que sucederle a cada programa para que sea ejecutable. La pregunta es cuándo.
El código C, por ejemplo, se compila “a mano”, en algún momento arbitrario antes de su ejecución. Sin embargo, muchos otros lenguajes (entre ellos Java, R y Python) son, al menos en sus implementaciones predeterminadas, interpretado: Vienen con ejecutables (java
, R
y python
resp.) que crean código de máquina en tiempo de ejecuciónbasado en el programa authentic tal como está escrito o en un formato intermedio llamado código de bytes. La interpretación puede realizarse línea por línea, como cuando ingresa algún código en el REPL (bucle de lectura-evaluación-impresión) de R, o en fragmentos (si hay un script o una aplicación completa para ejecutar). En el último caso, dado que el intérprete sabe qué es possible que se ejecute a continuación, puede implementar optimizaciones que de otro modo serían imposibles. Este proceso se conoce comúnmente como compilación justo a tiempo. Por lo tanto, en el lenguaje common, la compilación JIT es compilación, pero en un momento en el que el programa ya se está ejecutando.
El torch
compilador justo a tiempo
En comparación con esa noción de JIT, a la vez genérica (en términos técnicos) y específica (en el tiempo), lo que la gente de (Py-)Torch tiene en mente cuando hablan de “el JIT” está definido de manera más estricta (en términos de operaciones) y más inclusivo (en el tiempo): Lo que se entiende es el proceso completo desde proporcionar una entrada de código que se puede convertir en una representación intermedia (IR), pasando por la generación de esa IR, pasando por la optimización sucesiva del mismo por parte del compilador JIT, vía conversión (de nuevo, por el compilador) al código de bytes y, finalmente, a la ejecución, nuevamente a cargo del mismo compilador, que ahora actúa como una máquina digital.
Si eso suena complicado, no te asustes. Para realmente hacer uso de esta característica de R, no es necesario aprender mucho en términos de sintaxis; Una sola función, complementada con algunos ayudantes especializados, frena toda la pesada carga. Sin embargo, lo que importa es comprender un poco cómo funciona la compilación JIT, para saber qué esperar y no sorprenderse con resultados no deseados.
Lo que viene (en este texto)
Este submit tiene tres partes más.
En el primero, explicamos cómo hacer uso de las capacidades JIT en R. torch
. Más allá de la sintaxis, nos centramos en la semántica (lo que esencialmente sucede cuando se hace un “rastreo JIT” de un fragmento de código) y en cómo eso afecta el resultado.
En el segundo, “miramos un poco debajo del capó”; siéntete libre de hojearlo brevemente si esto no te interesa demasiado.
En el tercero, mostramos un ejemplo del uso de la compilación JIT para permitir la implementación en un entorno que no tiene R instalado.
Cómo hacer uso de torch
Compilación JIT
En el mundo de Python, o más específicamente, en las encarnaciones de Python de los marcos de aprendizaje profundo, hay un verbo mágico “rastro” que se refiere a una forma de obtener una representación gráfica a partir de la ejecución de código con entusiasmo. Es decir, ejecuta un fragmento de código (una función, por ejemplo, que contiene operaciones de PyTorch) en entradas de ejemplo. Estas entradas de ejemplo son arbitrarias en cuanto a valores, pero (naturalmente) deben ajustarse a las formas esperadas por la función. El seguimiento registrará las operaciones ejecutadas, es decir: aquellas operaciones que eran de hecho ejecutado, y solo esos. Cualquier ruta de código no ingresada queda relegada al olvido.
También en R, el rastreo es la forma en que obtenemos una primera representación intermedia. Esto se hace usando la función con el nombre apropiado jit_trace()
. Por ejemplo:
<script_function>
Ahora podemos llamar a la función rastreada como la authentic:
f_t(torch_randn(c(3, 3)))
torch_tensor
3.19587
[ CPUFloatType{} ]
¿Qué sucede si hay un flujo de management, como un if
¿declaración?
f <- operate(x) {
if (as.numeric(torch_sum(x)) > 0) torch_tensor(1) else torch_tensor(2)
}
f_t <- jit_trace(f, torch_tensor(c(2, 2)))
Aquí el rastreo debe haber entrado en el if
rama. Ahora llame a la función trazada con un tensor que no suma un valor mayor que cero:
torch_tensor
1
[ CPUFloatType{1} ]
Así es como funciona el rastreo. Los caminos no tomados se pierden para siempre. La lección aquí es no tener nunca un flujo de management dentro de una función que se va a rastrear.
Antes de continuar, mencionemos rápidamente dos de los más utilizados, además jit_trace()
funciona en el torch
Ecosistema JIT: jit_save()
y jit_load()
. Aquí están:
jit_save(f_t, "/tmp/f_t")
f_t_new <- jit_load("/tmp/f_t")
Un primer vistazo a las optimizaciones
Optimizaciones realizadas por el torch
El compilador JIT ocurre en etapas. En la primera pasada, vemos cosas como la eliminación de códigos muertos y el cálculo previo de constantes. Tome esta función:
f <- operate(x) {
a <- 7
b <- 11
c <- 2
d <- a + b + c
e <- a + b + c + 25
x + d
}
Aquí cálculo de e
es inútil, nunca se usa. En consecuencia, en la representación intermedia, e
ni siquiera aparece. Además, como los valores de a
, b
y c
ya se conocen en el momento de la compilación, la única constante presente en el IR es d
su suma.
Bueno, podemos comprobarlo por nosotros mismos. Para echar un vistazo al IR (el IR inicial, para ser precisos) primero rastreamos f
y luego acceda a la función rastreada graph
propiedad:
f_t <- jit_trace(f, torch_tensor(0))
f_t$graph
graph(%0 : Float(1, strides=[1], requires_grad=0, system=cpu)):
%1 : float = prim::Fixed[value=20.]()
%2 : int = prim::Fixed[value=1]()
%3 : Float(1, strides=[1], requires_grad=0, system=cpu) = aten::add(%0, %1, %2)
return (%3)
Y realmente, el único cálculo registrado es el que suma 20 al tensor pasado.
Hasta ahora, hemos estado hablando del paso inicial del compilador JIT. Pero el proceso no termina ahí. En pasadas posteriores, la optimización se expande al ámbito de las operaciones tensoriales.
Tome la siguiente función:
f <- operate(x) {
m1 <- torch_eye(5, system = "cuda")
x <- x$mul(m1)
m2 <- torch_arange(begin = 1, finish = 25, system = "cuda")$view(c(5,5))
x <- x$add(m2)
x <- torch_relu(x)
x$matmul(m2)
}
Aunque esta función pueda parecer inofensiva, implica bastantes gastos generales de programación. Una GPU separada núcleo (Se requiere una función C, que se paralelizará en muchos subprocesos CUDA) para cada uno de torch_mul()
, torch_add()
, torch_relu()
y torch_matmul()
.
Bajo ciertas condiciones, se pueden encadenar varias operaciones (o fusionadopara usar el término técnico) en uno solo. Aquí, tres de esos cuatro métodos (es decir, todos menos torch_matmul()
) operar puntualmente; es decir, modifican cada elemento de un tensor de forma aislada. En consecuencia, no sólo se prestan óptimamente a la paralelización individualmente, sino que lo mismo sería cierto para una función que fuera componer (“fusionarlos”): Para calcular una función compuesta “multiplica, luego suma y luego ReLU”
[
relu() circ (+) circ (*)
]
en un tensor elementono es necesario saber nada sobre otros elementos del tensor. Luego, la operación agregada podría ejecutarse en la GPU en un solo núcleo.
Para que esto suceda, normalmente tendría que escribir un código CUDA personalizado. Gracias al compilador JIT, en muchos casos no es necesario: creará dicho núcleo sobre la marcha.
Para ver la fusión en acción, usamos graph_for()
(un método) en lugar de graph
(una propiedad):
v <- jit_trace(f, torch_eye(5, system = "cuda"))
v$graph_for(torch_eye(5, system = "cuda"))
graph(%x.1 : Tensor):
%1 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::Fixed[value=<Tensor>]()
%24 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0), %25 : bool = prim::TypeCheck[types=[Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0)]](%x.1)
%26 : Tensor = prim::If(%25)
block0():
%x.14 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::TensorExprGroup_0(%24)
-> (%x.14)
block1():
%34 : Operate = prim::Fixed[name="fallback_function", fallback=1]()
%35 : (Tensor) = prim::CallFunction(%34, %x.1)
%36 : Tensor = prim::TupleUnpack(%35)
-> (%36)
%14 : Tensor = aten::matmul(%26, %1) # <stdin>:7:0
return (%14)
with prim::TensorExprGroup_0 = graph(%x.1 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0)):
%4 : int = prim::Fixed[value=1]()
%3 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::Fixed[value=<Tensor>]()
%7 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::Fixed[value=<Tensor>]()
%x.10 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = aten::mul(%x.1, %7) # <stdin>:4:0
%x.6 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = aten::add(%x.10, %3, %4) # <stdin>:5:0
%x.2 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = aten::relu(%x.6) # <stdin>:6:0
return (%x.2)
De este resultado, aprendemos que tres de las cuatro operaciones se han agrupado para formar un TensorExprGroup
. Este TensorExprGroup
se compilará en un único kernel CUDA. Sin embargo, la multiplicación de matrices, al no ser una operación puntual, debe realizarse sola.
En este punto, detenemos nuestra exploración de las optimizaciones JIT y pasamos al último tema: la implementación del modelo en entornos sin R. Si desea saber más, Thomas Viehmann weblog tiene publicaciones que detallan increíblemente la compilación JIT de (Py-)Torch.
torch
sin R
Nuestro plan es el siguiente: definimos y entrenamos un modelo en R. Luego, lo rastreamos y lo guardamos. El archivo guardado es entonces jit_load()
ed en otro entorno, un entorno que no tiene R instalado. Cualquier lenguaje que tenga una implementación de Torch servirá, siempre que esa implementación incluya la funcionalidad JIT. La forma más sencilla de mostrar cómo funciona esto es utilizando Python. Para la implementación con C++, consulte la instrucciones detalladas en el sitio net de PyTorch.
Definir modelo
Nuestro modelo de ejemplo es un perceptrón multicapa sencillo. Sin embargo, tenga en cuenta que tiene dos capas de abandono. Las capas de abandono se comportan de manera diferente durante el entrenamiento y la evaluación; y como hemos aprendido, las decisiones tomadas durante el rastreo son inamovibles. Esto es algo de lo que tendremos que ocuparnos una vez que hayamos terminado de entrenar el modelo.
library(torch)
internet <- nn_module(
initialize = operate() {
self$l1 <- nn_linear(3, 8)
self$l2 <- nn_linear(8, 16)
self$l3 <- nn_linear(16, 1)
self$d1 <- nn_dropout(0.2)
self$d2 <- nn_dropout(0.2)
},
ahead = operate(x) {
x %>%
self$l1() %>%
nnf_relu() %>%
self$d1() %>%
self$l2() %>%
nnf_relu() %>%
self$d2() %>%
self$l3()
}
)
train_model <- internet()
Modelo de tren en un conjunto de datos de juguetes.
Con fines de demostración, creamos un conjunto de datos de juguete con tres predictores y un objetivo escalar.
toy_dataset <- dataset(
identify = "toy_dataset",
initialize = operate(input_dim, n) {
df <- na.omit(df)
self$x <- torch_randn(n, input_dim)
self$y <- self$x[, 1, drop = FALSE] * 0.2 -
self$x[, 2, drop = FALSE] * 1.3 -
self$x[, 3, drop = FALSE] * 0.5 +
torch_randn(n, 1)
},
.getitem = operate(i) {
record(x = self$x[i, ], y = self$y[i])
},
.size = operate() {
self$x$dimension(1)
}
)
input_dim <- 3
n <- 1000
train_ds <- toy_dataset(input_dim, n)
train_dl <- dataloader(train_ds, shuffle = TRUE)
Entrenamos el tiempo suficiente para asegurarnos de que podamos distinguir el resultado de un modelo no entrenado del de uno entrenado.
optimizer <- optim_adam(train_model$parameters, lr = 0.001)
num_epochs <- 10
train_batch <- operate(b) {
optimizer$zero_grad()
output <- train_model(b$x)
goal <- b$y
loss <- nnf_mse_loss(output, goal)
loss$backward()
optimizer$step()
loss$merchandise()
}
for (epoch in 1:num_epochs) {
train_loss <- c()
coro::loop(for (b in train_dl) {
loss <- train_batch(b)
train_loss <- c(train_loss, loss)
})
cat(sprintf("nEpoch: %d, loss: %3.4fn", epoch, imply(train_loss)))
}
Epoch: 1, loss: 2.6753
Epoch: 2, loss: 1.5629
Epoch: 3, loss: 1.4295
Epoch: 4, loss: 1.4170
Epoch: 5, loss: 1.4007
Epoch: 6, loss: 1.2775
Epoch: 7, loss: 1.2971
Epoch: 8, loss: 1.2499
Epoch: 9, loss: 1.2824
Epoch: 10, loss: 1.2596
Trazar en eval
modo
Ahora, para la implementación, queremos un modelo que no no elimine cualquier elemento tensorial. Esto significa que antes de rastrear, debemos poner el modelo en eval()
modo.
train_model$eval()
train_model <- jit_trace(train_model, torch_tensor(c(1.2, 3, 0.1)))
jit_save(train_model, "/tmp/mannequin.zip")
El modelo guardado ahora se puede copiar a un sistema diferente.
Modelo de consulta de Python
Para hacer uso de este modelo de Python, nosotros jit.load()
luego llámelo como lo haríamos en R. Veamos: Para un tensor de entrada de (1, 1, 1)
esperamos una predicción alrededor de -1,6:
import torch
= torch.jit.load("/tmp/mannequin.zip")
deploy_model 1, 1, 1), dtype = torch.float)) deploy_model(torch.tensor((
tensor([-1.3630], system='cuda:0', grad_fn=<AddBackward0>)
Esto es lo suficientemente cercano como para asegurarnos de que el modelo implementado ha mantenido los pesos del modelo entrenado.
Conclusión
En esta publicación, nos hemos centrado en resolver un poco la confusión terminológica que rodea al torch
Compilador JIT y mostró cómo entrenar un modelo en R, rastro y consultar el modelo recién cargado desde Python. Deliberadamente, no hemos entrado en casos complejos y/o extremos; en R, esta característica aún está en desarrollo activo. Si tiene problemas con su propio código de uso JIT, ¡no dude en crear un problema en GitHub!
Y como siempre, ¡gracias por leer!
Foto por Jonny Kennaugh en desempaquetar