Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid log(0) in KL divergence #12233

Open
kevin1kevin1k opened this issue Oct 21, 2024 · 8 comments · May be fixed by #12262 or #12263
Open

Avoid log(0) in KL divergence #12233

kevin1kevin1k opened this issue Oct 21, 2024 · 8 comments · May be fixed by #12262 or #12263
Labels

Comments

@kevin1kevin1k
Copy link

Repository commit

03a4251

Python version (python --version)

Python 3.10.15

Dependencies version (pip freeze)

absl-py==2.1.0
astunparse==1.6.3
beautifulsoup4==4.12.3
certifi==2024.8.30
charset-normalizer==3.4.0
contourpy==1.3.0
cycler==0.12.1
dill==0.3.9
dom_toml==2.0.0
domdf-python-tools==3.9.0
fake-useragent==1.5.1
flatbuffers==24.3.25
fonttools==4.54.1
gast==0.6.0
google-pasta==0.2.0
grpcio==1.67.0
h5py==3.12.1
idna==3.10
imageio==2.36.0
joblib==1.4.2
keras==3.6.0
kiwisolver==1.4.7
libclang==18.1.1
lxml==5.3.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.9.2
mdurl==0.1.2
ml-dtypes==0.3.2
mpmath==1.3.0
namex==0.0.8
natsort==8.4.0
numpy==1.26.4
oauthlib==3.2.2
opencv-python==4.10.0.84
opt_einsum==3.4.0
optree==0.13.0
packaging==24.1
pandas==2.2.3
patsy==0.5.6
pbr==6.1.0
pillow==11.0.0
pip==24.2
protobuf==4.25.5
psutil==6.1.0
Pygments==2.18.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
pytz==2024.2
qiskit==1.2.4
qiskit-aer==0.15.1
requests==2.32.3
requests-oauthlib==1.3.1
rich==13.9.2
rustworkx==0.15.1
scikit-learn==1.5.2
scipy==1.14.1
setuptools==74.1.2
six==1.16.0
soupsieve==2.6
sphinx-pyproject==0.3.0
statsmodels==0.14.4
stevedore==5.3.0
symengine==0.13.0
sympy==1.13.3
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.2
tensorflow-io-gcs-filesystem==0.37.1
termcolor==2.5.0
threadpoolctl==3.5.0
tomli==2.0.2
tweepy==4.14.0
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
Werkzeug==3.0.4
wheel==0.44.0
wrapt==1.16.0
xgboost==2.1.1

Expected behavior

The entries where y_true is 0 should be ignored in the summation (see Actual behavior)

Actual behavior

In

kl_loss = y_true * np.log(y_true / y_pred)
return np.sum(kl_loss)
if any entry of y_true is 0, the output of np.log would become -inf and thus the method returns nan.
Maybe it would be better to exclude those entries where y_true is 0?

@bz-e

This comment was marked as spam.

bz-e pushed a commit to bz-e/Python-algorithms that referenced this issue Oct 22, 2024
…merator and denominator and added a test case
@vedprakash226 vedprakash226 mentioned this issue Oct 23, 2024
15 tasks
@vedprakash226
Copy link

if y_true is 0 than what we have to return

@kevin1kevin1k
Copy link
Author

I would like to work on this.

I left comments on your PR. Basically we can only use the entries where y_true is nonzero, without adding epsilon.

if y_true is 0 than what we have to return

y_true is an array instead of a number here, so we can still use the remaining entries.

@brambhattabhishek
Copy link

brambhattabhishek commented Oct 24, 2024

1.# Ensure that y_pred doesn't have zero values to avoid division by zero
2.# Clip y_pred to a small positive value to avoid log(0)
3.# Calculate the KL divergence only for non-zero y_true entries

def kl_divergence(y_true, y_pred):
    y_pred = np.clip(y_pred, 1e-10, None) to avoid log(0)
    kl_loss = np.where(y_true != 0, y_true * np.log(y_true / y_pred), 0)
    return np.sum(kl_loss)

@brambhattabhishek
Copy link

I would like to work on this.

I left comments on your PR. Basically we can only use the entries where y_true is nonzero, without adding epsilon.

if y_true is 0 than what we have to return

y_true is an array instead of a number here, so we can still use the remaining entries.

is my solution is correct ?

@brambhattabhishek
Copy link

/assign

@kevin1kevin1k
Copy link
Author

I would like to work on this.

I left comments on your PR. Basically we can only use the entries where y_true is nonzero, without adding epsilon.

if y_true is 0 than what we have to return

y_true is an array instead of a number here, so we can still use the remaining entries.

is my solution is correct ?

I think it's correct, and additionally handles the case where y_pred is 0. Great job.

@brambhattabhishek
Copy link

I would like to work on this.

I left comments on your PR. Basically we can only use the entries where y_true is nonzero, without adding epsilon.

if y_true is 0 than what we have to return

y_true is an array instead of a number here, so we can still use the remaining entries.

is my solution is correct ?

I think it's correct, and additionally handles the case where y_pred is 0. Great job.

thanks sure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment