Skip to content

Commit

Permalink
fix: issue TheAlgorithms#12233 ,added a small constant beta to the nu…
Browse files Browse the repository at this point in the history
…merator and denominator and added a test case
  • Loading branch information
evan.zhang5 committed Oct 22, 2024
1 parent 03a4251 commit 9367b3b
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions machine_learning/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def categorical_cross_entropy(
def categorical_focal_cross_entropy(
y_true: np.ndarray,
y_pred: np.ndarray,
alpha: np.ndarray = None,
alpha: np.ndarray | None = None,
gamma: float = 2.0,
epsilon: float = 1e-15,
) -> float:
Expand Down Expand Up @@ -648,7 +648,11 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
>>> true_labels = np.array([0.2, 0.3, 0.5])
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
0.030478754035472025
0.0304787540354719
>>> true_labels = np.array([0, 0.5, 0.5])
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
0.3669845875400667
>>> true_labels = np.array([0.2, 0.3, 0.5])
>>> predicted_probs = np.array([0.3, 0.3, 0.4, 0.5])
>>> kullback_leibler_divergence(true_labels, predicted_probs)
Expand All @@ -658,7 +662,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
"""
if len(y_true) != len(y_pred):
raise ValueError("Input arrays must have the same length.")

beta = 1e-15
y_true = y_true + beta
y_pred = y_pred + beta
kl_loss = y_true * np.log(y_true / y_pred)
return np.sum(kl_loss)

Expand Down

0 comments on commit 9367b3b

Please sign in to comment.