Skip to content

Commit af06a81

Browse files
authored
Merge branch 'main' into cache-key-and-lock
2 parents 97c4b03 + 60c1c3e commit af06a81

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+850
-329
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ https://github.com/IBM/unitxt/assets/23455264/baef9131-39d4-4164-90b2-05da52919f
3131

3232
### 🦄 Currently on Unitxt Catalog
3333

34-
![Abstract Tasks](https://img.shields.io/badge/Abstract_Tasks-62-blue)
35-
![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-3025-blue)
34+
![Abstract Tasks](https://img.shields.io/badge/Abstract_Tasks-64-blue)
35+
![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-3174-blue)
3636
![Templates](https://img.shields.io/badge/Templates-342-blue)
37-
![Benchmarks](https://img.shields.io/badge/Benchmarks-4-blue)
38-
![Metrics](https://img.shields.io/badge/Metrics-422-blue)
37+
![Benchmarks](https://img.shields.io/badge/Benchmarks-6-blue)
38+
![Metrics](https://img.shields.io/badge/Metrics-462-blue)
3939

4040
### 🦄 Run Unitxt Exploration Dashboard
4141

docs/_static/custom.css

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,8 @@ div.document div.documentwrapper {
206206
.red {
207207
color: red;
208208
}
209+
210+
#unitxtImports {
211+
/* Display nothing for the element */
212+
display: none;
213+
}

docs/catalog.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
22
import os
33
import re
4+
from collections import defaultdict
45
from functools import lru_cache
56
from pathlib import Path
7+
from typing import List
68

79
from docutils.core import publish_parts
810
from pygments import highlight
@@ -43,6 +45,44 @@ def dict_to_syntax_highlighted_html(nested_dict):
4345
# Apply syntax highlighting
4446
return highlight(py_str, PythonLexer(), formatter)
4547

48+
def imports_to_syntax_highlighted_html(subtypes: List[str])-> str:
49+
if len(subtypes) == 0:
50+
return ""
51+
module_to_class_names = defaultdict(list)
52+
for subtype in subtypes:
53+
subtype_class = Artifact._class_register.get(subtype)
54+
module_to_class_names[subtype_class.__module__].append(subtype_class.__name__)
55+
56+
imports_txt = ""
57+
for modu in sorted(module_to_class_names.keys()):
58+
classes_string = ", ".join(sorted(module_to_class_names[modu]))
59+
imports_txt += f"from {modu} import {classes_string}\n"
60+
61+
formatter = HtmlFormatter(nowrap=True)
62+
htm = highlight(imports_txt, PythonLexer(), formatter)
63+
64+
imports_html = f'\n<p><div><pre><span id="unitxtImports">{htm}</span></pre>\n'
65+
imports_html += """<button onclick="toggleText()" id="textButton">
66+
Show Imports
67+
</button>
68+
69+
<script>
70+
function toggleText() {
71+
let showImports = document.getElementById("unitxtImports");
72+
let buttonText = document.getElementById("textButton");
73+
if (showImports.style.display === "none" || showImports.style.display === "") {
74+
showImports.style.display = "inline";
75+
buttonText.innerHTML = "Close";
76+
}
77+
78+
else {
79+
showImports.style.display = "none";
80+
buttonText.innerHTML = "Show Imports";
81+
}
82+
}
83+
</script>
84+
</div></p>\n"""
85+
return imports_html
4686

4787
def write_title(title, label):
4888
title = f"📁 {title}"
@@ -177,26 +217,29 @@ def make_content(artifact, label, all_labels):
177217

178218
# Replacement function
179219
html_for_dict = re.sub(pattern, r"\1\2\3", html_for_dict)
220+
221+
subtypes = all_subtypes_of_artifact(artifact)
222+
subtypes = list(set(subtypes))
223+
subtypes.remove(artifact_type) # this was already documented
224+
html_for_imports = imports_to_syntax_highlighted_html(subtypes)
225+
180226
source_link = f"""<a class="reference external" href="https://github.com/IBM/unitxt/blob/main/src/unitxt/catalog/{catalog_id.replace(".", "/")}.json"><span class="viewcode-link"><span class="pre">[source]</span></span></a>"""
181-
html_for_dict = f"""<div class="admonition note">
227+
html_for_element = f"""<div class="admonition note">
182228
<p class="admonition-title">{catalog_id}</p>
183229
<div class="highlight-json notranslate">
184230
<div class="highlight"><pre>
185231
{html_for_dict.strip()}
186-
</pre>{source_link}</div></div>
232+
</pre>{source_link}{html_for_imports.strip()}</div></div>
187233
</div>""".replace("\n", "\n ")
188234

189-
result += " " + html_for_dict + "\n"
235+
result += " " + html_for_element + "\n"
190236

191237
if artifact_class.__doc__:
192238
explanation_str = f"Explanation about `{type_class_name}`"
193239
result += f"\n{explanation_str}\n"
194240
result += "+" * len(explanation_str) + "\n\n"
195241
result += artifact_class.__doc__ + "\n"
196242

197-
subtypes = all_subtypes_of_artifact(artifact)
198-
subtypes = list(set(subtypes))
199-
subtypes.remove(artifact_type) # this was already documented
200243
for subtype in subtypes:
201244
subtype_class = Artifact._class_register.get(subtype)
202245
subtype_class_name = subtype_class.__name__

examples/evaluate_existing_dataset_with_install.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424
results = evaluate(predictions=predictions, data=dataset)
2525

26-
print("Global Results:")
27-
print(results.global_scores.summary)
2826

2927
print("Instance Results:")
30-
print(results.instance_scores.summary)
28+
print(results.instance_scores)
29+
30+
print("Global Results:")
31+
print(results.global_scores.summary)

prepare/cards/ag_news.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from datasets import load_dataset_builder
21
from unitxt import add_to_catalog
32
from unitxt.blocks import (
43
LoadHF,
@@ -9,26 +8,16 @@
98
)
109
from unitxt.test_utils.card import test_card
1110

12-
dataset_name = "ag_news"
13-
14-
ds_builder = load_dataset_builder(dataset_name)
15-
classlabels = ds_builder.info.features["label"]
16-
17-
mappers = {}
18-
for i in range(len(classlabels.names)):
19-
mappers[str(i)] = classlabels.names[i]
20-
21-
2211
card = TaskCard(
23-
loader=LoadHF(path=f"{dataset_name}"),
12+
loader=LoadHF(path="fancyzhx/ag_news"),
2413
preprocess_steps=[
2514
SplitRandomMix(
2615
{"train": "train[87.5%]", "validation": "train[12.5%]", "test": "test"}
2716
),
28-
MapInstanceValues(mappers={"label": mappers}),
17+
MapInstanceValues(mappers={"label": {"0": "World", "1": "Sports", "2": "Business", "3": "Sci/Tech"}}),
2918
Set(
3019
fields={
31-
"classes": classlabels.names,
20+
"classes": ["World", "Sports", "Business", "Sci/Tech"],
3221
"text_type": "sentence",
3322
}
3423
),
@@ -52,4 +41,4 @@
5241
),
5342
)
5443
test_card(card, debug=False)
55-
add_to_catalog(card, f"cards.{dataset_name}", overwrite=True)
44+
add_to_catalog(card, "cards.ag_news", overwrite=True)

prepare/cards/head_qa.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
Set,
77
TaskCard,
88
)
9+
from unitxt.splitters import RenameSplits
910
from unitxt.test_utils.card import test_card
1011

1112
with unitxt.settings.context(allow_unverified_code=True):
12-
for subset in ["es", "en"]:
13+
for subset in ["es", "en", "gl", "it", "ru"]:
1314
card = TaskCard(
14-
loader=LoadHF(path="dvilares/head_qa", name=subset),
15+
loader=LoadHF(path="alesi12/head_qa_v2", name=subset),
1516
preprocess_steps=[
17+
RenameSplits({"train": "test"}),
1618
Rename(field_to_field={"qtext": "text", "category": "label"}),
1719
Set(
1820
fields={
@@ -46,5 +48,6 @@
4648
"task_ids": "multiple-choice-qa",
4749
},
4850
)
49-
test_card(card, debug=False)
51+
if subset == "es":
52+
test_card(card, debug=False)
5053
add_to_catalog(card, f"cards.head_qa.{subset}", overwrite=True)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from unitxt import add_to_catalog
2+
from unitxt.blocks import (
3+
InputOutputTemplate,
4+
Task,
5+
TaskCard,
6+
)
7+
from unitxt.loaders import LoadCSV
8+
from unitxt.operators import Rename
9+
from unitxt.templates import TemplatesDict
10+
from unitxt.test_utils.card import test_card
11+
12+
card = TaskCard(
13+
loader=LoadCSV(files={ "test": "https://raw.githubusercontent.com/mlcommons/ailuminate/0fe054f3d34209a0a1acf1b95174c5e8926f5c57/airr_official_1.0_practice_prompt_set_release_public_subset.csv"}, # pragma: allowlist secret
14+
data_classification_policy=["public"]),
15+
preprocess_steps=[
16+
Rename(field_to_field={"prompt_text": "input",
17+
"hazard": "label" }),
18+
],
19+
task=Task(
20+
input_fields={"release_prompt_id": str, "input": str, "label": str, "persona": str, "locale": str, "prompt_hash": str},
21+
reference_fields={"label": str},
22+
prediction_type=str,
23+
metrics=["metrics.llm_as_judge.safety.llamaguard"]
24+
),
25+
templates=TemplatesDict({"default": InputOutputTemplate(input_format="{input}\n", output_format="")}),
26+
27+
__description__="The DEMO prompt library of the AILuminate 1.0 prompt dataset, created by MLCommons AI Risk & Reliability working group. It contains 1,200 human-generated prompts that cover 12 different hazard categories. The dataset is a mix of adversarial and non-adversarial prompts. This dataset is a 10% subset of the full AILuminate training dataset.",
28+
__tags__={
29+
"languages": ["english"]
30+
},
31+
)
32+
33+
test_card(
34+
card,
35+
strict=False,
36+
demos_taken_from="test",
37+
num_demos=0,
38+
)
39+
40+
add_to_catalog(card, "cards.safety.mlcommons_ailuminate", overwrite=True)

prepare/cards/safety/simple_safety_tests.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unitxt import add_to_catalog
22
from unitxt.blocks import (
33
LoadHF,
4+
Task,
45
TaskCard,
56
)
67
from unitxt.operators import (
@@ -17,14 +18,19 @@
1718
Rename(field_to_field={"prompt": "input"}),
1819
Set(
1920
fields={
20-
"output": "None",
21-
"type_of_input": "question",
22-
"type_of_output": "answer",
23-
"prediction_type": "str",
21+
"output": "",
2422
}
2523
),
2624
],
27-
task="tasks.generation",
25+
task=Task(
26+
input_fields={"input": str, "id": str, "harm_area": str, "category": str},
27+
reference_fields={"output": str},
28+
prediction_type=str,
29+
metrics=[
30+
"metrics.granite_guardian.assistant_risk.harm[prediction_type=str,user_message_field=input,assistant_message_field=output,score_prefix=graniteguardian_]",
31+
"metrics.llm_as_judge.safety.llamaguard[score_prefix=llamaguard_]"
32+
],
33+
),
2834
templates=["templates.generation.empty"],
2935
__description__="100 test prompts across five harm areas that LLMs, for the vast majority of applications, should refuse to comply with.",
3036
__tags__={
@@ -33,7 +39,6 @@
3339
},
3440
)
3541

36-
test_card(
37-
card, format="formats.empty", strict=False, demos_taken_from="test", num_demos=0
38-
)
42+
test_card(card, strict=False, demos_taken_from="test", num_demos=0)
43+
3944
add_to_catalog(card, "cards.safety.simple_safety_tests", overwrite=True)

prepare/cards/sst2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from unitxt.blocks import LoadHF, MapInstanceValues, TaskCard
22
from unitxt.catalog import add_to_catalog
3-
from unitxt.operators import ExtractFieldValues, Rename, Set
3+
from unitxt.operators import Rename, Set
44
from unitxt.test_utils.card import test_card
55

66
card = TaskCard(
7-
loader=LoadHF(path="glue", name="sst2"),
7+
loader=LoadHF(path="stanfordnlp/sst2"),
88
preprocess_steps=[
99
"splitters.small_no_test",
1010
MapInstanceValues(mappers={"label": {"0": "negative", "1": "positive"}}),
@@ -13,9 +13,9 @@
1313
fields={
1414
"text_type": "sentence",
1515
"type_of_class": "sentiment",
16+
"classes": ["negative", "positive"]
1617
}
1718
),
18-
ExtractFieldValues(field="label", to_field="classes", stream_name="train"),
1919
],
2020
task="tasks.classification.multi_class",
2121
templates="templates.classification.multi_class.all",

prepare/cards/wnli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from unitxt.test_utils.card import test_card
1111

1212
card = TaskCard(
13-
loader=LoadHF(path="glue", name="wnli", data_classification_policy=["public"]),
13+
loader=LoadHF(path="nyu-mll/glue", name="wnli", data_classification_policy=["public"]),
1414
preprocess_steps=[
1515
SplitRandomMix(
1616
{"train": "train[95%]", "validation": "train[5%]", "test": "validation"}
@@ -57,7 +57,7 @@
5757

5858

5959
card = TaskCard(
60-
loader=LoadHF(path="glue", name="wnli"),
60+
loader=LoadHF(path="nyu-mll/glue", name="wnli"),
6161
preprocess_steps=[
6262
SplitRandomMix(
6363
{"train": "train[95%]", "validation": "train[5%]", "test": "validation"}
Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
from unitxt import add_to_catalog
2-
from unitxt.inference import IbmGenAiInferenceEngine
2+
from unitxt.inference import CrossProviderInferenceEngine
33
from unitxt.llm_as_judge_from_template import LLMAsJudge
44
from unitxt.random_utils import get_seed
55

6-
model_list = [
7-
"meta-llama/llama-3-8b-instruct",
8-
"meta-llama/llama-3-70b-instruct",
9-
] # will point to llamaguard2
10-
format = "formats.llama3_instruct"
11-
template = "templates.safety.unsafe_content"
6+
model = "llama-3-3-70b-instruct"
7+
format = "formats.chat_api"
8+
template = "templates.safety.llamaguard"
129
task = "rating.single_turn"
1310

14-
for model_id in model_list:
15-
inference_model = IbmGenAiInferenceEngine(
16-
model_name=model_id, max_new_tokens=252, random_seed=get_seed()
17-
)
18-
model_label = model_id.split("/")[1].replace("-", "_").replace(".", ",").lower()
19-
model_label = f"{model_label}_ibm_genai"
20-
template_label = template.split(".")[-1]
21-
metric_label = f"{model_label}_template_{template_label}"
22-
metric = LLMAsJudge(
23-
inference_model=inference_model,
24-
template=template,
25-
task=task,
26-
format=format,
27-
main_score=metric_label,
28-
)
11+
inference_model = CrossProviderInferenceEngine(
12+
model=model, max_tokens=20, seed=get_seed(), temperature=1e-7
13+
)
2914

30-
add_to_catalog(
31-
metric,
32-
f"metrics.llm_as_judge.safety.{model_label}_template_{template_label}",
33-
overwrite=True,
34-
)
15+
model_label = (
16+
model.replace("-", "_").replace(".", ",").lower() + "_cross_provider"
17+
)
18+
19+
template_label = template.split(".")[-1]
20+
21+
metric_label = f"{model_label}_template_{template_label}"
22+
23+
metric = LLMAsJudge(
24+
inference_model=inference_model,
25+
template=template,
26+
task=task,
27+
format=format,
28+
main_score=metric_label,
29+
)
30+
31+
add_to_catalog(
32+
metric,
33+
"metrics.llm_as_judge.safety.llamaguard",
34+
overwrite=True,
35+
)

0 commit comments

Comments
 (0)