[Перевод] Настройка Llama3 405B на AMD MI300x

[Перевод] Настройка Llama3 405B на AMD MI300x

Введение

Опенсорсные модели становятся всё объёмнее, поэтому потребность в надёжной инфраструктуре для выполнения крупномасштабного обучения ИИ сегодня как никогда высока. Недавно наша компания выполнила fine-tuning модели LLaMA 3.1 405B на GPU AMD, доказав их способность эффективно справляться с крупномасштабными задачами ИИ. Наш опыт был крайне положительным, и мы с радостью выложили всю свою работу на GitHub в опенсорс.

GPU AMD, и в особенности серия MI300X — это серьёзная альтернатива ИИ-оборудованию NVIDIA, обеспечивающая больше производительности на вложенный доллар. Наша система состояла из одного узла с 8 GPU AMD MI300x, а для fine-tuning мы использовали JAX. В этой статье мы расскажем всю историю fine-tuning LLaMA 405B, в том числе и подробности шардинга параметров и реализации LoRA.

Что такое JAX и почему мы его выбрали

JAX — это мощная библиотека для машинного обучения, объединяющая в себе NumPy-подобные API, автоматическое дифференцирование и компилятор Google XLA. Она имеет великолепные API для параллелизма моделей, идеально подходящие для обучения огромных моделей наподобие LLaMA 3.1 405B.

Почему я так люблю JAX: 

  1. Чистые функции: JAX мотивирует к написанию чистых функций (если вы хотите компилировать код при помощи JIT), что упрощает компоновку, отладку и чтение кода.

  2. Продвинутый параллелизм: гибкие JIT API библиотеки JAX изначально поддерживают продвинутый параллелизм данных и моделей, что крайне важно для крупномасштабного обучения.

  3. Повышение чистоты кодовых баз: философия дизайна JAX стимулирует к написанию кода, изначально портируемого между аппаратными платформами (CPU, GPU, TPU), что приводит к повышению чистоты и удобства поддержки кодовых баз.

Если вы хотите глубже изучить преимущества JAX перед PyTorch, то рекомендую прочитать пост PyTorch is dead. Long live JAX.

Особенно замечательна JAX при работе с оборудованием, произведённым не NVIDIA: 

При работе с AMD JAX обеспечивает множество преимуществ:

  1. Независимый от оборудования подход: JAX использует компилятор XLA (Accelerated Linear Algebra), компилирующий вычисления в независимое от оборудования промежуточное представление (граф HLO). Это позволяет оптимизировать и эффективно исполнять без модификаций один и тот же код JAX на разных аппаратных бэкендах, включая GPU AMD.

  2. Платформонезависимые оптимизации: компилятор XLA выполняет оптимизации вне зависимости от оборудования, от чего выигрывают все поддерживаемые платформы.

  3. Упрощённая портируемость: при работе с JAX переход с NVIDIA на AMD (или на другое поддерживаемое оборудование) требует лишь минимальных изменений в коде. Это сильно отличает её от PyTorch, который более тесно связан с экосистемой CUDA NVIDIA.

    • PyTorch часто использует специфичные для CUDA реализации (например, вызовы torch.cudascaled_dot_product_attention).

    • Хотя PyTorch поддерживает другие бэкенды наподобие ROCm для AMD GPU, портирование кода может быть трудной задачей из-за специфичных для NVIDIA путей исполнения кода.

    • Процесс «избавления от NVIDIA» кода PyTorch может повысить сложность и помешать портируемости.

Подготовить JAX для AMD крайне просто! 

Настройка JAX на GPU AMD — это очень простой процесс:

# Подтягиваем образ Docker:
docker pull rocm/jax:latest

# Запускаем контейнер Docker:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest

# Верифицируем установку:
python3 -c 'import jax; print(jax.devices())'

Я работал с узлом AMD, состоящим из 8 GPU AMD MI300x. У каждого из MI300x имелось 192 ГБ памяти HBM3. Они крайне хорошо проявляют себя по сравнению с новыми GPU NVIDIA H100. (См. сравнение ниже, источник: TensorWave)

Обучение LLaMA 405B: производительность и масштабируемость 

При помощи JAX мне удалось обучить модель LLaMA 405B на GPU AMD, добившись впечатляющих результатов.

Мы выполнили fine-tuning LoRA со всеми весами модели и параметрами lora с точностью bfloat16, с LoRA rank = 8 и LoRA alpha = 16:

  • Размер модели: веса модели LLaMA занимают примерно 800 ГБ VRAM.

  • Веса LoRA + состояние оптимизатора: приблизительно 400 ГБ VRAM.

  • Общее использование VRAM: 77% от общего объёма VRAM, примерно 1200 ГБ.

  • Ограничения: из-за большого размера модели 405B пространство для размеров батчей и длины последовательностей было ограничено. Я использовал размер батчей 16 и длину последовательностей 64.

  • JIT-компиляция: кроме того, из-за ограничений пространства я не смог запустить JIT-компилируемую версию; вероятно, для этого требуется чуть больше пространства, чем для графа eager mode.

  • Скорость обучения: примерно 35 токенов в секунду в eager mode JAX (1 этап обучения занимал 30 с)

  • Эффективность использования памяти: стабильно примерно 70%

  • Масштабирование: при работе с JAX масштабирование было примерно линейным среди всех 8 GPU.

Ниже представлены показатели GPU, эффективности использования памяти и результаты rocm-smi для 8 GPU на одном этапе обучения прогона fine-tuning:

  • Использование GPU:

  • Использование VRAM:

результаты rocm-smi:

Устройство

Температура

Мощность

Разделы

Кулер

Производительность

PwrCap

VRAM%

GPU%

0

58,0°C

232,0 Вт

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

27%

1

58,0°C

233,0 Вт

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

25%

2

56,0°C

236,0 Вт

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

24%

3

52,0°C

228,0 Вт

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

23%

4

59,0°C

232,0 Вт

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

22%

5

51,0°C

230,0 Вт

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

21%

6

61,0°C

235.0W

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

18%

7

56,0°C

227,0 Вт

NPS1, SPX, 0

0%

auto

750,0 Вт

77%

18%

Полную информацию об использовании GPU, VRAM и данные rocm-smi можно найти в нашем репозитории Github.

Наша система для обучения 

Мы перенесли архитектуру LLaMA 3.1 с PyTorch на JAX. Нашу реализацию можно изучить в репозитории GitHub.

Этот перенос открыл для нас новые возможности с точки зрения производительности и масштабируемости.

Загрузка модели и параметров шардинга

Для работы с такой огромной моделью, как LLaMA 405B, требуется эффективный шардинг параметров между несколькими устройствами. Ниже мы расскажем, как добились его при помощи JAX.

Параметры шардинга в JAX 

Чтобы эффективно распределить огромную модель LLaMA 405B на 8 GPU AMD, мы применили функцию меша устройств (device mesh) JAX (codepointer). Меш устройств упорядочивает имеющиеся устройства в многомерную сетку, позволяя нам указывать, как будут разбиты вычисления и данные. В своей системе мы создали меш с формой (1, 8, 1), а именно с такими осями, как параллелизм данных (data parallelism, dp), параллелизм данных с полным шардингом (fully sharded data parallelism, fsdp) и параллелизм модели (model parallelism, mp). Затем мы применили к параметрам модели конкретные правила шардинга, указав для каждого тензора модели способ разбиения его размерностей между осями меша.

DEVICES = jax.devices()
DEVICE_COUNT = len(DEVICES)
DEVICE_MESH = mesh_utils.create_device_mesh((1, 8, 1))
MESH = Mesh(devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))

Визуализация шардинга 

Шардинг массивов можно визуализировать при помощи jax.debug.visualize_array_sharding. Это невероятно полезно для проверки правильности применения спецификаций шардинга.

Правила разбиения

Мы определили правила разбиения для различных компонентов модели:

Способ шардинга параметров 

  • Обычные параметры: разбиты шардингом на 8 GPU.

    • Например, тензор LM head (lm_head/kernel) имеет две оси, разбитые с PS("fsdp", "mp"); в наше случае это 8, 1, так что мы видим, что по первой оси тензор разбит на 8 GPU.

  • Нереплицированные параметры:

    • Параметры без спецификаций шардинга реплицируются между всеми устройствами.

    • Например, нормы слоёв (attention_norm/kernel и ffn_norm/kernel) используют PS(None).

Применение ограничений шардинга 

В процессе загрузки модели мы инкрементно выполняем шардинг весов модели при помощи специальных функций шардинга:

def make_shard_and_gather_fns(partition_specs):
    def make_shard_fn(partition_spec):
        out_sharding = NamedSharding(mesh, partition_spec)
        def shard_fn(tensor):
            return jax.device_put(tensor, out_sharding).block_until_ready()
        return shard_fn

    shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
    return shard_fns

# Создаём функции шардинга на основании правил разбиения
shard_fns = make_shard_and_gather_fns(partitioning_rules)

Это позволяет нам помещать каждый параметр на соответствующие устройства с указанным шардингом.

Шардинг батча обучения

Изначально батч обучения создаётся обычным образом. Перед передачей его модели мы выполняем его шардинг между GPU в соответствии со следующим кодом:

train_batch = jax.device_put(
    train_batch, NamedSharding(self.mesh, PS("dp", "fsdp"))
)

Здесь мы указываем, что батч обучения должен быть разделён шардингом между осями data parallel ("dp") и fully sharded data parallel ("fsdp"), которые в нашем случае соответствуют 1, 8; это приводит к следующей визуализации:

  • до шардинга

  • после вызова jax.device_put

Реализация обучения LoRA 

LoRA (Low-Rank Adaptation) снижает количество параметров для обучения, разбивая обновления весов на низкоранговые матрицы. Это особенно полезно для fine-tuning больших моделей.

Ключевые аспекты нашей реализации LoRA:

  1. Раздельная параметризация: мы храним параметры LoRA (lora_a и lora_b) отдельно от параметров основной модели.

  2. Прекращение градиента: мы используем jax.lax.stop_gradient(kernel), чтобы предотвратить обновления весов основной модели.

  3. Эффективное умножение матриц: мы используем lax.dot_general для быстрых матричных операций с контролем точности.

  4. Коэффициент масштабирования: перед добавлением к основным выходным данных выходные данные LoRA масштабируются на (self.lora_alpha / self.lora_rank).

Слой LoRADense 

Мы реализовали специальный слой LoRADense, включающий в себя параметры LoRA:

class LoRADense(nn.Module):
    features: int
    lora_rank: int = 8
    lora_alpha: float = 16.0

    @nn.compact
    def __call__(self, inputs: Any) -> Any:
        # Параметр исходного ядра (заморожен)
        kernel = self.param('kernel', ...)
        y = lax.dot_general(inputs, jax.lax.stop_gradient(kernel), ...)

        # Параметры LoRA (обучаемые)
        lora_a = self.variable('lora_params', 'lora_a', ..., ...)
        lora_b = self.variable('lora_params', 'lora_b', ..., ...)

        # Вычисление выходных данных LoRA
        lora_output = lax.dot_general(inputs, lora_a.value, ...)
        lora_output = lax.dot_general(lora_output, lora_b.value, ...)

        # Комбинирование исходных выходных данных с модификациями LoRA
        y += (self.lora_alpha / self.lora_rank) * lora_output

        return y.astype(self.dtype)

Шардинг параметров LoRA 

Для эффективного распределения параметров LoRA между устройствами мы с помощью JAX применили особые правила шардинга. Это гарантирует, что параметры LoRA выравниваются с шардингом параметров основной модели, оптимизируя при этом и использование памяти, и эффективность вычислений.

Матрицы LoRA A (lora_a) 

  • Использованная нами спецификация разбиенияPS("fsdp", "mp").

  • Визуализация:

    • Шардинг осей: шардинг параметров lora_a между осями будет выполняться как (8, 1), то есть первая ось разбивается шардингом на 8 устройств (ось fsdp), а вторая ось не разбивается.

      На иллюстрации показано, что первая ось разбита шардингом на 8 устройств (ось fsdp), а вторая ось не разбита.

Матрицы LoRA B (lora_b) 

  • Использованная нами спецификация разбиенияPS("mp", "fsdp").

  • Визуализация:

    • Шардинг осей: шардинг параметров lora_b по слоям будет выполняться как (1, 8), то есть вторая ось разбивается шардингом на 8 устройств (ось fsdp), а первая ось не разбивается.

      На иллюстрации показано, что вторая ось разбита шардингом на 8 устройств (ось fsdp), разбивая столбцы матрицы.

Такая стратегия шардинга оптимизирует распределение параметров, снижает лишнюю трату ресурсов на коммуникации и повышает параллелизм при обучении. Она гарантирует, что на каждом устройстве содержится только часть параметров LoRA, обеспечивая эффективное масштабирование для больших моделей наподобие LLaMA 405B.

Обновление только параметров LoRA 

Для оптимизации обучения при fine-tuning модели LLaMA 405B мы вычисляем градиенты только для параметров LoRA, оставляя параметры основной модели замороженными. При таком подходе снижается объём используемой памяти и ускоряется обучение, потому что мы обновляем меньшее количество параметров. Подробности реализации можно посмотреть в нашем репозитории GitHub.

В нашем цикле обучения на каждом этапе используется передача батча входных данных через модель. Так как обучаются только параметры LoRA, прогнозы модели и вычисляемая функция потерь зависят только от этих параметров. Затем мы выполняем обратное распространение градиентов с параметрами LoRA. Сосредоточив обновления только на этих параметрах, мы упрощаем процесс обучения, что позволяет эффективно выполнять на нескольких GPU fine-tuning чрезвычайно больших моделей наподобие LLaMA 405B.

Заключение

Fine-tuning огромной модели LLaMA 3.1 405B на GPU AMD при помощи JAX оставил у нас крайне положительное впечатление. Благодаря использованию мощных возможностей параллелизма JAX и её независящих от оборудования методик я смог эффективно распределить модель по 8 GPU AMD MI300x. Использование шардинга параметров позволило эффективно управлять огромным объёмом параметров модели между устройствами, что обеспечило почти линейную масштабируемость и высокую эффективность использования памяти.

Этот опыт подчёркивает способности GPU AMD в качестве мощной альтернативы оборудованию NVIDIA в крупномасштабном обучении ИИ. Беспроблемная интеграция JAX с поддержкой ROCm упрощает переход и открывает новые возможности для сообщества исследователей и разработчиков ИИ. Делясь своим опытом и кодом, я надеюсь, что это мотивирует других исследовать и применять эти инструменты в собственных крупномасштабных проектах машинного обучения.

 

Источник

Читайте также