Skip to content

Commit 1783195

Browse files
committed
fix(cli): fix saving plots when they are also shown immediately
1 parent 3f80048 commit 1783195

File tree

1 file changed

+36
-31
lines changed

1 file changed

+36
-31
lines changed

eis_toolkit/cli.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -683,17 +683,17 @@ def parallel_coordinates_cli(
683683
curved_lines=curved_lines,
684684
)
685685
typer.echo("Progress: 75%")
686-
if show_plot:
687-
plt.show()
688686

689-
echo_str_end = "."
690687
if output_file is not None:
691688
dpi = "figure" if save_dpi is None else save_dpi
692689
plt.savefig(output_file, dpi=dpi)
693-
echo_str_end = f", output figure saved to {output_file}."
694-
typer.echo("Progress: 100%")
690+
typer.echo(f"Output figure saved to {output_file}.")
695691

696-
typer.echo("Parallel coordinates plot completed" + echo_str_end)
692+
if show_plot:
693+
plt.show()
694+
695+
typer.echo("Progress: 100%")
696+
typer.echo("Parallel coordinates plot completed")
697697

698698

699699
# PCA FOR RASTER DATA
@@ -3540,16 +3540,17 @@ def plot_roc_curve_cli(
35403540

35413541
_ = plot_roc_curve(y_true=y_true, y_prob=y_prob)
35423542
typer.echo("Progress: 75%")
3543-
if show_plot:
3544-
plt.show()
35453543

35463544
if output_file is not None:
35473545
dpi = "figure" if save_dpi is None else save_dpi
35483546
plt.savefig(output_file, dpi=dpi)
3549-
echo_str_end = f", output figure saved to {output_file}."
3550-
typer.echo("Progress: 100% \n")
3547+
typer.echo(f"Output figure saved to {output_file}.")
35513548

3552-
typer.echo("ROC curve plot completed" + echo_str_end)
3549+
if show_plot:
3550+
plt.show()
3551+
3552+
typer.echo("Progress: 100%")
3553+
typer.echo("ROC curve plot completed")
35533554

35543555

35553556
@app.command()
@@ -3580,16 +3581,17 @@ def plot_det_curve_cli(
35803581

35813582
_ = plot_det_curve(y_true=y_true, y_prob=y_prob)
35823583
typer.echo("Progress: 75%")
3583-
if show_plot:
3584-
plt.show()
35853584

35863585
if output_file is not None:
35873586
dpi = "figure" if save_dpi is None else save_dpi
35883587
plt.savefig(output_file, dpi=dpi)
3589-
echo_str_end = f", output figure saved to {output_file}."
3590-
typer.echo("Progress: 100% \n")
3588+
typer.echo(f"Output figure saved to {output_file}.")
35913589

3592-
typer.echo("DET curve plot completed" + echo_str_end)
3590+
if show_plot:
3591+
plt.show()
3592+
3593+
typer.echo("Progress: 100%")
3594+
typer.echo("DET curve plot completed")
35933595

35943596

35953597
@app.command()
@@ -3619,16 +3621,17 @@ def plot_precision_recall_curve_cli(
36193621

36203622
_ = plot_precision_recall_curve(y_true=y_true, y_prob=y_prob)
36213623
typer.echo("Progress: 75%")
3622-
if show_plot:
3623-
plt.show()
36243624

36253625
if output_file is not None:
36263626
dpi = "figure" if save_dpi is None else save_dpi
36273627
plt.savefig(output_file, dpi=dpi)
3628-
echo_str_end = f", output figure saved to {output_file}."
3629-
typer.echo("Progress: 100% \n")
3628+
typer.echo(f"Output figure saved to {output_file}.")
36303629

3631-
typer.echo("Precision-Recall curve plot completed" + echo_str_end)
3630+
if show_plot:
3631+
plt.show()
3632+
3633+
typer.echo("Progress: 100%")
3634+
typer.echo("Precision-Recall curve plot completed")
36323635

36333636

36343637
@app.command()
@@ -3658,16 +3661,17 @@ def plot_calibration_curve_cli(
36583661

36593662
_ = plot_calibration_curve(y_true=y_true, y_prob=y_prob, n_bins=n_bins)
36603663
typer.echo("Progress: 75%")
3661-
if show_plot:
3662-
plt.show()
36633664

36643665
if output_file is not None:
36653666
dpi = "figure" if save_dpi is None else save_dpi
36663667
plt.savefig(output_file, dpi=dpi)
3667-
echo_str_end = f", output figure saved to {output_file}."
3668-
typer.echo("Progress: 100% \n")
3668+
typer.echo(f"Output figure saved to {output_file}.")
36693669

3670-
typer.echo("Calibration curve plot completed" + echo_str_end)
3670+
if show_plot:
3671+
plt.show()
3672+
3673+
typer.echo("Progress: 100%")
3674+
typer.echo("Calibration curve plot completed")
36713675

36723676

36733677
@app.command()
@@ -3693,16 +3697,17 @@ def plot_confusion_matrix_cli(
36933697
matrix = confusion_matrix(y_true, y_pred)
36943698
_ = plot_confusion_matrix(confusion_matrix=matrix)
36953699
typer.echo("Progress: 75%")
3696-
if show_plot:
3697-
plt.show()
36983700

36993701
if output_file is not None:
37003702
dpi = "figure" if save_dpi is None else save_dpi
37013703
plt.savefig(output_file, dpi=dpi)
3702-
echo_str_end = f", output figure saved to {output_file}."
3703-
typer.echo("Progress: 100% \n")
3704+
typer.echo(f"Output figure saved to {output_file}.")
37043705

3705-
typer.echo("Confusion matrix plot completed" + echo_str_end)
3706+
if show_plot:
3707+
plt.show()
3708+
3709+
typer.echo("Progress: 100%")
3710+
typer.echo("Confusion matrix plot completed.")
37063711

37073712

37083713
@app.command()

0 commit comments

Comments
 (0)