-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] Add Task for multi-turn dialogue distillation #1120
Labels
enhancement
New feature or request
Comments
Hi @AndreasMadsen, let me know if this helps: from typing import TYPE_CHECKING, Any, Dict, Union
from distilabel.steps import StepInput
from distilabel.steps.tasks import Task
from distilabel.models import MlxLLM
from distilabel.pipeline import Pipeline
from datasets import Dataset
if TYPE_CHECKING:
from distilabel.typing import StepColumns, StepOutput, ChatType
class AssistantResponseDistillation(Task):
@property
def inputs(self) -> "StepColumns":
return ["messages"]
def format_input(self, input: Dict[str, Any]) -> dict[str, Any]:
return {}
@property
def outputs(self) -> "StepColumns":
return ["new_messages"]
def format_output(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
return {}
def _get_user_messages(self, inputs: StepInput) -> tuple[list["ChatType"], int]:
conversations = []
max_turns = -1
for input in inputs:
conversation = []
user_turns = 0
for message in input["messages"]:
if message["role"] == "user":
user_turns += 1
conversation.append(message)
conversations.append(conversation)
if max_turns < user_turns:
max_turns = user_turns
return conversations, max_turns
def process(self, inputs: StepInput) -> "StepOutput":
user_messages, max_turns = self._get_user_messages(inputs)
conversations = []
active_indices = []
for i in range(len(user_messages)):
conversations.append([])
active_indices.append(i)
for i in range(max_turns):
for idx, messages in enumerate(user_messages):
if idx not in active_indices:
continue
if len(messages) > i:
user_message = messages[i]
conversations[idx].append(user_message)
else:
active_indices.pop(idx)
outputs = self.llm.generate(
inputs=[conversations[idx] for idx in active_indices],
num_generations=1,
**self.llm.generation_kwargs,
)
active_conversations = [conversations[idx] for idx in active_indices]
for conversation, output in zip(active_conversations, outputs):
conversation.append(
{
"role": "assistant",
"content": output["generations"][0],
}
)
for input, conversation in zip(inputs, conversations):
input["new_messages"] = conversation
yield inputs
with Pipeline() as pipeline:
AssistantResponseDistillation(
llm=MlxLLM(path_or_hf_repo="Qwen/Qwen2.5-0.5B-Instruct")
)
if __name__ == "__main__":
dataset = Dataset.from_list(
[
{
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{
"role": "assistant",
"content": "I'm doing great. How can I help you today?",
},
{
"role": "user",
"content": "I'd like to show off how chat templating works!",
},
{"role": "assistant", "content": "Okay, let me show you ..."},
]
},
{
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{
"role": "assistant",
"content": "I'm doing great. How can I help you today?",
},
]
},
]
)
distiset = pipeline.run(dataset=dataset)
for row in distiset["default"]["train"]:
messages = row["new_messages"]
print("-" * 100)
for message in messages:
print(f"{message['role']}: {message['content']}") |
Hi @gabrielmbmb, thanks a lot of for demonstrating how this can be done with distilable. Although, I'm not really a fan of how I have to store all of the observations in memory and schedule all requests simultaneously, it would be great if there was a iterable pipeline. But thanks anyway :) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is your feature request related to a problem? Please describe.
I have a dataset like:
And I would like to distill the all the
"role": "assistant"
answers from a newer model, assuming the user will ask the same question. However, it appears that distilabel doesn't support this kind of recursive interaction with an LLM.Describe the solution you'd like
Instead I would like to able to have an LLM recursively fill in the assistant responses.
I apologize if this is already supported, but I couldn't see it anywhere.
Describe alternatives you've considered
The only alternative choice would be to split the observation as:
However, this is not the same, as I would still be distilling with content from an older LLM. I'm also not sure how to even do this kind of input expansion. Meaning, turning one observation into multiple observations.
Additional context
No response
The text was updated successfully, but these errors were encountered: