Skip to content

Commit 5320414

Browse files
committed
Add tests for stix-extract
1 parent f7ffdeb commit 5320414

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed

tests/test_stix_extract.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""Test the stix_extract module."""
2+
# Standard Python Libraries
3+
import os
4+
import sys
5+
import tempfile
6+
from unittest.mock import MagicMock, patch
7+
8+
# Third-Party Libraries
9+
import pytest
10+
11+
# cisagov Libraries
12+
import ioc_scan
13+
from ioc_scan import stix_extract
14+
from ioc_scan.stix_extract import extract_stix_info, sort_ip_address
15+
16+
PROJECT_VERSION = ioc_scan.__version__
17+
18+
19+
def test_extract_stix_info_ip():
20+
"""Test extracting IP addresses from a STIX package."""
21+
observable_mock = MagicMock()
22+
observable_mock.object_.properties._XSI_TYPE = "AddressObjectType"
23+
observable_mock.object_.properties.address_value = "127.0.0.1"
24+
indicator_mock = MagicMock()
25+
indicator_mock.observables = [observable_mock]
26+
stix_package_mock = MagicMock()
27+
stix_package_mock.indicators = [indicator_mock]
28+
with patch(
29+
"ioc_scan.stix_extract.STIXPackage.from_xml", return_value=stix_package_mock
30+
):
31+
ip_addresses, hashes, fqdns, urls = extract_stix_info("stix_file")
32+
assert ip_addresses == ["127.0.0.1"]
33+
assert hashes == []
34+
assert fqdns == []
35+
assert urls == []
36+
37+
38+
def test_sort_ip_address():
39+
"""Test sorting IP addresses."""
40+
result = sort_ip_address("127.0.0.1")
41+
assert result == (4, 2130706433)
42+
43+
44+
@pytest.fixture
45+
def mock_domain_observable():
46+
"""Return a mock STIX DomainNameObjectType observable."""
47+
observable = MagicMock()
48+
observable.object_.properties._XSI_TYPE = "DomainNameObjectType"
49+
observable.object_.properties.value.value = "www.example.com"
50+
return observable
51+
52+
53+
@pytest.fixture
54+
def mock_uri_observable():
55+
"""Return a mock STIX URIObjectType observable."""
56+
observable = MagicMock()
57+
observable.object_.properties._XSI_TYPE = "URIObjectType"
58+
observable.object_.properties.value.value = "www.example.com/path"
59+
return observable
60+
61+
62+
def test_extract_stix_info_with_domain_and_uri_observables(
63+
mock_domain_observable, mock_uri_observable
64+
):
65+
"""Test extracting FQDNs and URLs from a STIX package."""
66+
stix_package_mock = MagicMock()
67+
stix_package_mock.indicators = [
68+
MagicMock(observables=[mock_domain_observable, mock_uri_observable])
69+
]
70+
with patch(
71+
"ioc_scan.stix_extract.STIXPackage.from_xml", return_value=stix_package_mock
72+
):
73+
ips, hashes, fqdns, urls = extract_stix_info("fake_file.xml")
74+
assert fqdns == ["www.example.com"]
75+
assert urls == ["www.example.com/path"]
76+
77+
78+
@pytest.fixture
79+
def mock_hash_observable_with_valid_types():
80+
"""Return a mock STIX FileObjectType observable with valid hash types."""
81+
observable = MagicMock()
82+
observable.object_.properties._XSI_TYPE = "FileObjectType"
83+
observable.object_.properties.hashes = [
84+
MagicMock(
85+
type_=MagicMock(value="SHA1"),
86+
simple_hash_value=MagicMock(value="SHA1_HASH"),
87+
),
88+
MagicMock(
89+
type_=MagicMock(value="MD5"), simple_hash_value=MagicMock(value="MD5_HASH")
90+
),
91+
MagicMock(
92+
type_=MagicMock(value="SHA256"),
93+
simple_hash_value=MagicMock(value="SHA256_HASH"),
94+
),
95+
]
96+
return observable
97+
98+
99+
@pytest.fixture
100+
def mock_hash_observable_with_invalid_types():
101+
"""Return a mock STIX FileObjectType observable with invalid hash types."""
102+
observable = MagicMock()
103+
observable.object_.properties._XSI_TYPE = "FileObjectType"
104+
observable.object_.properties.hashes = [
105+
MagicMock(
106+
type_=MagicMock(value="INVALID1"),
107+
simple_hash_value=MagicMock(value="INVALID1_HASH"),
108+
),
109+
MagicMock(
110+
type_=MagicMock(value="INVALID2"),
111+
simple_hash_value=MagicMock(value="INVALID2_HASH"),
112+
),
113+
MagicMock(
114+
type_=MagicMock(value="INVALID3"),
115+
simple_hash_value=MagicMock(value="INVALID3_HASH"),
116+
),
117+
]
118+
return observable
119+
120+
121+
@pytest.mark.parametrize(
122+
"hash_observable, expected",
123+
[
124+
("mock_hash_observable_with_valid_types", ["SHA256_HASH"]),
125+
("mock_hash_observable_with_invalid_types", []),
126+
],
127+
)
128+
def test_extract_stix_info_with_hash_observable(hash_observable, expected, request):
129+
"""Test extracting hashes from a STIX package."""
130+
mock_hash_observable = request.getfixturevalue(hash_observable)
131+
stix_package_mock = MagicMock()
132+
stix_package_mock.indicators = [MagicMock(observables=[mock_hash_observable])]
133+
with patch(
134+
"ioc_scan.stix_extract.STIXPackage.from_xml", return_value=stix_package_mock
135+
):
136+
ips, hashes, fqdns, urls = extract_stix_info("fake_file.xml")
137+
assert hashes == expected
138+
139+
140+
def test_extract_stix_info_with_invalid_stix_file():
141+
"""Test invalid filename."""
142+
with pytest.raises(Exception):
143+
extract_stix_info("invalid.stix")
144+
145+
146+
def test_extract_stix_info_with_unexpected_object_type():
147+
"""Test extracting observables from a STIX package with an unexpected object type."""
148+
observable_mock = MagicMock()
149+
observable_mock.object_.properties._XSI_TYPE = "UnexpectedObjectType"
150+
indicator_mock = MagicMock()
151+
indicator_mock.observables = [observable_mock]
152+
stix_package_mock = MagicMock()
153+
stix_package_mock.indicators = [indicator_mock]
154+
with patch(
155+
"ioc_scan.stix_extract.STIXPackage.from_xml", return_value=stix_package_mock
156+
):
157+
ip_addresses, hashes, fqdns, urls = extract_stix_info("stix_file")
158+
assert ip_addresses == []
159+
assert hashes == []
160+
assert fqdns == []
161+
assert urls == []
162+
163+
164+
def test_extract_stix_info_with_file_object_without_hashes():
165+
"""Test extracting observables from a STIX package where the file object does not have hashes."""
166+
observable_mock = MagicMock()
167+
observable_mock.object_.properties._XSI_TYPE = "FileObjectType"
168+
observable_mock.object_.properties.hashes = None
169+
indicator_mock = MagicMock()
170+
indicator_mock.observables = [observable_mock]
171+
stix_package_mock = MagicMock()
172+
stix_package_mock.indicators = [indicator_mock]
173+
with patch(
174+
"ioc_scan.stix_extract.STIXPackage.from_xml", return_value=stix_package_mock
175+
):
176+
ip_addresses, hashes, fqdns, urls = extract_stix_info("stix_file")
177+
assert ip_addresses == []
178+
assert hashes == []
179+
assert fqdns == []
180+
assert urls == []
181+
182+
183+
def test_version(capsys):
184+
"""Verify that version string sent to stdout, and agrees with the module."""
185+
with pytest.raises(SystemExit):
186+
with patch.object(sys, "argv", ["bogus", "--version"]):
187+
stix_extract.main()
188+
captured = capsys.readouterr()
189+
assert (
190+
captured.out == f"{PROJECT_VERSION}\n"
191+
), "standard output by '--version' should agree with module.__version__"
192+
193+
194+
def test_help(capsys):
195+
"""Verify that the help text is sent to stdout."""
196+
with pytest.raises(SystemExit):
197+
with patch.object(sys, "argv", ["bogus", "--help"]):
198+
stix_extract.main()
199+
captured = capsys.readouterr()
200+
assert (
201+
"This script parses" in captured.out
202+
), "help text did not have expected string"
203+
204+
205+
def test_main():
206+
"""Test the main function of the script."""
207+
# Mock the command line arguments
208+
with patch("ioc_scan.stix_extract.docopt") as mock_docopt:
209+
# Create a temporary STIX file
210+
temp_file = tempfile.NamedTemporaryFile(delete=False)
211+
temp_file.write(b"<xml></xml>") # Minimal XML content to prevent parse errors
212+
temp_file.close()
213+
214+
mock_docopt.return_value = {"<file>": temp_file.name}
215+
216+
# Mock the extraction function to return some test data
217+
with patch("ioc_scan.stix_extract.extract_stix_info") as mock_extract:
218+
mock_extract.return_value = (
219+
["1.1.1.1", "2.2.2.2"],
220+
["hash1", "hash2"],
221+
["fqdn1", "fqdn2"],
222+
["url1", "url2"],
223+
)
224+
225+
# Mock the print function to do nothing
226+
with patch("builtins.print") as mock_print:
227+
stix_extract.main()
228+
229+
# Verify the mock calls.
230+
mock_docopt.assert_called_once_with(
231+
stix_extract.__doc__, version=PROJECT_VERSION
232+
)
233+
mock_extract.assert_called_once_with(temp_file.name)
234+
assert mock_print.call_count == 12 # Check how many times print is called
235+
236+
os.unlink(temp_file.name) # Delete the temporary file

0 commit comments

Comments
 (0)