ROCm7 Logo  Bild © AMDROCm7 Logo (Bild © AMD)

Im Zentrum des PyTorch-Dockers steht Primus, ein modulares Trainingsframework, das die Konfiguration und Orchestrierung über Backends hinweg standardisiert. Die neueste Version bietet nun auch Unterstützung für TorchTitan und Megatron-LM, sodass Teams flexibel ihren bevorzugten Stack auswählen können, ohne Pipelines neu aufbauen zu müssen. Eine Begleitbibliothek, Primus-Turbo, führt Transformer-fokussierte Kernel-Optimierungen ein, die darauf abzielen, mehr Token-Durchsatz aus MI355X herauszuholen.

AMD ROCm 7.0  PyTorcvh Single Node PerfomranceAMD ROCm 7.0 PyTorcvh Single Node Perfomrance (Bild © AMD)

Durchsatz pro Knoten

In PyTorch/Primus-Tests zeigt MI355X bei weit verbreiteten LLMs durchgängige Verbesserungen gegenüber B200:

  • Llama3 70B BF16: +16 % (1,16x)
  • Llama3 8B FP8/BF16: +8 % / +2 % (1,08x / 1,02x)
  • Mixtral 8×7B FP16: +15 % (1,15x)
  • Llama3 70B FP8: Parität (1,0x)

Die wichtigste Kennzahl ist die Anzahl der Tokens pro GPU und Sekunde, die direkt zeigt, wie viel nützliche Arbeit ein Cluster pro Taktzyklus leisten kann.

AMD ROCm 7.0  PyTorch Multi Node PerformanceAMD ROCm 7.0 PyTorch Multi Node Performance (Bild © AMD)

JAX MaxText: vorgefertigtes Image, vorhersehbare Skalierung

Das Docker-Bundle „MaxText” von ROCm packt die Bibliotheken JAX, XLA und ROCm zusammen mit MaxText-Dienstprogrammen, was JAX-Nutzern die Einrichtung erleichtert. Auf MI355X zeigen die Ergebnisse von JAX mit einem einzelnen Knoten eine konstante Leistungssteigerung bei dichten Modellen und Parität bei MoE:

  • Llama3.1 70B FP8: +11 % (1,11x)
  • Llama3.1 8B FP8: +7 % (1,07x)
  • Mixtral 8×7B FP16: Parität (1,00x)

Skalierung über mehrere Knoten

Verteilte Trainings-Benchmarks im Primus-Megatron-Stack zeigen robuste Skalierungseigenschaften auf MI355X:

  • Mixtral 8×22B BF16, 4 Knoten: +14 % (1,14x)
  • Llama3 70B FP8, 4 Knoten: Parität (1,01x)
  • Llama3.1 405B FP8, 8 Knoten: nahezu Parität (0,96x)

Diese Ergebnisse zeigen, dass MI355X eine wettbewerbsfähige Effizienz beibehält, wenn sich die Workloads über die Knoten verteilen, was ein wichtiger Faktor für das Vorabtraining mit Billionen von Tokens ist.