Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add Task for multi-turn dialogue distillation #1120

Closed
AndreasMadsen opened this issue Feb 11, 2025 · 2 comments
Closed

[FEATURE] Add Task for multi-turn dialogue distillation #1120

AndreasMadsen opened this issue Feb 11, 2025 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@AndreasMadsen
Copy link

AndreasMadsen commented Feb 11, 2025

Is your feature request related to a problem? Please describe.

I have a dataset like:

[
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
  {"role": "user", "content": "I'd like to show off how chat templating works!"},
  {"role": "assistant", "content": "Okay, let me show you ..."},
]

And I would like to distill the all the "role": "assistant" answers from a newer model, assuming the user will ask the same question. However, it appears that distilabel doesn't support this kind of recursive interaction with an LLM.

Describe the solution you'd like

Instead I would like to able to have an LLM recursively fill in the assistant responses.

input = [
  {"role": "user", "content": "Hello, how are you?"},
]
output = [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": new_response_1},
]
input = [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": new_response_1},
  {"role": "user", "content": "I'd like to show off how chat templating works!"},
]
output = [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": new_response_1},
  {"role": "user", "content": "I'd like to show off how chat templating works!"},
  {"role": "assistant", "content": new_response_2},
]

I apologize if this is already supported, but I couldn't see it anywhere.

Describe alternatives you've considered

The only alternative choice would be to split the observation as:

input = [
  {"role": "user", "content": "Hello, how are you?"},
]
output = [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": new_response_1},
]
input = [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
  {"role": "user", "content": "I'd like to show off how chat templating works!"},
]
output = [
  {"role": "user", "content": "Hello, how are you?"},
  {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
  {"role": "user", "content": "I'd like to show off how chat templating works!"},
  {"role": "assistant", "content": new_response_2},
]

However, this is not the same, as I would still be distilling with content from an older LLM. I'm also not sure how to even do this kind of input expansion. Meaning, turning one observation into multiple observations.

Additional context

No response

@AndreasMadsen AndreasMadsen added the enhancement New feature or request label Feb 11, 2025
@AndreasMadsen AndreasMadsen changed the title [FEATURE] Add Task for Dialogue distillation [FEATURE] Add Task for multi-turn dialogue distillation Feb 11, 2025
@gabrielmbmb gabrielmbmb self-assigned this Feb 12, 2025
@gabrielmbmb
Copy link
Member

Hi @AndreasMadsen, let me know if this helps:

from typing import TYPE_CHECKING, Any, Dict, Union
from distilabel.steps import StepInput
from distilabel.steps.tasks import Task
from distilabel.models import MlxLLM
from distilabel.pipeline import Pipeline
from datasets import Dataset

if TYPE_CHECKING:
    from distilabel.typing import StepColumns, StepOutput, ChatType


class AssistantResponseDistillation(Task):
    @property
    def inputs(self) -> "StepColumns":
        return ["messages"]

    def format_input(self, input: Dict[str, Any]) -> dict[str, Any]:
        return {}

    @property
    def outputs(self) -> "StepColumns":
        return ["new_messages"]

    def format_output(
        self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
    ) -> Dict[str, Any]:
        return {}

    def _get_user_messages(self, inputs: StepInput) -> tuple[list["ChatType"], int]:
        conversations = []
        max_turns = -1
        for input in inputs:
            conversation = []
            user_turns = 0
            for message in input["messages"]:
                if message["role"] == "user":
                    user_turns += 1
                    conversation.append(message)
            conversations.append(conversation)
            if max_turns < user_turns:
                max_turns = user_turns
        return conversations, max_turns

    def process(self, inputs: StepInput) -> "StepOutput":
        user_messages, max_turns = self._get_user_messages(inputs)

        conversations = []
        active_indices = []
        for i in range(len(user_messages)):
            conversations.append([])
            active_indices.append(i)

        for i in range(max_turns):
            for idx, messages in enumerate(user_messages):
                if idx not in active_indices:
                    continue

                if len(messages) > i:
                    user_message = messages[i]
                    conversations[idx].append(user_message)
                else:
                    active_indices.pop(idx)

            outputs = self.llm.generate(
                inputs=[conversations[idx] for idx in active_indices],
                num_generations=1,
                **self.llm.generation_kwargs,
            )

            active_conversations = [conversations[idx] for idx in active_indices]
            for conversation, output in zip(active_conversations, outputs):
                conversation.append(
                    {
                        "role": "assistant",
                        "content": output["generations"][0],
                    }
                )

        for input, conversation in zip(inputs, conversations):
            input["new_messages"] = conversation

        yield inputs


with Pipeline() as pipeline:
    AssistantResponseDistillation(
        llm=MlxLLM(path_or_hf_repo="Qwen/Qwen2.5-0.5B-Instruct")
    )


if __name__ == "__main__":
    dataset = Dataset.from_list(
        [
            {
                "messages": [
                    {"role": "user", "content": "Hello, how are you?"},
                    {
                        "role": "assistant",
                        "content": "I'm doing great. How can I help you today?",
                    },
                    {
                        "role": "user",
                        "content": "I'd like to show off how chat templating works!",
                    },
                    {"role": "assistant", "content": "Okay, let me show you ..."},
                ]
            },
            {
                "messages": [
                    {"role": "user", "content": "Hello, how are you?"},
                    {
                        "role": "assistant",
                        "content": "I'm doing great. How can I help you today?",
                    },
                ]
            },
        ]
    )

    distiset = pipeline.run(dataset=dataset)

    for row in distiset["default"]["train"]:
        messages = row["new_messages"]
        print("-" * 100)
        for message in messages:
            print(f"{message['role']}: {message['content']}")

@AndreasMadsen
Copy link
Author

Hi @gabrielmbmb, thanks a lot of for demonstrating how this can be done with distilable. Although, I'm not really a fan of how I have to store all of the observations in memory and schedule all requests simultaneously, it would be great if there was a iterable pipeline. But thanks anyway :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants