27.3 C
Colombia
martes, julio 8, 2025

Compilación justo a tiempo (JIT) para la implementación de modelos sin R



Compilación justo a tiempo (JIT) para la implementación de modelos sin R

Nota: Para seguir esta publicación, necesitará torch versión 0.5, que al momento de escribir este artículo aún no está en CRAN. Mientras tanto, instale la versión de desarrollo desde GitHub.

Cada dominio tiene sus conceptos, y estos son los que uno necesita entender, en algún momento, en su viaje desde copiar y hacer que funcione hasta una utilización deliberada y con propósito. Además, lamentablemente cada ámbito tiene su propia jerga, en la que los términos se utilizan de forma técnicamente correcta, pero no logran evocar una imagen clara para los no iniciados. El JIT de (Py-)Torch es un ejemplo.

Introducción terminológica

“El JIT”, del que se habla mucho en PyTorch-world y una característica eminente de R torchademás, es dos cosas al mismo tiempo – dependiendo de cómo se mire: un compilador optimizador; y un pase libre para la ejecución en muchos entornos donde ni R ni Python están presentes.

Compilado, interpretado, compilado justo a tiempo

“JIT” es un acrónimo común de “justo a tiempo” [to wit: compilation]. Compilación significa generar código ejecutable por máquina; es algo que tiene que sucederle a cada programa para que sea ejecutable. La pregunta es cuándo.

El código C, por ejemplo, se compila “a mano”, en algún momento arbitrario antes de su ejecución. Sin embargo, muchos otros lenguajes (entre ellos Java, R y Python) son, al menos en sus implementaciones predeterminadas, interpretado: Vienen con ejecutables (java, Ry pythonresp.) que crean código de máquina en tiempo de ejecuciónbasado en el programa authentic tal como está escrito o en un formato intermedio llamado código de bytes. La interpretación puede realizarse línea por línea, como cuando ingresa algún código en el REPL (bucle de lectura-evaluación-impresión) de R, o en fragmentos (si hay un script o una aplicación completa para ejecutar). En el último caso, dado que el intérprete sabe qué es possible que se ejecute a continuación, puede implementar optimizaciones que de otro modo serían imposibles. Este proceso se conoce comúnmente como compilación justo a tiempo. Por lo tanto, en el lenguaje common, la compilación JIT es compilación, pero en un momento en el que el programa ya se está ejecutando.

El torch compilador justo a tiempo

En comparación con esa noción de JIT, a la vez genérica (en términos técnicos) y específica (en el tiempo), lo que la gente de (Py-)Torch tiene en mente cuando hablan de “el JIT” está definido de manera más estricta (en términos de operaciones) y más inclusivo (en el tiempo): Lo que se entiende es el proceso completo desde proporcionar una entrada de código que se puede convertir en una representación intermedia (IR), pasando por la generación de esa IR, pasando por la optimización sucesiva del mismo por parte del compilador JIT, vía conversión (de nuevo, por el compilador) al código de bytes y, finalmente, a la ejecución, nuevamente a cargo del mismo compilador, que ahora actúa como una máquina digital.

Si eso suena complicado, no te asustes. Para realmente hacer uso de esta característica de R, no es necesario aprender mucho en términos de sintaxis; Una sola función, complementada con algunos ayudantes especializados, frena toda la pesada carga. Sin embargo, lo que importa es comprender un poco cómo funciona la compilación JIT, para saber qué esperar y no sorprenderse con resultados no deseados.

Lo que viene (en este texto)

Este submit tiene tres partes más.

En el primero, explicamos cómo hacer uso de las capacidades JIT en R. torch. Más allá de la sintaxis, nos centramos en la semántica (lo que esencialmente sucede cuando se hace un “rastreo JIT” de un fragmento de código) y en cómo eso afecta el resultado.

En el segundo, “miramos un poco debajo del capó”; siéntete libre de hojearlo brevemente si esto no te interesa demasiado.

En el tercero, mostramos un ejemplo del uso de la compilación JIT para permitir la implementación en un entorno que no tiene R instalado.

Cómo hacer uso de torch Compilación JIT

En el mundo de Python, o más específicamente, en las encarnaciones de Python de los marcos de aprendizaje profundo, hay un verbo mágico “rastro” que se refiere a una forma de obtener una representación gráfica a partir de la ejecución de código con entusiasmo. Es decir, ejecuta un fragmento de código (una función, por ejemplo, que contiene operaciones de PyTorch) en entradas de ejemplo. Estas entradas de ejemplo son arbitrarias en cuanto a valores, pero (naturalmente) deben ajustarse a las formas esperadas por la función. El seguimiento registrará las operaciones ejecutadas, es decir: aquellas operaciones que eran de hecho ejecutado, y solo esos. Cualquier ruta de código no ingresada queda relegada al olvido.

También en R, el rastreo es la forma en que obtenemos una primera representación intermedia. Esto se hace usando la función con el nombre apropiado jit_trace(). Por ejemplo:

library(torch)

f <- operate(x) {
  torch_sum(x)
}

# name with instance enter tensor
f_t <- jit_trace(f, torch_tensor(c(2, 2)))

f_t
<script_function>

Ahora podemos llamar a la función rastreada como la authentic:

f_t(torch_randn(c(3, 3)))
torch_tensor
3.19587
[ CPUFloatType{} ]

¿Qué sucede si hay un flujo de management, como un if ¿declaración?

f <- operate(x) {
  if (as.numeric(torch_sum(x)) > 0) torch_tensor(1) else torch_tensor(2)
}

f_t <- jit_trace(f, torch_tensor(c(2, 2)))

Aquí el rastreo debe haber entrado en el if rama. Ahora llame a la función trazada con un tensor que no suma un valor mayor que cero:

torch_tensor
 1
[ CPUFloatType{1} ]

Así es como funciona el rastreo. Los caminos no tomados se pierden para siempre. La lección aquí es no tener nunca un flujo de management dentro de una función que se va a rastrear.

Antes de continuar, mencionemos rápidamente dos de los más utilizados, además jit_trace()funciona en el torch Ecosistema JIT: jit_save() y jit_load(). Aquí están:

jit_save(f_t, "/tmp/f_t")

f_t_new <- jit_load("/tmp/f_t")

Un primer vistazo a las optimizaciones

Optimizaciones realizadas por el torch El compilador JIT ocurre en etapas. En la primera pasada, vemos cosas como la eliminación de códigos muertos y el cálculo previo de constantes. Tome esta función:

f <- operate(x) {
  
  a <- 7
  b <- 11
  c <- 2
  d <- a + b + c
  e <- a + b + c + 25
  
  
  x + d 
  
}

Aquí cálculo de e es inútil, nunca se usa. En consecuencia, en la representación intermedia, e ni siquiera aparece. Además, como los valores de a, by c ya se conocen en el momento de la compilación, la única constante presente en el IR es dsu suma.

Bueno, podemos comprobarlo por nosotros mismos. Para echar un vistazo al IR (el IR inicial, para ser precisos) primero rastreamos fy luego acceda a la función rastreada graph propiedad:

f_t <- jit_trace(f, torch_tensor(0))

f_t$graph
graph(%0 : Float(1, strides=[1], requires_grad=0, system=cpu)):
  %1 : float = prim::Fixed[value=20.]()
  %2 : int = prim::Fixed[value=1]()
  %3 : Float(1, strides=[1], requires_grad=0, system=cpu) = aten::add(%0, %1, %2)
  return (%3)

Y realmente, el único cálculo registrado es el que suma 20 al tensor pasado.

Hasta ahora, hemos estado hablando del paso inicial del compilador JIT. Pero el proceso no termina ahí. En pasadas posteriores, la optimización se expande al ámbito de las operaciones tensoriales.

Tome la siguiente función:

f <- operate(x) {
  
  m1 <- torch_eye(5, system = "cuda")
  x <- x$mul(m1)

  m2 <- torch_arange(begin = 1, finish = 25, system = "cuda")$view(c(5,5))
  x <- x$add(m2)
  
  x <- torch_relu(x)
  
  x$matmul(m2)
  
}

Aunque esta función pueda parecer inofensiva, implica bastantes gastos generales de programación. Una GPU separada núcleo (Se requiere una función C, que se paralelizará en muchos subprocesos CUDA) para cada uno de torch_mul() , torch_add(), torch_relu() y torch_matmul().

Bajo ciertas condiciones, se pueden encadenar varias operaciones (o fusionadopara usar el término técnico) en uno solo. Aquí, tres de esos cuatro métodos (es decir, todos menos torch_matmul()) operar puntualmente; es decir, modifican cada elemento de un tensor de forma aislada. En consecuencia, no sólo se prestan óptimamente a la paralelización individualmente, sino que lo mismo sería cierto para una función que fuera componer (“fusionarlos”): Para calcular una función compuesta “multiplica, luego suma y luego ReLU”

[
relu() circ (+) circ (*)
]

en un tensor elementono es necesario saber nada sobre otros elementos del tensor. Luego, la operación agregada podría ejecutarse en la GPU en un solo núcleo.

Para que esto suceda, normalmente tendría que escribir un código CUDA personalizado. Gracias al compilador JIT, en muchos casos no es necesario: creará dicho núcleo sobre la marcha.

Para ver la fusión en acción, usamos graph_for() (un método) en lugar de graph (una propiedad):

v <- jit_trace(f, torch_eye(5, system = "cuda"))

v$graph_for(torch_eye(5, system = "cuda"))
graph(%x.1 : Tensor):
  %1 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::Fixed[value=<Tensor>]()
  %24 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0), %25 : bool = prim::TypeCheck[types=[Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0)]](%x.1)
  %26 : Tensor = prim::If(%25)
    block0():
      %x.14 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::TensorExprGroup_0(%24)
      -> (%x.14)
    block1():
      %34 : Operate = prim::Fixed[name="fallback_function", fallback=1]()
      %35 : (Tensor) = prim::CallFunction(%34, %x.1)
      %36 : Tensor = prim::TupleUnpack(%35)
      -> (%36)
  %14 : Tensor = aten::matmul(%26, %1) # <stdin>:7:0
  return (%14)
with prim::TensorExprGroup_0 = graph(%x.1 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0)):
  %4 : int = prim::Fixed[value=1]()
  %3 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::Fixed[value=<Tensor>]()
  %7 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = prim::Fixed[value=<Tensor>]()
  %x.10 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = aten::mul(%x.1, %7) # <stdin>:4:0
  %x.6 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = aten::add(%x.10, %3, %4) # <stdin>:5:0
  %x.2 : Float(5, 5, strides=[5, 1], requires_grad=0, system=cuda:0) = aten::relu(%x.6) # <stdin>:6:0
  return (%x.2)

De este resultado, aprendemos que tres de las cuatro operaciones se han agrupado para formar un TensorExprGroup . Este TensorExprGroup se compilará en un único kernel CUDA. Sin embargo, la multiplicación de matrices, al no ser una operación puntual, debe realizarse sola.

En este punto, detenemos nuestra exploración de las optimizaciones JIT y pasamos al último tema: la implementación del modelo en entornos sin R. Si desea saber más, Thomas Viehmann weblog tiene publicaciones que detallan increíblemente la compilación JIT de (Py-)Torch.

torch sin R

Nuestro plan es el siguiente: definimos y entrenamos un modelo en R. Luego, lo rastreamos y lo guardamos. El archivo guardado es entonces jit_load()ed en otro entorno, un entorno que no tiene R instalado. Cualquier lenguaje que tenga una implementación de Torch servirá, siempre que esa implementación incluya la funcionalidad JIT. La forma más sencilla de mostrar cómo funciona esto es utilizando Python. Para la implementación con C++, consulte la instrucciones detalladas en el sitio net de PyTorch.

Definir modelo

Nuestro modelo de ejemplo es un perceptrón multicapa sencillo. Sin embargo, tenga en cuenta que tiene dos capas de abandono. Las capas de abandono se comportan de manera diferente durante el entrenamiento y la evaluación; y como hemos aprendido, las decisiones tomadas durante el rastreo son inamovibles. Esto es algo de lo que tendremos que ocuparnos una vez que hayamos terminado de entrenar el modelo.

library(torch)
internet <- nn_module( 
  
  initialize = operate() {
    
    self$l1 <- nn_linear(3, 8)
    self$l2 <- nn_linear(8, 16)
    self$l3 <- nn_linear(16, 1)
    self$d1 <- nn_dropout(0.2)
    self$d2 <- nn_dropout(0.2)
    
  },
  
  ahead = operate(x) {
    x %>%
      self$l1() %>%
      nnf_relu() %>%
      self$d1() %>%
      self$l2() %>%
      nnf_relu() %>%
      self$d2() %>%
      self$l3()
  }
)

train_model <- internet()

Modelo de tren en un conjunto de datos de juguetes.

Con fines de demostración, creamos un conjunto de datos de juguete con tres predictores y un objetivo escalar.

toy_dataset <- dataset(
  
  identify = "toy_dataset",
  
  initialize = operate(input_dim, n) {
    
    df <- na.omit(df) 
    self$x <- torch_randn(n, input_dim)
    self$y <- self$x[, 1, drop = FALSE] * 0.2 -
      self$x[, 2, drop = FALSE] * 1.3 -
      self$x[, 3, drop = FALSE] * 0.5 +
      torch_randn(n, 1)
    
  },
  
  .getitem = operate(i) {
    record(x = self$x[i, ], y = self$y[i])
  },
  
  .size = operate() {
    self$x$dimension(1)
  }
)

input_dim <- 3
n <- 1000

train_ds <- toy_dataset(input_dim, n)

train_dl <- dataloader(train_ds, shuffle = TRUE)

Entrenamos el tiempo suficiente para asegurarnos de que podamos distinguir el resultado de un modelo no entrenado del de uno entrenado.

optimizer <- optim_adam(train_model$parameters, lr = 0.001)
num_epochs <- 10

train_batch <- operate(b) {
  
  optimizer$zero_grad()
  output <- train_model(b$x)
  goal <- b$y
  
  loss <- nnf_mse_loss(output, goal)
  loss$backward()
  optimizer$step()
  
  loss$merchandise()
}

for (epoch in 1:num_epochs) {
  
  train_loss <- c()
  
  coro::loop(for (b in train_dl) {
    loss <- train_batch(b)
    train_loss <- c(train_loss, loss)
  })
  
  cat(sprintf("nEpoch: %d, loss: %3.4fn", epoch, imply(train_loss)))
  
}
Epoch: 1, loss: 2.6753

Epoch: 2, loss: 1.5629

Epoch: 3, loss: 1.4295

Epoch: 4, loss: 1.4170

Epoch: 5, loss: 1.4007

Epoch: 6, loss: 1.2775

Epoch: 7, loss: 1.2971

Epoch: 8, loss: 1.2499

Epoch: 9, loss: 1.2824

Epoch: 10, loss: 1.2596

Trazar en eval modo

Ahora, para la implementación, queremos un modelo que no no elimine cualquier elemento tensorial. Esto significa que antes de rastrear, debemos poner el modelo en eval() modo.

train_model$eval()

train_model <- jit_trace(train_model, torch_tensor(c(1.2, 3, 0.1))) 

jit_save(train_model, "/tmp/mannequin.zip")

El modelo guardado ahora se puede copiar a un sistema diferente.

Modelo de consulta de Python

Para hacer uso de este modelo de Python, nosotros jit.load() luego llámelo como lo haríamos en R. Veamos: Para un tensor de entrada de (1, 1, 1)esperamos una predicción alrededor de -1,6:

Jonny Kennaugh en desempaquetar

Related Articles

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Latest Articles