Skip to content

Commit 321fee3

Browse files
committed
Format test files with ruff
1 parent b761807 commit 321fee3

18 files changed

Lines changed: 111 additions & 140 deletions

tests/test_backward_compat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,5 +233,3 @@ def test_constants_import(self) -> None:
233233
assert len(ALPHABET) == 21 # 20 AAs + X
234234
assert AA_1_TO_3["A"] == "ALA"
235235
assert AA_3_TO_1["ALA"] == "A"
236-
237-

tests/test_cli.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ def test_info_shows_authors(self) -> None:
110110
assert result.exit_code == 0
111111
# Should contain at least one author name
112112
assert any(
113-
name in result.output
114-
for name in ["Beining", "Engelberger", "Schoeder", "Meiler"]
113+
name in result.output for name in ["Beining", "Engelberger", "Schoeder", "Meiler"]
115114
)
116115

117116
def test_info_shows_pytorch_info(self) -> None:
@@ -180,7 +179,9 @@ def test_predict_nonexistent_checkpoint(self, tmp_path) -> None:
180179

181180
# Create a dummy PDB file
182181
pdb_file = tmp_path / "test.pdb"
183-
pdb_file.write_text("ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\n")
182+
pdb_file.write_text(
183+
"ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\n"
184+
)
184185

185186
runner = CliRunner()
186187
result = runner.invoke(
@@ -199,7 +200,9 @@ def test_batch_missing_checkpoint(self, tmp_path) -> None:
199200

200201
# Create a dummy PDB file
201202
pdb_file = tmp_path / "test.pdb"
202-
pdb_file.write_text("ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\n")
203+
pdb_file.write_text(
204+
"ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\n"
205+
)
203206

204207
runner = CliRunner()
205208
result = runner.invoke(cli, ["batch", str(pdb_file)])
@@ -257,5 +260,3 @@ def test_cli_as_module(self) -> None:
257260
# Should show help (exit code 0) or fail gracefully
258261
# Note: This may fail if package is not installed
259262
assert result.returncode == 0 or "No module named" in result.stderr
260-
261-

tests/test_constants.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""Tests for constants module."""
22

33

4-
54
def test_alphabet():
65
"""Test ALPHABET constant."""
76
from frustrampnn import ALPHABET
87

98
assert len(ALPHABET) == 21
10-
assert 'A' in ALPHABET
11-
assert 'X' in ALPHABET
9+
assert "A" in ALPHABET
10+
assert "X" in ALPHABET
1211

1312

1413
def test_vocab_dim():
@@ -23,36 +22,34 @@ def test_aa_conversion():
2322
"""Test amino acid conversion dictionaries."""
2423
from frustrampnn import AA_1_TO_3, AA_3_TO_1
2524

26-
assert AA_3_TO_1['ALA'] == 'A'
27-
assert AA_3_TO_1['TRP'] == 'W'
28-
assert AA_3_TO_1['MSE'] == 'M' # Selenomethionine
29-
assert AA_1_TO_3['A'] == 'ALA'
30-
assert AA_1_TO_3['W'] == 'TRP'
25+
assert AA_3_TO_1["ALA"] == "A"
26+
assert AA_3_TO_1["TRP"] == "W"
27+
assert AA_3_TO_1["MSE"] == "M" # Selenomethionine
28+
assert AA_1_TO_3["A"] == "ALA"
29+
assert AA_1_TO_3["W"] == "TRP"
3130

3231

3332
def test_frustration_thresholds():
3433
"""Test frustration threshold values."""
3534
from frustrampnn import FRUSTRATION_THRESHOLDS
3635

37-
assert FRUSTRATION_THRESHOLDS['highly'] == -1.0
38-
assert FRUSTRATION_THRESHOLDS['minimally'] == 0.58
36+
assert FRUSTRATION_THRESHOLDS["highly"] == -1.0
37+
assert FRUSTRATION_THRESHOLDS["minimally"] == 0.58
3938

4039

4140
def test_frustration_colors():
4241
"""Test frustration color scheme."""
4342
from frustrampnn import FRUSTRATION_COLORS
4443

45-
assert FRUSTRATION_COLORS['highly'] == 'red'
46-
assert FRUSTRATION_COLORS['minimally'] == 'green'
47-
assert FRUSTRATION_COLORS['neutral'] == 'gray'
48-
assert FRUSTRATION_COLORS['native'] == 'blue'
44+
assert FRUSTRATION_COLORS["highly"] == "red"
45+
assert FRUSTRATION_COLORS["minimally"] == "green"
46+
assert FRUSTRATION_COLORS["neutral"] == "gray"
47+
assert FRUSTRATION_COLORS["native"] == "blue"
4948

5049

5150
def test_constants_from_model():
5251
"""Test that constants can be imported from model module."""
5352
from frustrampnn.model import ALPHABET, VOCAB_DIM
5453

55-
assert ALPHABET == 'ACDEFGHIKLMNPQRSTVWYX'
54+
assert ALPHABET == "ACDEFGHIKLMNPQRSTVWYX"
5655
assert VOCAB_DIM == 21
57-
58-

tests/test_inference_batch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,3 @@ def test_predict_batch_combines_results():
8585
assert isinstance(result, pd.DataFrame)
8686
assert len(result) == 2
8787
assert list(result["pdb"]) == ["pdb1", "pdb2"]
88-
89-

tests/test_inference_predictor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,3 @@ def test_frustrampnn_integration(test_pdb_path):
146146
# This test requires a checkpoint to be available
147147
# Skip if not in the right environment
148148
pass
149-
150-

tests/test_integration.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ def test_mutation_to_dict_integration(self) -> None:
150150
class TestVisualizationIntegration:
151151
"""Integration tests for visualization workflow."""
152152

153-
def test_plot_single_residue_with_data(
154-
self, sample_frustration_df: pd.DataFrame
155-
) -> None:
153+
def test_plot_single_residue_with_data(self, sample_frustration_df: pd.DataFrame) -> None:
156154
"""Test single residue plot with sample data."""
157155
from frustrampnn.visualization import plot_single_residue
158156

@@ -162,19 +160,15 @@ def test_plot_single_residue_with_data(
162160
axes = fig.get_axes()
163161
assert len(axes) == 1
164162

165-
def test_plot_heatmap_with_data(
166-
self, sample_frustration_df: pd.DataFrame
167-
) -> None:
163+
def test_plot_heatmap_with_data(self, sample_frustration_df: pd.DataFrame) -> None:
168164
"""Test heatmap plot with sample data."""
169165
from frustrampnn.visualization import plot_frustration_heatmap
170166

171167
fig = plot_frustration_heatmap(sample_frustration_df, chain="A")
172168

173169
assert fig is not None
174170

175-
def test_plot_single_residue_plotly(
176-
self, sample_frustration_df: pd.DataFrame
177-
) -> None:
171+
def test_plot_single_residue_plotly(self, sample_frustration_df: pd.DataFrame) -> None:
178172
"""Test plotly single residue plot."""
179173
from frustrampnn.visualization import plot_single_residue_plotly
180174

@@ -183,9 +177,7 @@ def test_plot_single_residue_plotly(
183177
assert fig is not None
184178
assert len(fig.data) > 0
185179

186-
def test_plot_heatmap_plotly(
187-
self, sample_frustration_df: pd.DataFrame
188-
) -> None:
180+
def test_plot_heatmap_plotly(self, sample_frustration_df: pd.DataFrame) -> None:
189181
"""Test plotly heatmap plot."""
190182
from frustrampnn.visualization import plot_frustration_heatmap_plotly
191183

@@ -408,12 +400,9 @@ def test_invalid_mutation_string_error(self) -> None:
408400
with pytest.raises(ValueError, match="Invalid mutation format"):
409401
parse_mutation_string("invalid")
410402

411-
def test_invalid_position_in_plot_error(
412-
self, sample_frustration_df: pd.DataFrame
413-
) -> None:
403+
def test_invalid_position_in_plot_error(self, sample_frustration_df: pd.DataFrame) -> None:
414404
"""Test error handling for invalid position in plot."""
415405
from frustrampnn.visualization import plot_single_residue
416406

417407
with pytest.raises(ValueError, match="No data found"):
418408
plot_single_residue(sample_frustration_df, position=999, chain="A")
419-

tests/test_model_features.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def test_forward_shape(
6161
edge_features = 128
6262
node_features = 128
6363
top_k = min(5, seq_length) # Ensure top_k <= seq_length
64-
layer = CA_ProteinFeatures(
65-
edge_features, node_features, top_k=top_k
66-
)
64+
layer = CA_ProteinFeatures(edge_features, node_features, top_k=top_k)
6765

6866
E, E_idx = layer(
6967
sample_ca_coords,
@@ -171,9 +169,7 @@ def test_dist_returns_neighbors(
171169
top_k = 5
172170
layer = CA_ProteinFeatures(64, 64, top_k=top_k)
173171

174-
D_neighbors, E_idx, mask_neighbors = layer._dist(
175-
sample_ca_coords, sample_mask
176-
)
172+
D_neighbors, E_idx, mask_neighbors = layer._dist(sample_ca_coords, sample_mask)
177173

178174
batch_size, seq_length = sample_ca_coords.shape[:2]
179175
assert D_neighbors.shape == (batch_size, seq_length, top_k)
@@ -215,9 +211,7 @@ def test_forward_shape(
215211
edge_features = 128
216212
node_features = 128
217213
top_k = min(5, seq_length)
218-
layer = ProteinFeatures(
219-
edge_features, node_features, top_k=top_k
220-
)
214+
layer = ProteinFeatures(edge_features, node_features, top_k=top_k)
221215

222216
E, E_idx = layer(
223217
sample_backbone_coords,
@@ -426,4 +420,3 @@ def test_multiple_chains(self):
426420

427421
assert E.shape == (1, 10, 3, 64)
428422
assert not torch.isnan(E).any()
429-

tests/test_model_layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,4 +384,3 @@ def test_exports_from_model_init(self):
384384
assert gather_nodes is not None
385385
assert gather_nodes_t is not None
386386
assert cat_neighbors_nodes is not None
387-

tests/test_model_pdb_parsing.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
information from PDB files for use in ProteinMPNN.
55
"""
66

7-
87
import numpy as np
98
import pytest
109

@@ -53,9 +52,7 @@ def test_parse_pdb_biounits_custom_atoms(self, test_pdb_path):
5352
pytest.skip("Test PDB file not found")
5453

5554
# Parse with N, CA, C, O
56-
xyz, seq = parse_PDB_biounits(
57-
str(test_pdb_path), atoms=["N", "CA", "C", "O"], chain="A"
58-
)
55+
xyz, seq = parse_PDB_biounits(str(test_pdb_path), atoms=["N", "CA", "C", "O"], chain="A")
5956

6057
assert xyz.shape[1] == 4 # 4 atoms
6158

@@ -461,4 +458,3 @@ def test_sequence_extraction(self, tmp_path):
461458
xyz, seq = parse_PDB_biounits(str(pdb_file), atoms=["CA"], chain="A")
462459

463460
assert seq[0] == "MQI"
464-

tests/test_model_protein_mpnn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,3 @@ def test_parse_pdb(test_pdb_path):
9090
assert len(result) > 0
9191
assert "seq" in result[0]
9292
assert "name" in result[0]
93-
94-

0 commit comments

Comments
 (0)