1
0
Files
nex/scripts/core.py
lanyuanxiaoyao 7f0f831226 feat: 抽取 scripts/core.py 公共模块,重构检测脚本
将 anthropic_detect.py 和 openai_detect.py 中的公共功能抽取到
core.py 模块,包括:
- HTTP 请求(普通/流式)及重试逻辑
- SSL 上下文管理
- 测试用例/结果数据结构 (TestCase, TestResult)
- 错误分类 (ErrorType)
- 响应验证辅助函数 (validate_response_structure 等)
- 测试执行框架 (run_test, run_test_suite)

两个检测脚本重构后更聚焦于各自 API 的测试用例定义。
2026-04-21 11:45:21 +08:00

472 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""兼容性测试脚本的核心公共函数
提供 HTTP 请求、SSL 上下文、JSON 格式化、验证辅助等通用功能。
"""
import json
import time
import ssl
import urllib.request
import urllib.error
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple, List, Union, Type
from enum import Enum
TIMEOUT = 30
MAX_RETRIES = 2 # 最大重试次数
class ErrorType(Enum):
"""错误类型分类"""
NETWORK = "network" # 网络错误
CLIENT = "client" # 4xx 错误
SERVER = "server" # 5xx 错误
SUCCESS = "success" # 成功
@dataclass
class TestCase:
"""测试用例数据结构"""
desc: str # 测试描述
method: str # HTTP 方法
url: str # 请求 URL
headers: Dict[str, str] # 请求头
body: Optional[Any] = None # 请求体
stream: bool = False # 是否流式请求
validator: Optional[Any] = None # 响应验证函数(可选)
@dataclass
class TestResult:
"""测试结果数据结构"""
status: Optional[int] # HTTP 状态码
elapsed: float # 耗时(秒)
error_type: ErrorType # 错误类型
response: str # 响应内容
def create_ssl_context() -> ssl.SSLContext:
"""创建不验证证书的 SSL 上下文(用于测试环境)"""
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx
def classify_error(status: Optional[int]) -> ErrorType:
"""根据状态码分类错误类型"""
if status is None:
return ErrorType.NETWORK
if 200 <= status < 300:
return ErrorType.SUCCESS
if 400 <= status < 500:
return ErrorType.CLIENT
if status >= 500:
return ErrorType.SERVER
return ErrorType.NETWORK
def http_request(
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
ssl_ctx: Optional[ssl.SSLContext] = None,
retries: int = MAX_RETRIES
) -> TestResult:
"""执行普通 HTTP 请求(支持重试)
Args:
url: 请求 URL
method: HTTP 方法 (GET/POST/PUT/DELETE)
headers: 请求头字典
body: 请求体 (dict 或 str)
ssl_ctx: SSL 上下文
retries: 重试次数
Returns:
TestResult 对象
"""
req = urllib.request.Request(url, method=method)
if headers:
for k, v in headers.items():
req.add_header(k, v)
if body is not None:
if isinstance(body, str):
req.data = body.encode("utf-8")
else:
req.data = json.dumps(body).encode("utf-8")
start = time.time()
last_error = None
for attempt in range(retries + 1):
try:
resp = urllib.request.urlopen(req, timeout=TIMEOUT, context=ssl_ctx)
elapsed = time.time() - start
status = resp.getcode()
return TestResult(
status=status,
elapsed=elapsed,
error_type=classify_error(status),
response=resp.read().decode("utf-8")
)
except urllib.error.HTTPError as e:
elapsed = time.time() - start
return TestResult(
status=e.code,
elapsed=elapsed,
error_type=classify_error(e.code),
response=e.read().decode("utf-8")
)
except Exception as e:
last_error = str(e)
if attempt < retries:
time.sleep(0.5 * (attempt + 1)) # 递增延迟
continue
elapsed = time.time() - start
return TestResult(
status=None,
elapsed=elapsed,
error_type=ErrorType.NETWORK,
response=last_error or "Unknown error"
)
def http_stream_request(
url: str,
headers: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
ssl_ctx: Optional[ssl.SSLContext] = None,
retries: int = MAX_RETRIES
) -> TestResult:
"""执行流式 HTTP 请求 (SSE支持重试)
Args:
url: 请求 URL
headers: 请求头字典
body: 请求体 (dict)
ssl_ctx: SSL 上下文
retries: 重试次数
Returns:
TestResult 对象
"""
req = urllib.request.Request(url, method="POST")
if headers:
for k, v in headers.items():
req.add_header(k, v)
if body is not None:
req.data = json.dumps(body).encode("utf-8")
start = time.time()
last_error = None
for attempt in range(retries + 1):
try:
resp = urllib.request.urlopen(req, timeout=TIMEOUT, context=ssl_ctx)
status = resp.getcode()
lines = []
for raw_line in resp:
line = raw_line.decode("utf-8").rstrip("\n\r")
if line:
lines.append(line)
elapsed = time.time() - start
return TestResult(
status=status,
elapsed=elapsed,
error_type=classify_error(status),
response="\n".join(lines)
)
except urllib.error.HTTPError as e:
elapsed = time.time() - start
return TestResult(
status=e.code,
elapsed=elapsed,
error_type=classify_error(e.code),
response=e.read().decode("utf-8")
)
except Exception as e:
last_error = str(e)
if attempt < retries:
time.sleep(0.5 * (attempt + 1))
continue
elapsed = time.time() - start
return TestResult(
status=None,
elapsed=elapsed,
error_type=ErrorType.NETWORK,
response=last_error or "Unknown error"
)
def format_json(text: str) -> str:
"""格式化 JSON 文本(用于美化输出)
Args:
text: JSON 字符串或任意文本
Returns:
格式化后的 JSON 字符串,或原文本(如果不是有效 JSON
"""
try:
parsed = json.loads(text)
return json.dumps(parsed, ensure_ascii=False, indent=2)
except (json.JSONDecodeError, TypeError):
return text
def run_test(
index: int,
total: int,
test_case: TestCase,
ssl_ctx: ssl.SSLContext
) -> TestResult:
"""执行单个测试用例并打印结果
Args:
index: 测试序号
total: 总测试数
test_case: 测试用例对象
ssl_ctx: SSL 上下文
Returns:
TestResult 对象
"""
print(f"\n[{index}/{total}] {test_case.desc}")
print(f">>> {test_case.method} {test_case.url}")
if test_case.body is not None:
if isinstance(test_case.body, str):
print(test_case.body)
else:
print(format_json(json.dumps(test_case.body, ensure_ascii=False)))
if test_case.stream:
result = http_stream_request(
test_case.url,
test_case.headers,
test_case.body,
ssl_ctx
)
else:
result = http_request(
test_case.url,
test_case.method,
test_case.headers,
test_case.body,
ssl_ctx
)
if result.status is not None:
print(f"状态码: {result.status} | 耗时: {result.elapsed:.2f}s")
else:
print(f"请求失败 | 耗时: {result.elapsed:.2f}s")
if test_case.stream and result.status and result.status < 300:
# 流式响应按 SSE 行逐行输出
for line in result.response.split("\n"):
print(line)
else:
print(format_json(result.response))
# 执行响应验证
if test_case.validator and result.status and 200 <= result.status < 300:
is_valid, errors = test_case.validator(result.response)
if is_valid:
print("✓ 响应验证通过")
else:
print("✗ 响应验证失败:")
for error in errors:
print(f" - {error}")
return result
def run_test_suite(
cases: List[TestCase],
ssl_ctx: ssl.SSLContext,
title: str,
base_url: str,
model: str,
flags: Optional[List[str]] = None
) -> Tuple[int, int, int, int]:
"""执行测试套件并打印总结
Args:
cases: 测试用例列表
ssl_ctx: SSL 上下文
title: 测试标题
base_url: API 基础地址
model: 模型名称
flags: 扩展测试标记列表
Returns:
(总数, 成功数, 客户端错误数, 服务端错误数)
"""
total = len(cases)
count_success = 0
count_client_error = 0
count_server_error = 0
count_network_error = 0
print("=" * 60)
print(title)
print(f"目标: {base_url}")
print(f"模型: {model}")
print(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
if flags:
print(f"用例: {total} 个 | 扩展: {', '.join(flags)}")
else:
print(f"用例: {total}")
print("=" * 60)
for i, test_case in enumerate(cases, 1):
result = run_test(i, total, test_case, ssl_ctx)
if result.error_type == ErrorType.SUCCESS:
count_success += 1
elif result.error_type == ErrorType.CLIENT:
count_client_error += 1
elif result.error_type == ErrorType.SERVER:
count_server_error += 1
else:
count_network_error += 1
print()
print("=" * 60)
print(f"测试完成 | 总计: {total} | 成功: {count_success} | "
f"客户端错误: {count_client_error} | 服务端错误: {count_server_error} | "
f"网络错误: {count_network_error}")
print("=" * 60)
return total, count_success, count_client_error, count_server_error
# ==================== 通用验证辅助函数 ====================
def check_required_fields(data: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
"""检查必需字段是否存在
Args:
data: 待检查的数据字典
required_fields: 必需字段列表
Returns:
(是否全部存在, 缺失字段列表)
"""
missing = []
for field in required_fields:
if field not in data:
missing.append(field)
return len(missing) == 0, missing
def check_field_type(value: Any, expected_type: Union[Type, tuple]) -> bool:
"""检查字段类型是否正确
Args:
value: 待检查的值
expected_type: 期望的类型(可以是类型元组)
Returns:
类型是否匹配
"""
if value is None:
return True # None值通常表示可选字段允许
return isinstance(value, expected_type)
def check_enum_value(value: Any, allowed_values: List[Any]) -> bool:
"""检查值是否在允许的枚举值列表中
Args:
value: 待检查的值
allowed_values: 允许的值列表
Returns:
值是否合法
"""
if value is None:
return True # None值通常表示可选字段允许
return value in allowed_values
def check_array_items_type(arr: List[Any], expected_item_type: Union[Type, tuple]) -> bool:
"""检查数组中所有元素的类型
Args:
arr: 待检查的数组
expected_item_type: 期望的元素类型
Returns:
所有元素类型是否匹配
"""
if not isinstance(arr, list):
return False
return all(check_field_type(item, expected_item_type) for item in arr)
def format_validation_errors(errors: List[str]) -> str:
"""格式化验证错误信息
Args:
errors: 错误信息列表
Returns:
格式化后的错误字符串
"""
if not errors:
return "验证通过"
return "验证失败:\n - " + "\n - ".join(errors)
def validate_response_structure(
response_text: str,
required_fields: List[str],
field_types: Optional[Dict[str, Union[Type, tuple]]] = None,
enum_values: Optional[Dict[str, List[Any]]] = None
) -> Tuple[bool, List[str]]:
"""验证响应结构(通用验证函数)
Args:
response_text: 响应文本
required_fields: 必需字段列表
field_types: 字段类型映射 {字段名: 期望类型}
enum_values: 枚举值映射 {字段名: 允许值列表}
Returns:
(是否验证通过, 错误信息列表)
"""
errors = []
# 尝试解析JSON
try:
data = json.loads(response_text)
except json.JSONDecodeError as e:
errors.append(f"响应不是有效的JSON: {e}")
return False, errors
# 检查必需字段
has_required, missing = check_required_fields(data, required_fields)
if not has_required:
errors.append(f"缺少必需字段: {', '.join(missing)}")
# 检查字段类型
if field_types:
for field, expected_type in field_types.items():
if field in data and not check_field_type(data[field], expected_type):
actual_type = type(data[field]).__name__
expected_name = expected_type.__name__ if isinstance(expected_type, type) else str(expected_type)
errors.append(f"字段 '{field}' 类型错误: 期望 {expected_name}, 实际 {actual_type}")
# 检查枚举值
if enum_values:
for field, allowed in enum_values.items():
if field in data and not check_enum_value(data[field], allowed):
errors.append(f"字段 '{field}' 值非法: {data[field]}, 允许值: {allowed}")
return len(errors) == 0, errors