Skip to content

Commit a04993a

Browse files
CopilotLIghtJUNctionCopilot
authored
Replace insecure random with secrets module in cryptographic contexts (#3248)
* Initial plan * Security fixes: Replace insecure random with secrets module and improve SSL context Co-authored-by: LIghtJUNction <[email protected]> * Address code review feedback: fix POST method and add named constants Co-authored-by: LIghtJUNction <[email protected]> * Improve documentation for random number generation constants Co-authored-by: LIghtJUNction <[email protected]> * Update astrbot/core/utils/io.py Co-authored-by: Copilot <[email protected]> * Update astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py Co-authored-by: Copilot <[email protected]> * Update tests/test_security_fixes.py Co-authored-by: Copilot <[email protected]> * Update astrbot/core/utils/io.py Co-authored-by: Copilot <[email protected]> * Update astrbot/core/utils/io.py Co-authored-by: Copilot <[email protected]> * Fix: Handle path parameter in SSL fallback for download_image_by_url Co-authored-by: LIghtJUNction <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: LIghtJUNction <[email protected]> Co-authored-by: LIghtJUNction <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 74f845b commit a04993a

File tree

5 files changed

+199
-13
lines changed

5 files changed

+199
-13
lines changed

astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import hashlib
1111
import json
1212
import logging
13-
import random
13+
import secrets
1414
import socket
1515
import struct
1616
import time
@@ -139,6 +139,12 @@ def decode(self, decrypted):
139139
class Prpcrypt:
140140
"""提供接收和推送给企业微信消息的加解密接口"""
141141

142+
# 16位随机字符串的范围常量
143+
# randbelow(RANDOM_RANGE) 返回 [0, 8999999999999999](两端都包含,即包含0和8999999999999999)
144+
# 加上 MIN_RANDOM_VALUE 后得到 [1000000000000000, 9999999999999999](两端都包含)即16位数字
145+
MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位)
146+
RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位)
147+
142148
def __init__(self, key):
143149
# self.key = base64.b64decode(key+"=")
144150
self.key = key
@@ -207,7 +213,9 @@ def get_random_str(self):
207213
"""随机生成16位字符串
208214
@return: 16位字符串
209215
"""
210-
return str(random.randint(1000000000000000, 9999999999999999)).encode()
216+
return str(
217+
secrets.randbelow(self.RANDOM_RANGE) + self.MIN_RANDOM_VALUE
218+
).encode()
211219

212220

213221
class WXBizJsonMsgCrypt:

astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import base64
77
import hashlib
8-
import random
8+
import secrets
99
import string
1010
from typing import Any
1111

@@ -53,7 +53,7 @@ def generate_random_string(length: int = 10) -> str:
5353
5454
"""
5555
letters = string.ascii_letters + string.digits
56-
return "".join(random.choice(letters) for _ in range(length))
56+
return "".join(secrets.choice(letters) for _ in range(length))
5757

5858

5959
def calculate_image_md5(image_data: bytes) -> str:

astrbot/core/provider/sources/azure_tts_source.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
22
import hashlib
33
import json
4-
import random
54
import re
5+
import secrets
66
import time
77
import uuid
88
from pathlib import Path
@@ -54,7 +54,9 @@ async def _sync_time(self):
5454
async def _generate_signature(self) -> str:
5555
await self._sync_time()
5656
timestamp = int(time.time()) + self.time_offset
57-
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
57+
nonce = "".join(
58+
secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
59+
)
5860
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
5961
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
6062

astrbot/core/utils/io.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,31 @@ async def download_image_by_url(
105105
f.write(await resp.read())
106106
return path
107107
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
108-
# 关闭SSL验证
108+
# 关闭SSL验证(仅在证书验证失败时作为fallback)
109+
logger.warning(
110+
f"SSL certificate verification failed for {url}. "
111+
"Disabling SSL verification (CERT_NONE) as a fallback. "
112+
"This is insecure and exposes the application to man-in-the-middle attacks. "
113+
"Please investigate and resolve certificate issues."
114+
)
109115
ssl_context = ssl.create_default_context()
110-
ssl_context.set_ciphers("DEFAULT")
116+
ssl_context.check_hostname = False
117+
ssl_context.verify_mode = ssl.CERT_NONE
111118
async with aiohttp.ClientSession() as session:
112119
if post:
113-
async with session.get(url, ssl=ssl_context) as resp:
114-
return save_temp_img(await resp.read())
120+
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
121+
if not path:
122+
return save_temp_img(await resp.read())
123+
with open(path, "wb") as f:
124+
f.write(await resp.read())
125+
return path
115126
else:
116127
async with session.get(url, ssl=ssl_context) as resp:
117-
return save_temp_img(await resp.read())
128+
if not path:
129+
return save_temp_img(await resp.read())
130+
with open(path, "wb") as f:
131+
f.write(await resp.read())
132+
return path
118133
except Exception as e:
119134
raise e
120135

@@ -157,9 +172,19 @@ async def download_file(url: str, path: str, show_progress: bool = False):
157172
end="",
158173
)
159174
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
160-
# 关闭SSL验证
175+
# 关闭SSL验证(仅在证书验证失败时作为fallback)
176+
logger.warning(
177+
"SSL 证书验证失败,已关闭 SSL 验证(不安全,仅用于临时下载)。请检查目标服务器的证书配置。"
178+
)
179+
logger.warning(
180+
f"SSL certificate verification failed for {url}. "
181+
"Falling back to unverified connection (CERT_NONE). "
182+
"This is insecure and exposes the application to man-in-the-middle attacks. "
183+
"Please investigate certificate issues with the remote server."
184+
)
161185
ssl_context = ssl.create_default_context()
162-
ssl_context.set_ciphers("DEFAULT")
186+
ssl_context.check_hostname = False
187+
ssl_context.verify_mode = ssl.CERT_NONE
163188
async with aiohttp.ClientSession() as session:
164189
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
165190
total_size = int(resp.headers.get("content-length", 0))

tests/test_security_fixes.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Tests for security fixes - cryptographic random number generation and SSL context."""
2+
3+
import os
4+
import ssl
5+
import sys
6+
7+
# Add project root to sys.path
8+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
9+
10+
import pytest
11+
12+
13+
def test_wecom_crypto_uses_secrets():
14+
"""Test that WXBizJsonMsgCrypt uses secrets module instead of random."""
15+
from astrbot.core.platform.sources.wecom_ai_bot.WXBizJsonMsgCrypt import Prpcrypt
16+
17+
# Create an instance and test that random string generation works
18+
prpcrypt = Prpcrypt(b"test_key_32_bytes_long_value!")
19+
20+
# Generate multiple random strings and verify they are different and valid
21+
random_strings = [prpcrypt.get_random_str() for _ in range(10)]
22+
23+
# All strings should be 16 bytes long
24+
assert all(len(s) == 16 for s in random_strings)
25+
26+
# All strings should be different (extremely high probability with cryptographic random)
27+
assert len(set(random_strings)) == 10
28+
29+
# All strings should be numeric when decoded
30+
for s in random_strings:
31+
decoded = s.decode()
32+
assert decoded.isdigit()
33+
assert 1000000000000000 <= int(decoded) <= 9999999999999999
34+
35+
36+
def test_wecomai_utils_uses_secrets():
37+
"""Test that wecomai_utils uses secrets module for random string generation."""
38+
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_utils import (
39+
generate_random_string,
40+
)
41+
42+
# Generate multiple random strings and verify they are different
43+
random_strings = [generate_random_string(10) for _ in range(20)]
44+
45+
# All strings should be 10 characters long
46+
assert all(len(s) == 10 for s in random_strings)
47+
48+
# All strings should be alphanumeric
49+
for s in random_strings:
50+
assert s.isalnum()
51+
52+
# All strings should be different (extremely high probability with cryptographic random)
53+
assert len(set(random_strings)) >= 19 # Allow for 1 collision in 20 (very unlikely)
54+
55+
56+
def test_azure_tts_signature_uses_secrets():
57+
"""Test that Azure TTS signature generation uses secrets module."""
58+
import asyncio
59+
60+
from astrbot.core.provider.sources.azure_tts_source import OTTSProvider
61+
62+
# Create a provider with test config
63+
config = {
64+
"OTTS_SKEY": "test_secret_key",
65+
"OTTS_URL": "https://example.com/api/tts",
66+
"OTTS_AUTH_TIME": "https://example.com/api/time",
67+
}
68+
69+
async def test_nonce_generation():
70+
async with OTTSProvider(config) as provider:
71+
# Mock time sync to avoid actual API calls
72+
provider.time_offset = 0
73+
provider.last_sync_time = 9999999999
74+
75+
# Generate multiple signatures and extract nonces
76+
signatures = []
77+
for _ in range(10):
78+
sig = await provider._generate_signature()
79+
signatures.append(sig)
80+
81+
# Extract nonces (second field in signature format: timestamp-nonce-0-hash)
82+
nonces = [sig.split("-")[1] for sig in signatures]
83+
84+
# All nonces should be 10 characters long
85+
assert all(len(n) == 10 for n in nonces)
86+
87+
# All nonces should be alphanumeric (lowercase letters and digits)
88+
for n in nonces:
89+
assert all(c in "abcdefghijklmnopqrstuvwxyz0123456789" for c in n)
90+
91+
# All nonces should be different (cryptographic random ensures uniqueness)
92+
assert len(set(nonces)) == 10
93+
94+
asyncio.run(test_nonce_generation())
95+
96+
97+
def test_ssl_context_fallback_explicit():
98+
"""Test that SSL context fallback is properly configured."""
99+
# This test verifies the SSL context configuration
100+
# We can't easily test the full io.py functions without network calls,
101+
# but we can verify that ssl.CERT_NONE and check_hostname=False are valid settings
102+
103+
# Create a context similar to what's used in io.py
104+
ssl_context = ssl.create_default_context()
105+
ssl_context.check_hostname = False
106+
ssl_context.verify_mode = ssl.CERT_NONE
107+
108+
# Verify the settings are applied correctly
109+
assert ssl_context.check_hostname is False
110+
assert ssl_context.verify_mode == ssl.CERT_NONE
111+
112+
# This configuration should work but is intentionally insecure for fallback
113+
# The actual code only uses this when certificate validation fails
114+
115+
116+
def test_io_module_has_ssl_imports():
117+
"""Verify that io.py properly imports ssl module."""
118+
from astrbot.core.utils import io
119+
120+
# Check that ssl is available in the module
121+
assert hasattr(io, "ssl")
122+
123+
# Check that CERT_NONE constant is accessible
124+
assert hasattr(io.ssl, "CERT_NONE")
125+
126+
127+
def test_secrets_module_randomness_quality():
128+
"""Test that secrets module provides high-quality randomness."""
129+
import secrets
130+
131+
# Generate a large set of random numbers
132+
random_numbers = [secrets.randbelow(100) for _ in range(1000)]
133+
134+
# Basic statistical test: should have good distribution
135+
unique_values = len(set(random_numbers))
136+
137+
# With 1000 random numbers from 0-99, we should see most values at least once
138+
# This is a very basic test - real cryptographic random should pass this easily
139+
assert unique_values >= 60 # Should see at least 60 different values out of 100
140+
141+
# Test secrets.choice for string generation
142+
chars = "abcdefghijklmnopqrstuvwxyz0123456789"
143+
random_chars = [secrets.choice(chars) for _ in range(1000)]
144+
145+
# Should have good character distribution
146+
unique_chars = len(set(random_chars))
147+
assert unique_chars >= 20 # Should see at least 20 different characters
148+
149+
150+
if __name__ == "__main__":
151+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)