Skip to content

Commit cc96c26

Browse files
2 parents 162f32d + 1f5d956 commit cc96c26

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

data_questionnaire_agent/model/confidence_schema.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ class ConfidenceDegree(StrEnum):
1414
low = "low"
1515

1616

17+
CONFIDENCE_DEGREE_DICT = {
18+
ConfidenceDegree.outstanding: 5,
19+
ConfidenceDegree.high: 4,
20+
ConfidenceDegree.medium: 3,
21+
ConfidenceDegree.mediocre: 2,
22+
ConfidenceDegree.low: 1,
23+
}
24+
25+
1726
class ConfidenceRating(BaseModel):
1827
"""Represents a rating of how confident the model is to give advice to a customer based on a questionnaire"""
1928

@@ -30,6 +39,29 @@ class ConfidenceRating(BaseModel):
3039
description="The confidence rating of the model to give advice to a customer based on a questionnaire",
3140
)
3241

42+
def _value(self) -> int:
43+
return CONFIDENCE_DEGREE_DICT[self.rating]
44+
45+
def __lt__(self, other):
46+
if isinstance(other, ConfidenceRating):
47+
return self._value() < other._value()
48+
return NotImplemented
49+
50+
def __le__(self, other):
51+
if isinstance(other, ConfidenceRating):
52+
return self._value() <= other._value()
53+
return NotImplemented
54+
55+
def __gt__(self, other):
56+
if isinstance(other, ConfidenceRating):
57+
return self._value() > other._value()
58+
return NotImplemented
59+
60+
def __ge__(self, other):
61+
if isinstance(other, ConfidenceRating):
62+
return self._value() >= other._value()
63+
return NotImplemented
64+
3365
def to_markdown(self, locale: str = "en") -> str:
3466
return f"""
3567
# {t("Confidence Degree", locale=locale)}
@@ -49,3 +81,5 @@ def to_html(self, language: str = "en") -> str:
4981
5082
<p>{self.reasoning}</p>
5183
"""
84+
85+

data_questionnaire_agent/server/questionnaire_server.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,19 @@ async def save_confidence_rating(
412412
questionnaire: Questionnaire,
413413
):
414414
if confidence_rating is not None:
415+
step = len(questionnaire.questions)
416+
417+
# Prevent the confidence rating from decreasing
418+
previous_confidence_rating = None
419+
if step is not None and step > 2:
420+
previous_confidence_rating = await select_confidence(session_id, step - 1)
421+
if previous_confidence_rating is not None and previous_confidence_rating is not None:
422+
if previous_confidence_rating > confidence_rating:
423+
confidence_rating = previous_confidence_rating
424+
415425
# Save the confidence
416426
if conditional_advice:
417427
conditional_advice.confidence = confidence_rating
418-
step = len(questionnaire.questions)
419428
await save_confidence(session_id, step, confidence_rating)
420429

421430

0 commit comments

Comments
 (0)