Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Bash commands
- `uv run pytest`: Run the tests.
- `make style && make quality` run the linter + formatter.

# Workflow
- Always run the linter and make sure the tests pass before finishing a task.
- Prefer running single tests, not the whole suite, when developing.
- To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes.
- Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`.
- Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`.
11 changes: 1 addition & 10 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
# Bash commands
- `uv run pytest`: Run the tests.
- `make style && make quality` run the linter + formatter.

# Workflow
- Always run the linter and make sure the tests pass before finishing a task.
- Prefer running single tests, not the whole suite, when developing.
- To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes.
- Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`.
- Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`.
@AGENTS.md
18 changes: 0 additions & 18 deletions open_instruct/IFEvalG/instructions_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,21 +292,3 @@
_KEYWORD + "keyword_specific_position": {_KEYWORD + "keyword_specific_position"},
_KEYWORD + "start_end": {_KEYWORD + "start_end"},
}


def conflict_make(conflicts):
"""Makes sure if A conflicts with B, B will conflict with A.

Args:
conflicts: Dictionary of potential conflicts where key is instruction id
and value is set of instruction ids that it conflicts with.

Returns:
Revised version of the dictionary. All instructions conflict with
themselves. If A conflicts with B, B will conflict with A.
"""
for key in conflicts:
for k in conflicts[key]:
conflicts[k].add(key)
conflicts[key].add(key)
return conflicts
9 changes: 0 additions & 9 deletions open_instruct/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@
original_builtins = __builtins__


def encode_tests(tests: list) -> str:
if not tests:
return ""
pickled_data = pickle.dumps(tests)
compressed_data = zlib.compress(pickled_data)
b64_encoded_data = base64.b64encode(compressed_data)
return b64_encoded_data.decode("utf-8")


def decode_tests(tests: Any) -> list:
if not tests:
return []
Expand Down
11 changes: 0 additions & 11 deletions open_instruct/code_utils/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@
import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"


def truncatefn(s, length=300):
if isinstance(s, str):
pass
else:
s = str(s)
if len(s) <= length:
return s

return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]


class CODE_TYPE(Enum):
call_based = 0
standard_input = 1
Expand Down
16 changes: 0 additions & 16 deletions open_instruct/context_window_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,22 +385,6 @@ async def safe_acompletion_with_context_check(


# Convenience function for quick context checking
def will_exceed_context_window(
messages: list[dict[str, str]],
max_completion_tokens: int,
model_name: str,
max_context_length: int = 8192,
safety_margin: int = 100,
) -> bool:
"""
Quick check to see if a request would exceed the context window.

Returns:
bool: True if the request would exceed context window, False otherwise
"""
return not check_context_window_limit(
messages, max_completion_tokens, model_name, max_context_length, safety_margin
)


def truncate_str_for_prompt_template(
Expand Down
15 changes: 0 additions & 15 deletions open_instruct/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@ def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_p
return min(num_required_cpus, num_available_cpus, dataset_len)


def select_nested(dataset: DatasetDict, max_examples_per_split: int):
"""select the dataset nested in a DatasetDict"""
return {key: dataset[key].select(range(min(max_examples_per_split, len(dataset[key])))) for key in dataset}


class DatasetProcessor:
def __init__(self, tokenizer: PreTrainedTokenizer, config: DatasetConfig) -> None:
self.tokenizer = tokenizer
Expand Down Expand Up @@ -475,16 +470,6 @@ def get_token_length_visualization(self, dataset: DatasetDict, save_path: str =
)


def convert_preference_dataset_to_binary_dataset(ds: Dataset):
binary_ds = defaultdict(list)
for i in tqdm(range(len(ds))):
binary_ds[SFT_MESSAGE_KEY].append(ds[i]["chosen"])
binary_ds[BINARY_LABEL_KEY].append(True)
binary_ds[SFT_MESSAGE_KEY].append(ds[i]["rejected"])
binary_ds[BINARY_LABEL_KEY].append(False)
return Dataset.from_dict(binary_ds)


def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer):
i = 0
console = Console()
Expand Down
46 changes: 0 additions & 46 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,6 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer):
console.print(rich_text)


def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrainedTokenizer):
i = 0
console = Console()
rich_text = Text()
# for i, token in enumerate():
for i in range(min(len(tokens), len(masks))):
token = tokens[i]
color = COLORS[masks[i] % len(COLORS)]
decoded_token = tokenizer.decode(token)
rich_text.append(f"{decoded_token}", style=color)
console.print(rich_text)


# ----------------------------------------------------------------------------
# Tokenization
# Chat templates
Expand Down Expand Up @@ -1274,25 +1261,6 @@ def rlvr_tokenize_v3(
return row


def rlvr_filter_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
need_contain_labels: bool = True,
max_prompt_token_length: Optional[int] = None,
max_token_length: Optional[int] = None,
):
max_prompt_token_length_ok = True
if max_prompt_token_length is not None:
max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= max_prompt_token_length

max_token_length_ok = True
if max_token_length is not None:
max_token_length_ok = len(row[INPUT_IDS_KEY]) <= max_token_length

contain_some_labels = any(x != -100 for x in row[LABELS_KEY])
return max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels)


def rlvr_max_length_filter_v2(
row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_prompt_token_length: Optional[int] = None
):
Expand Down Expand Up @@ -1686,20 +1654,6 @@ def count_tokens(sample):
return loaded_dataset, all_statistics


def get_cached_dataset(
dcs: List[DatasetConfig],
tc: TokenizerConfig,
hf_entity: Optional[str] = None,
dataset_local_cache_dir: Optional[str] = None,
dataset_skip_cache: bool = False,
) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
if dataset_local_cache_dir is not None:
cache = LocalDatasetTransformationCache(dataset_local_cache_dir=dataset_local_cache_dir)
else:
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache)


def get_cached_dataset_tulu_with_statistics(
dataset_mixer_list: List[str],
dataset_mixer_list_splits: List[str],
Expand Down
12 changes: 0 additions & 12 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@

import asyncio
import gc
import json
import math
import random
import shutil
Expand Down Expand Up @@ -367,17 +366,6 @@ def __post_init__(self):
print("WARNING: number_samples_per_prompt is 1. This reduces GRPO to REINFORCE. ")


def process_dataset_mixer(value) -> tuple[dict | None, str | None]:
# if passed through cli: convert the dataset mixers to dictionaries
if isinstance(value, str):
return json.loads(value), value
# if passed through yaml: convert the dataset mixers to strings
elif isinstance(value, dict):
return value, json.dumps(value)
else:
raise ValueError("Input must be either a string or a dictionary")


def get_train_ds_config(
offload,
adam_offload=False,
Expand Down
54 changes: 5 additions & 49 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,23 @@
# limitations under the License.


import asyncio
import itertools
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Literal, Union

try:
import deepspeed
from deepspeed.runtime.engine import DeepSpeedEngine
except ImportError:
pass
import asyncio
from typing import Literal

import pandas as pd
import torch
import transformers
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from deepspeed.runtime.engine import DeepSpeedEngine
from huggingface_hub import HfApi
from rich import print as rprint
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from torch.nn.parallel.distributed import DistributedDataParallel
from transformers import PreTrainedModel, PreTrainedTokenizer

from open_instruct import logger_utils
Expand Down Expand Up @@ -419,24 +412,6 @@ def generate(
return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits


@torch.no_grad()
def batch_generation(
model: torch.nn.Module,
queries: torch.Tensor,
local_rollout_forward_batch_size: int,
pad_token_id: int,
generation_config: dict,
):
query_responses = []
logitss = []
for i in range(0, queries.shape[0], local_rollout_forward_batch_size):
query = queries[i : i + local_rollout_forward_batch_size]
query_response, logits = generate(model, query, pad_token_id, generation_config)
query_responses.append(query_response)
logitss.append(logits)
return torch.cat(query_responses, 0), torch.cat(logitss, 0)


def get_olmo3_generation_config(tokenizer):
return transformers.GenerationConfig(
temperature=None,
Expand Down Expand Up @@ -556,7 +531,7 @@ def iter_params(module, recurse=False):
return [param for _, param in get_all_parameters(module, recurse)]


def remove_hooks(model: "DeepSpeedEngine") -> None:
def remove_hooks(model: DeepSpeedEngine) -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
Expand All @@ -575,7 +550,7 @@ def remove_hooks(model: "DeepSpeedEngine") -> None:
optimizer_offload.backward_hooks = []


def add_hooks(model: "DeepSpeedEngine") -> None:
def add_hooks(model: DeepSpeedEngine) -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
Expand All @@ -584,25 +559,6 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
optimizer_offload._register_hooks_recursively(optimizer_offload.module)


@contextmanager
def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False
) -> Union["transformers.PreTrainedModel", "DeepSpeedEngine"]:
"""Context manager to unwrap a model for generation.
For ZeRO-3 models, we gather the weights once to speed up generation.
"""
unwrapped_model = accelerator.unwrap_model(model)
if is_peft_model:
unwrapped_model.pretrained_model.disable_adapter()
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
add_hooks(model)
else:
yield unwrapped_model


def prepare_deepspeed(model: torch.nn.Module, per_device_train_batch_size: int, mixed_precision: str):
"""
Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and
Expand Down
12 changes: 0 additions & 12 deletions open_instruct/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
# isort: on

import gc
import json
import math
import random
import shutil
Expand Down Expand Up @@ -352,17 +351,6 @@ class Args:
"""What dataset to upload the metadata to. If unset, don't upload metadata"""


def process_dataset_mixer(value) -> tuple[dict | None, str | None]:
# if passed through cli: convert the dataset mixers to dictionaries
if isinstance(value, str):
return json.loads(value), value
# if passed through yaml: convert the dataset mixers to strings
elif isinstance(value, dict):
return value, json.dumps(value)
else:
raise ValueError("Input must be either a string or a dictionary")


def get_train_ds_config(
offload,
adam_offload=False,
Expand Down
12 changes: 0 additions & 12 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@

import asyncio
import gc
import json
import math
import random
import shutil
Expand Down Expand Up @@ -369,17 +368,6 @@ class Args:
"""whether to apply a performance penalty to the code verifier"""


def process_dataset_mixer(value) -> tuple[dict | None, str | None]:
# if passed through cli: convert the dataset mixers to dictionaries
if isinstance(value, str):
return json.loads(value), value
# if passed through yaml: convert the dataset mixers to strings
elif isinstance(value, dict):
return value, json.dumps(value)
else:
raise ValueError("Input must be either a string or a dictionary")


def get_train_ds_config(
offload,
adam_offload=False,
Expand Down
Loading