Skip to content

Commit f7796d4

Browse files
committed
adding dpo loss
1 parent 40c96c8 commit f7796d4

File tree

5 files changed

+148
-12
lines changed

5 files changed

+148
-12
lines changed

fast_llm/functional/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,7 @@ class CrossEntropyImpl(str, enum.Enum):
9191
torch = "torch"
9292
fused = "fused"
9393
triton = "triton"
94+
95+
class LossFunctionType(str, enum.Enum):
96+
cross_entropy = "cross_entropy"
97+
dpo = "dpo"

fast_llm/functional/dpo.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from typing import Tuple
4+
5+
6+
def compute_logps_for_spans(
7+
logits: torch.Tensor,
8+
targets: torch.Tensor,
9+
chosen_span: torch.Tensor,
10+
rejected_span: torch.Tensor
11+
):
12+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
13+
14+
# gather log probabilities corresponding to the target tokens
15+
# selected_log_probs = log_probs[torch.arange(logits.shape[0] - 1), targets]
16+
selected_log_probs = log_probs[:-1].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
17+
18+
# apply chosen mask
19+
chosen_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool)
20+
chosen_mask[chosen_span[:, 0]: chosen_span[:, 1] + 1] = 1
21+
chosen_logp = (selected_log_probs * chosen_mask).sum()
22+
23+
# apply rejected mask
24+
rejected_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool)
25+
rejected_mask[rejected_span[:, 0]: rejected_span[:, 1] + 1] = 1
26+
rejected_logp = (selected_log_probs * rejected_mask).sum()
27+
28+
# chosen_logp = selected_log_probs[chosen_span[:, 0]: chosen_span[:, 1] + 1].sum()
29+
# rejected_logp = selected_log_probs[rejected_span[:, 0]: rejected_span[:, 1] + 1].sum()
30+
31+
return chosen_logp, rejected_logp
32+
33+
def compute_simplified_dpo_loss(
34+
logits: torch.Tensor,
35+
targets: torch.Tensor,
36+
chosen_span: torch.Tensor,
37+
rejected_span: torch.Tensor,
38+
beta: float,
39+
grad_output: float | None
40+
) -> Tuple[torch.Tensor, torch.Tensor]:
41+
with torch.enable_grad():
42+
logits_ = logits.float().detach().requires_grad_()
43+
44+
policy_chosen_logps, policy_rejected_logps = compute_logps_for_spans(logits_, targets, chosen_span, rejected_span)
45+
46+
pi_logratios = policy_chosen_logps - policy_rejected_logps
47+
48+
losses = -F.logsigmoid(beta * pi_logratios)
49+
if grad_output is None:
50+
loss = None
51+
else:
52+
loss = losses.mean()
53+
loss.backward(torch.full_like(loss, grad_output))
54+
loss.detach()
55+
return loss.detach(), logits_.grad.detach().to(logits.dtype)

fast_llm/layers/language_model/config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
55
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
66
from fast_llm.engine.distributed.config import DistributedDimNames
7-
from fast_llm.functional.config import CrossEntropyImpl
7+
from fast_llm.functional.config import CrossEntropyImpl, LossFunctionType
88
from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig
99
from fast_llm.utils import Assert
1010

@@ -28,6 +28,8 @@ class LanguageModelKwargs:
2828
# TODO: These are generic
2929
labels = "labels"
3030
phase = "phase"
31+
chosen_spans = "chosen_spans"
32+
rejected_spans = "rejected_spans"
3133

3234

3335
@config_class()
@@ -128,6 +130,16 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):
128130
desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).",
129131
hint=FieldHint.feature,
130132
)
133+
loss_function_type: LossFunctionType = Field(
134+
default=LossFunctionType.cross_entropy,
135+
desc="Type of loss function to use",
136+
hint=FieldHint.feature,
137+
)
138+
beta: float | None = Field(
139+
default=1.0,
140+
desc="Beta value for DPO loss.",
141+
hint=FieldHint.feature,
142+
)
131143
cross_entropy_impl: CrossEntropyImpl = Field(
132144
default=CrossEntropyImpl.auto,
133145
desc="Implementation for the cross-entropy computation.",

fast_llm/layers/language_model/head.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace
1111
from fast_llm.engine.distributed.config import DistributedDimNames
1212
from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward
13-
from fast_llm.functional.config import CrossEntropyImpl, TritonConfig
13+
from fast_llm.functional.config import CrossEntropyImpl, TritonConfig, LossFunctionType
1414
from fast_llm.functional.cross_entropy import cross_entropy_forward_backward
1515
from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward
16+
from fast_llm.functional.dpo import compute_simplified_dpo_loss
1617
from fast_llm.layers.common.auxiliary_loss import z_loss
1718
from fast_llm.layers.language_model.config import (
1819
LanguageModelBaseConfig,
@@ -74,14 +75,20 @@ def __init__(
7475
),
7576
)
7677

77-
self._cross_entropy_impl = config.cross_entropy_impl
78-
if self._cross_entropy_impl == CrossEntropyImpl.auto:
79-
if self._parallel_embeddings:
80-
self._cross_entropy_impl = CrossEntropyImpl.fused
81-
elif TritonConfig.TRITON_ENABLED:
82-
self._cross_entropy_impl = CrossEntropyImpl.triton
83-
else:
84-
self._cross_entropy_impl = CrossEntropyImpl.fused
78+
self._loss_function_type = config.loss_function_type
79+
if self._loss_function_type == LossFunctionType.cross_entropy:
80+
self._cross_entropy_impl = config.cross_entropy_impl
81+
if self._cross_entropy_impl == CrossEntropyImpl.auto:
82+
if self._parallel_embeddings:
83+
self._cross_entropy_impl = CrossEntropyImpl.fused
84+
elif TritonConfig.TRITON_ENABLED:
85+
self._cross_entropy_impl = CrossEntropyImpl.triton
86+
else:
87+
self._cross_entropy_impl = CrossEntropyImpl.fused
88+
self._loss_fcn = self._logits_cross_entropy_forward_backward_split
89+
else:
90+
self._loss_fcn = self._logits_dpo
91+
self.dpo_beta = config.beta
8592

8693
self._forward = wrap_forward_backward(self._forward_backward, grad_is_context)
8794

@@ -127,7 +134,7 @@ def _forward_backward(
127134
)
128135

129136
output_weights = kwargs[WORD_EMBEDDINGS_WEIGHT] if self._tie_word_embeddings else self.output_weights
130-
loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split(
137+
loss, ln_output_grad = self._loss_fcn(
131138
ln_output.detach(), labels, output_weights, grad_output, kwargs, losses
132139
)
133140

@@ -136,6 +143,38 @@ def _forward_backward(
136143
return loss, input_.grad
137144
else:
138145
return loss, None
146+
147+
def _logits_dpo(
148+
self,
149+
input_: torch.Tensor,
150+
labels: torch.Tensor | None,
151+
weight: torch.Tensor,
152+
grad_output: float,
153+
kwargs: dict,
154+
losses: dict | None = None
155+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
156+
logits, context = output_parallel_linear_forward(
157+
input_=input_,
158+
weight=weight,
159+
bias=None,
160+
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
161+
sequence_parallel=self._sequence_parallel and self._parallel_embeddings,
162+
)
163+
164+
loss, grad = compute_simplified_dpo_loss(
165+
logits.flatten(0, -2),
166+
labels,
167+
kwargs[LanguageModelKwargs.chosen_spans],
168+
kwargs[LanguageModelKwargs.rejected_spans],
169+
self.dpo_beta,
170+
grad_output
171+
)
172+
173+
# TODO: de-allocate earlier.
174+
del logits
175+
return loss, output_parallel_linear_backward(grad, context).view_as(input_)
176+
177+
139178

140179
def _logits_cross_entropy_forward_backward_split(
141180
self,

fast_llm/models/gpt/model.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def preprocess(
254254
TransformerKwargs.presents: presents,
255255
}
256256
if phase != PhaseType.inference:
257-
sequence_offset = sequence_k - sequence_q + 1
257+
sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels
258258
if sequence_first:
259259
labels = batch.token_ids[sequence_offset : sequence_k + 1]
260260
else:
@@ -266,8 +266,10 @@ def preprocess(
266266
for i, spans in enumerate(batch.loss_masking_spans):
267267
if not spans.numel():
268268
continue
269+
# filter spans within the sequence or partially within the sequence
269270
valid_spans = spans[(spans[:, 0] <= sequence_k) & (spans[:, 1] >= sequence_offset)]
270271
if valid_spans.numel():
272+
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
271273
valid_spans[:, 0].clamp_(min=sequence_offset)
272274
valid_spans[:, 1].clamp_(max=sequence_k)
273275
valid_spans -= sequence_offset
@@ -276,6 +278,30 @@ def preprocess(
276278
labels[start : end + 1, i] = -100
277279
else:
278280
labels[i, start : end + 1] = -100
281+
if batch.chosen_loss_masking_spans is not None:
282+
for i, spans in enumerate(batch.chosen_loss_masking_spans):
283+
if not spans.numel():
284+
continue
285+
# filter spans within the sequence or partially within the sequence
286+
valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)]
287+
if valid_spans.numel():
288+
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
289+
valid_spans[:, 0].clamp_(min=sequence_offset)
290+
valid_spans[:, 1].clamp_(max=sequence_k)
291+
valid_spans -= sequence_offset
292+
kwargs[LanguageModelKwargs.chosen_spans] = valid_spans
293+
if batch.rejected_loss_masking_spans is not None:
294+
for i, spans in enumerate(batch.rejected_loss_masking_spans):
295+
if not spans.numel():
296+
continue
297+
# filter spans within the sequence or partially within the sequence
298+
valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)]
299+
if valid_spans.numel():
300+
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
301+
valid_spans[:, 0].clamp_(min=sequence_offset)
302+
valid_spans[:, 1].clamp_(max=sequence_k)
303+
valid_spans -= sequence_offset
304+
kwargs[LanguageModelKwargs.rejected_spans] = valid_spans
279305
kwargs[LanguageModelKwargs.labels] = labels
280306
if self._config.use_absolute_position_embeddings:
281307
self._position_embedding_preprocessor.preprocess(kwargs)

0 commit comments

Comments
 (0)