Machine Learning Engineer (PyTorch to JAX Migration & Systems Optimization)

Jobs via Dice • United State
Remote
Apply
AI Summary

Seeking a specialized Machine Learning Engineer to re-architect LLMs from PyTorch to JAX for high-performance TPU/GPU clusters. This role involves structural porting, state management transition, and advanced profiling for maximum throughput and hardware efficiency. Key requirements include deep expertise in the high-performance AI stack, JAX/PyTorch ecosystems, and hardware-aware optimization. This is a 100% remote role requiring work in EST time zone, with no agency or C2C considered.

Key Highlights
Migrate and re-architect Large Language Models (LLMs) from PyTorch to JAX.
Optimize models for TPU and GPU clusters at scale, focusing on performance and hardware efficiency.
Requires deep expertise in high-performance AI stack, JAX/PyTorch, and distributed training.
Key Responsibilities
Manually migrate complex PyTorch LLM architectures (Transformers, MoE, SSMs) into JAX-based frameworks (Equinox, Flax, or Pax).
Transition imperative PyTorch state management to JAX s purely functional paradigm, handling PRNGKey management and immutable state updates.
Develop robust pipelines for checkpoint conversion, ensuring numerical parity between frameworks.
Use profiling tools to identify XLA compilation overheads, excessive rematerialization, or un-fused kernels.
Implement precision-tracking tools to ensure stable training runs during the transition.
Implement and optimize Fully Sharded Data Parallelism (FSDP) equivalents in JAX (using pjit or sharding APIs).
Design 3D parallelism strategies (Data, Pipeline, and Tensor) tailored for the interconnect topology of the target hardware.
Optimize HLO graphs to minimize 'jit-time' and maximize 'run-time' efficiency.
Apply optimizations like Selective Activation Checkpointing and memory-efficient attention based on hardware constraints.
Port PyTorch-based PEFT (LoRA, DoRA) methods into the JAX stack.
Adapt JAX implementations for newer primitives like Mamba/SSMs, Grouped-Query Attention (GQA), and Linear Attention.
Technical Skills Required
JAX Flax Equinox Optax Orbax PyTorch torch.compile DistributedDataParallel FSDP HLO MLIR Grain tf.data JAX Profiler TensorBoard Profiler NVIDIA Nsight Systems Perfetto NVIDIA H100 NVIDIA A100 Google TPU v4 Google TPU v5p Slurm Kubernetes Google Cloud Platform AWS Azure Functional Programming Asynchronous Programming All-Reduce All-Gather Reduce-Scatter BF16 FP8
Benefits & Perks
100% remote role
Nice to Have
Familiarity with the following technical Stack & Tooling Core Frameworks & Libraries: JAX Ecosystem: Expertise in Flax or Equinox (for model definition), Optax (for optimization/schedules), and Orbax (for checkpointing).
PyTorch Ecosystem: Deep knowledge of PyTorch 2.x, including torch.compile, DistributedDataParallel (DDP), and FSDP.
Intermediate Representations: Proficiency in HLO (High-Level Optimizer) and MLIR to understand how JAX code translates to hardware instructions.
Data Loaders: Experience migrating from torch.utils.data to Grain or tf.data for high-throughput JAX pipelines.
Profiling & Observability device memory traffic. JAX Profiler / TensorBoard: For identifying XLA compilation bottlenecks and tracing NVIDIA Nsight Systems: To analyze GPU utilization, SM occupancy, and NVLink Perfetto: For deep-dive trace analysis across multi-node TPU/GPU clusters.
Accelerator Hardware: Strong understanding of NVIDIA H100/A100 (Hopper/Ampere) architecture and Google TPU (v4/v5p) topology.
Orchestration: Experience with Slurm or Kubernetes (K8s) for managing large-scale training jobs.
Cloud Providers: Proficiency in Google Cloud (Google Cloud Platform) for TPUs or AWS/Azure for high-end GPU instances.
Core Skills & Competencies Software Engineering Excellence Functional Programming: A shift in mindset from OOP (Object-Oriented) to pure functions, immutability, and stateless logic.
Asynchronous Programming: Understanding JAX s asynchronous dispatch model and how to avoid 'host-sync' bottlenecks.
Testing Rigor: Ability to write property-based tests for numerical stability.
Distributed Systems Knowledge Collective Communications: Deep understanding of All-Reduce, All-Gather, and Reduce-Scatter primitives.
Network Topology: Understanding how rack-level interconnects (e.g., InfiniBand) affect the choice of 3D parallelism strategies.
Mathematical & AI Domain Expertise (Desirable) Linear Algebra: Mastery of tensor contractions, Einstein summation (einsum), and matrix decomposition.
Mixed Precision Training: Expert-level knowledge of Stochastic Rounding, Loss Scaling, and the nuances of BF16 vs. FP8 training.
Architecture Insight: Ability to decompose modern LLM components (KV Caches, Rotary Embeddings, Gated Linear Units) into their primitive mathematical operations

Job Description


Dice is the leading career destination for tech experts at every stage of their careers. Our client, Paradigm Infotech, is seeking the following. Apply via Dice today!

100% remote role

Need to work as per EST

No agency or C2C will NOT be considered and Visa sponsorship is not available nor provided

Machine Learning Engineer: Framework Migration &Systems Optimization (PyTorch to JAX)

We are seeking a specialized Machine Learning Engineer with deep expertise in the

high-performance AI stack. This role isn't just about "translating" code; it s about

re-architecting Large Language Models (LLMs) to thrive in a JAX-native environment,

specifically targeting TPU and GPU clusters at scale. You will bridge the gap between high-level PyTorch research implementations and thefunctional, XLA-optimized world of JAX/XLA, ensuring that our models achieve maximum throughput and hardware efficiency.

  • Core Framework Migration

Structural Porting: Manually migrate complex PyTorch LLM architectures (Transformers, MoE, SSMs) into JAX-based frameworks (Equinox, Flax, or Pax).

State Management: Transition imperative PyTorch state management to JAX s purely functional paradigm, handling PRNGKey management and immutable state updates with precision.

Weight Translation: Develop robust pipelines for checkpoint conversion, ensuring numerical parity between frameworks via rigorous unit testing and error tolerance checks.

  • Advanced Profiling & Numerical Stability

Bottleneck Analysis: Use the NVIDIA Nsight and TensorBoard Profiler to identify XLA compilation overheads, excessive rematerialization, or un-fused kernels.

Numerical Debugging: Implement precision-tracking tools to ensure that $BF16$ or $FP8$ training runs remain stable during the transition, preventing gradient divergence.

  • Scaling & Distributed Training

Parallelism Strategies: Implement and optimize Fully Sharded Data Parallelism (FSDP) equivalents in JAX (using pjit or sharding APIs).

Hybrid Parallelism: Design 3D parallelism strategies (Data, Pipeline, and Tensor) tailored for the interconnect topology (e.g., NVLink or TPU IC) of the target hardware.

  • Hardware-Aware Optimization

XLA Mastery: Understand and influence the XLA (Accelerated Linear Algebra) compiler behavior. You will optimize HLO (High-Level Optimizer) graphs to minimize "jit-time" and maximize "run-time" efficiency.

Memory Management: Apply optimizations like Selective Activation Checkpointing and memory-efficient attention (FlashAttention-2 JAX implementations) based on the specific HBM (High Bandwidth Memory) constraints of the hardware.

  • Fine-Tuning & Adaptation

Efficient Fine-Tuning: Port PyTorch-based PEFT (LoRA, DoRA) methods into the JAX stack.

Architectural Evolution: Stay ahead of the curve by adapting JAX implementations for newer primitives like Mamba/SSMs, Grouped-Query Attention (GQA), and Linear Attention as they emerge in the research space.

Familiarity with the following technical Stack & Tooling

  • Core Frameworks & Libraries:

JAX Ecosystem: Expertise in Flax or Equinox (for model definition), Optax (for optimization/schedules), and Orbax (for checkpointing).

PyTorch Ecosystem: Deep knowledge of PyTorch 2.x, including torch.compile, DistributedDataParallel (DDP), and FSDP.

Intermediate Representations: Proficiency in HLO (High-Level Optimizer) and MLIR to understand how JAX code translates to hardware instructions.

Data Loaders: Experience migrating from torch.utils.data to Grain or tf.data for high-throughput JAX pipelines.

  • Profiling & Observability device memory traffic.

JAX Profiler / TensorBoard: For identifying XLA compilation bottlenecks and tracing

NVIDIA Nsight Systems: To analyze GPU utilization, SM occupancy, and NVLink

Perfetto: For deep-dive trace analysis across multi-node TPU/GPU clusters.

  • Infrastructure & Hardware

Accelerator Hardware: Strong understanding of NVIDIA H100/A100 (Hopper/Ampere) architecture and Google TPU (v4/v5p) topology.

Orchestration: Experience with Slurm or Kubernetes (K8s) for managing large-scale training jobs.

Cloud Providers: Proficiency in Google Cloud (Google Cloud Platform) for TPUs or AWS/Azure for high-end GPU instances.

Core Skills & Competencies

  • Software Engineering Excellence

Functional Programming: A shift in mindset from OOP (Object-Oriented) to pure functions, immutability, and stateless logic.

Asynchronous Programming: Understanding JAX s asynchronous dispatch model and how to avoid "host-sync" bottlenecks.

Testing Rigor: Ability to write property-based tests for numerical stability.

  • Distributed Systems Knowledge

Collective Communications: Deep understanding of All-Reduce, All-Gather, and Reduce-Scatter primitives.

Network Topology: Understanding how rack-level interconnects (e.g., InfiniBand) affect the choice of 3D parallelism strategies.

  • Mathematical & AI Domain Expertise (Desirable)

Linear Algebra: Mastery of tensor contractions, Einstein summation (einsum), and matrix decomposition.

Mixed Precision Training: Expert-level knowledge of Stochastic Rounding, Loss Scaling, and the nuances of BF16 vs. FP8 training.

Architecture Insight: Ability to decompose modern LLM components (KV Caches, Rotary Embeddings, Gated Linear Units) into their primitive mathematical operations

Similar Jobs

Explore other opportunities that match your interests

AI/ML Engineer Lead / Architect

Machine Learning
•
2h ago
Visa Sponsorship Relocation Remote
Job Type Contract
Experience Level Mid-Senior level

United Software Group Inc

United State

Microsoft Foundry Specialist - AI/ML Engineer

Machine Learning
•
3h ago
Visa Sponsorship Relocation Remote
Job Type Contract
Experience Level Mid-Senior level

Ampstek

United State

Staff Machine Learning Engineer, AI

Machine Learning
•
8h ago

Premium Job

Sign up is free! Login or Sign up to view full details.

•••••• •••••• ••••••
Job Type ••••••
Experience Level ••••••

Wiraa

United State

Subscribe our newsletter

New Things Will Always Update Regularly