Skip to content

Commit 10605aa

Browse files
authored
Merge pull request #22 from modelscope/autorubric_gt
[update] autorubric src
2 parents a6c03c7 + f5e4f3a commit 10605aa

File tree

7 files changed

+768
-28
lines changed

7 files changed

+768
-28
lines changed

examples/rubric/analysis.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Rubric Analysis Runner Script
4+
5+
Evaluate rubric performance on validation datasets using comprehensive metrics.
6+
This script analyzes generated or structured rubrics to assess their quality,
7+
coverage, precision, and contribution to ensemble performance.
8+
9+
This is useful for:
10+
1. Evaluating rubric quality and effectiveness
11+
2. Comparing different rubric sets or generation methods
12+
3. Analyzing individual rubric contributions to ensemble performance
13+
14+
Features:
15+
- Comprehensive rubric evaluation (Coverage, Precision, Contribution)
16+
- Ensemble accuracy calculation with multiple rubrics
17+
- Source vs. Target rubric comparison analysis
18+
- Multithreaded evaluation for high performance
19+
- Detailed statistics and performance metrics
20+
21+
"""
22+
23+
import argparse
24+
import json
25+
import sys
26+
import time
27+
from pathlib import Path
28+
from typing import List
29+
30+
from rm_gallery.core.reward.rubric.analyzer import EvaluationConfig, RubricAnalyzer
31+
32+
33+
def load_rubrics(rubrics_path: str) -> List[str]:
34+
"""Load rubrics from JSON file"""
35+
with open(rubrics_path, "r", encoding="utf-8") as f:
36+
rubrics = json.load(f)
37+
38+
if isinstance(rubrics, list):
39+
return rubrics
40+
else:
41+
raise ValueError(f"Invalid rubrics format in {rubrics_path}")
42+
43+
44+
def run_analysis(
45+
rubrics_path: str,
46+
dataset_path: str,
47+
model: str = "qwen3-32b",
48+
max_samples: int = 100,
49+
max_workers: int = 256,
50+
output_dir: str = None,
51+
source_rubrics_path: str = None,
52+
):
53+
"""
54+
Run comprehensive rubric analysis
55+
56+
Args:
57+
rubrics_path: Path to target rubrics (main evaluation set)
58+
dataset_path: Path to validation dataset
59+
model: LLM model name for evaluation
60+
max_samples: Maximum samples to evaluate
61+
max_workers: Number of worker threads
62+
output_dir: Output directory for results
63+
source_rubrics_path: Optional path to source rubrics for comparison
64+
65+
Note:
66+
- Target rubrics: Calculate Coverage, Precision, and Contribution
67+
- Source rubrics: Calculate only Coverage and Precision (for comparison baseline)
68+
"""
69+
print("🔍 Running Rubric Analysis")
70+
print("=" * 50)
71+
72+
# Load target rubrics
73+
rubrics = load_rubrics(rubrics_path)
74+
75+
print(f"✅ Loaded {len(rubrics)} target rubrics")
76+
77+
# Load source rubrics (optional)
78+
source_rubrics = []
79+
if source_rubrics_path:
80+
source_rubrics = load_rubrics(source_rubrics_path)
81+
print(f"✅ Loaded {len(source_rubrics)} source rubrics")
82+
83+
print(f"🔧 Using {max_workers} worker threads for parallel processing")
84+
85+
# Initialize analyzer with multithreading support
86+
config = EvaluationConfig(
87+
model=model,
88+
max_workers=max_workers, # Configurable worker threads
89+
optimization_strategy="sampling",
90+
target_sample_ratio=1.0,
91+
)
92+
93+
analyzer = RubricAnalyzer(config)
94+
95+
# Load dataset
96+
dataset = analyzer.load_dataset(
97+
dataset_path, domains=["general"], max_samples=max_samples
98+
)
99+
100+
print(f"✅ Loaded {len(dataset)} validation samples")
101+
102+
# Evaluate target rubrics
103+
print("\n🎯 Evaluating target rubrics...")
104+
ensemble_accuracy, metrics = analyzer.evaluate_rubric_set(
105+
rubrics, dataset, "target", calculate_contribution=True
106+
)
107+
108+
# Evaluate source rubrics (if provided)
109+
source_metrics = []
110+
if source_rubrics:
111+
print("\n📊 Evaluating source rubrics...")
112+
print(
113+
" ℹ️ Note: Source rubrics only calculate Coverage and Precision (no Contribution)"
114+
)
115+
print(
116+
f" 🚀 Using parallel evaluation for {len(source_rubrics)} source rubrics..."
117+
)
118+
_, source_metrics = analyzer.evaluate_rubric_set(
119+
source_rubrics,
120+
dataset,
121+
"source",
122+
calculate_contribution=False,
123+
parallel_rubrics=True, # Enable parallel evaluation for source rubrics
124+
)
125+
126+
# Generate output directory name if not provided
127+
if output_dir is None:
128+
timestamp = time.strftime("%Y%m%d_%H%M%S")
129+
output_dir = f"rubric_analysis_results_{timestamp}"
130+
131+
output_path = Path(output_dir)
132+
output_path.mkdir(exist_ok=True)
133+
134+
# Save results using analyzer's built-in method
135+
results_file = output_path / "analysis_results.json"
136+
analyzer.save_analysis_results(
137+
ensemble_accuracy, source_metrics, metrics, str(results_file)
138+
)
139+
140+
print(f"\n💾 Results saved to: {output_path}")
141+
print(f" 📄 Analysis results: {results_file}")
142+
143+
return ensemble_accuracy, metrics
144+
145+
146+
def main():
147+
"""Main function with command line interface"""
148+
parser = argparse.ArgumentParser(description="Simple Rubric Analysis Runner")
149+
150+
# Input options
151+
parser.add_argument(
152+
"--rubrics", required=True, help="Rubrics JSON file or output directory"
153+
)
154+
parser.add_argument(
155+
"--dataset",
156+
default="./data/helpsteer3_preference_valid.jsonl",
157+
help="Validation dataset path",
158+
)
159+
parser.add_argument("--model", default="qwen3-32b", help="Model name")
160+
parser.add_argument(
161+
"--max-samples", type=int, default=100, help="Maximum samples for evaluation"
162+
)
163+
parser.add_argument(
164+
"--max-workers",
165+
type=int,
166+
default=256,
167+
help="Maximum number of worker threads for parallel processing",
168+
)
169+
parser.add_argument(
170+
"--output",
171+
default=None,
172+
help="Output directory for results (default: auto-generated with timestamp)",
173+
)
174+
parser.add_argument(
175+
"--source-rubrics",
176+
default=None,
177+
help="Optional source rubrics JSON file or directory for comparison",
178+
)
179+
180+
args = parser.parse_args()
181+
182+
try:
183+
run_analysis(
184+
args.rubrics,
185+
args.dataset,
186+
args.model,
187+
args.max_samples,
188+
args.max_workers,
189+
args.output,
190+
args.source_rubrics,
191+
)
192+
print("\n🎉 Analysis completed successfully!")
193+
194+
except Exception as e:
195+
print(f"❌ Analysis failed: {e}")
196+
sys.exit(1)
197+
198+
199+
if __name__ == "__main__":
200+
main()

0 commit comments

Comments
 (0)