Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchprime/torch_xla_models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import transformers
from datasets import load_dataset
from omegaconf import DictConfig, OmegaConf
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch_xla._internal.jax_workarounds import jax_env_context
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear
from transformers import (
Expand All @@ -44,6 +44,11 @@
from torchprime.torch_xla_models import offloading, remat_all, scan_layers
from torchprime.torch_xla_models.topology import get_mesh, is_1d_sharding

if version.parse(torch_xla.__version__.split("+")[0]) >= version.parse("2.8.0"):
from torch_xla._internal.jax_workarounds import jax_env_context
else:
from torch_xla.experimental.custom_kernel import jax_env_context

check_min_version("4.39.3")
logger = logging.getLogger(__name__)

Expand Down