El entrenamiento de precisión mixta es una técnica utilizada en el aprendizaje profundo (AD) para acelerar el entrenamiento del modelo y reducir el consumo de memoria sin afectar significativamente a la precisión del modelo. Lo consigue utilizando estratégicamente una combinación de diferentes formatos de precisión numérica para almacenar y calcular valores dentro de una red neuronal (NN). Normalmente, esto implica utilizar el formato estándar de coma flotante de 32 bits (FP32 o de precisión única) para las partes críticas, como el almacenamiento de los pesos del modelo, mientras se emplean los formatos de coma flotante de 16 bits (FP16 o de precisión media, y a veces BF16 o BFloat16), más rápidos y que consumen menos memoria, para los cálculos durante las pasadas hacia delante y hacia atrás(retropropagación).
Cómo funciona la Precisión Mixta
La idea central de la precisión mixta es aprovechar las ventajas de velocidad y memoria de los formatos de menor precisión, mitigando al mismo tiempo los posibles problemas de estabilidad numérica. Un enfoque común implica estos pasos:
- Mantener pesos maestros en FP32: Se mantiene una copia primaria de los pesos del modelo en el formato estándar FP32 para garantizar una alta precisión en las actualizaciones de los pesos.
- Utiliza FP16/BF16 para los cálculos: Durante el bucle de entrenamiento, los pesos FP32 se convierten en FP16 o BF16 para los pases hacia delante y hacia atrás. Los cálculos que utilizan estos formatos de menor precisión son significativamente más rápidos en hardware moderno como las GPUsNVIDIA equipadas con Tensor Cores, que están específicamente diseñados para acelerar las multiplicaciones matriciales a bajas precisiones.
- Escalado de pérdidas: Cuando se utiliza FP16, el rango de números representables es mucho menor que FP32. Esto puede hacer que los pequeños valores de gradiente calculados durante la retropropagación se conviertan en cero (subdesbordamiento), dificultando el aprendizaje. Para evitarlo, el valor de pérdida se escala antes de la retropropagación, escalando efectivamente los gradientes a un rango representable por FP16. Antes de actualizar el peso, estos gradientes se vuelven a reducir. BF16, con su rango dinámico más amplio similar a FP32 pero de menor precisión, a menudo evita la necesidad de escalar las pérdidas.
- Actualizar pesos maestros: Los gradientes calculados (reducidos si se utilizó la escala de pérdidas) se utilizan para actualizar la copia maestra de los pesos, que permanecen en FP32.
Este cuidadoso equilibrio permite que los modelos se entrenen más rápido y utilicen menos GPU de la GPU.
Ventajas de la precisión mixta
- Entrenamiento más rápido: Los cálculos de menor precisión (FP16/BF16) se ejecutan mucho más rápido en hardware compatible, reduciendo significativamente el tiempo necesario para cada época de entrenamiento. Esto permite una iteración y experimentación más rápidas.
- Consumo de memoria reducido: Los valores FP16/BF16 requieren la mitad de memoria que los valores FP32. Esta reducción se aplica a las activaciones almacenadas durante el paso hacia delante y a los gradientes calculados durante el paso hacia atrás. Un menor uso de memoria permite entrenar modelos más grandes o utilizar lotes de mayor tamaño, lo que puede mejorar el rendimiento del modelo y la estabilidad del entrenamiento.
- Eficiencia mejorada: La combinación de computación más rápida y menores requisitos de ancho de banda de memoria conduce a un uso más eficiente de los recursos de hardware, reduciendo potencialmente los costes de formación para la computación en nube o los clusters locales.
Precisión mixta frente a conceptos afines
- Precisión total (FP32): El entrenamiento tradicional utiliza FP32 para todo el almacenamiento y el cálculo. Suele ser más estable numéricamente, pero más lento y requiere más memoria que la precisión mixta.
- Media precisión (FP16/BF16): Utilizar sólo FP16 o BF16 durante todo el entrenamiento puede provocar una inestabilidad numérica importante (especialmente FP16 sin técnicas como el escalado de pérdidas) y una pérdida potencial de precisión. La precisión mixta es un enfoque más robusto que combina FP32 y FP16/BF16.
- Cuantización del modelo: Suele referirse a la conversión de los pesos y/o activaciones del modelo a formatos de precisión aún más bajos, como los enteros de 8 bits (INT8), principalmente para optimizar la velocidad y la eficacia de la inferencia, sobre todo en los dispositivos de borde. Aunque a veces se utiliza durante el entrenamiento(Quantization-Aware Training), es distinta de la típica precisión mixta FP32/FP16 utilizada durante las fases de entrenamiento estándar.