将 anthropic_detect.py 和 openai_detect.py 中的公共功能抽取到 core.py 模块,包括: - HTTP 请求(普通/流式)及重试逻辑 - SSL 上下文管理 - 测试用例/结果数据结构 (TestCase, TestResult) - 错误分类 (ErrorType) - 响应验证辅助函数 (validate_response_structure 等) - 测试执行框架 (run_test, run_test_suite) 两个检测脚本重构后更聚焦于各自 API 的测试用例定义。
472 lines
13 KiB
Python
472 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""兼容性测试脚本的核心公共函数
|
||
|
||
提供 HTTP 请求、SSL 上下文、JSON 格式化、验证辅助等通用功能。
|
||
"""
|
||
|
||
import json
|
||
import time
|
||
import ssl
|
||
import urllib.request
|
||
import urllib.error
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Dict, Any, Tuple, List, Union, Type
|
||
from enum import Enum
|
||
|
||
TIMEOUT = 30
|
||
MAX_RETRIES = 2 # 最大重试次数
|
||
|
||
|
||
class ErrorType(Enum):
|
||
"""错误类型分类"""
|
||
NETWORK = "network" # 网络错误
|
||
CLIENT = "client" # 4xx 错误
|
||
SERVER = "server" # 5xx 错误
|
||
SUCCESS = "success" # 成功
|
||
|
||
|
||
@dataclass
|
||
class TestCase:
|
||
"""测试用例数据结构"""
|
||
desc: str # 测试描述
|
||
method: str # HTTP 方法
|
||
url: str # 请求 URL
|
||
headers: Dict[str, str] # 请求头
|
||
body: Optional[Any] = None # 请求体
|
||
stream: bool = False # 是否流式请求
|
||
validator: Optional[Any] = None # 响应验证函数(可选)
|
||
|
||
|
||
@dataclass
|
||
class TestResult:
|
||
"""测试结果数据结构"""
|
||
status: Optional[int] # HTTP 状态码
|
||
elapsed: float # 耗时(秒)
|
||
error_type: ErrorType # 错误类型
|
||
response: str # 响应内容
|
||
|
||
|
||
def create_ssl_context() -> ssl.SSLContext:
|
||
"""创建不验证证书的 SSL 上下文(用于测试环境)"""
|
||
ctx = ssl.create_default_context()
|
||
ctx.check_hostname = False
|
||
ctx.verify_mode = ssl.CERT_NONE
|
||
return ctx
|
||
|
||
|
||
def classify_error(status: Optional[int]) -> ErrorType:
|
||
"""根据状态码分类错误类型"""
|
||
if status is None:
|
||
return ErrorType.NETWORK
|
||
if 200 <= status < 300:
|
||
return ErrorType.SUCCESS
|
||
if 400 <= status < 500:
|
||
return ErrorType.CLIENT
|
||
if status >= 500:
|
||
return ErrorType.SERVER
|
||
return ErrorType.NETWORK
|
||
|
||
|
||
def http_request(
|
||
url: str,
|
||
method: str = "GET",
|
||
headers: Optional[Dict[str, str]] = None,
|
||
body: Optional[Any] = None,
|
||
ssl_ctx: Optional[ssl.SSLContext] = None,
|
||
retries: int = MAX_RETRIES
|
||
) -> TestResult:
|
||
"""执行普通 HTTP 请求(支持重试)
|
||
|
||
Args:
|
||
url: 请求 URL
|
||
method: HTTP 方法 (GET/POST/PUT/DELETE)
|
||
headers: 请求头字典
|
||
body: 请求体 (dict 或 str)
|
||
ssl_ctx: SSL 上下文
|
||
retries: 重试次数
|
||
|
||
Returns:
|
||
TestResult 对象
|
||
"""
|
||
req = urllib.request.Request(url, method=method)
|
||
if headers:
|
||
for k, v in headers.items():
|
||
req.add_header(k, v)
|
||
if body is not None:
|
||
if isinstance(body, str):
|
||
req.data = body.encode("utf-8")
|
||
else:
|
||
req.data = json.dumps(body).encode("utf-8")
|
||
|
||
start = time.time()
|
||
last_error = None
|
||
|
||
for attempt in range(retries + 1):
|
||
try:
|
||
resp = urllib.request.urlopen(req, timeout=TIMEOUT, context=ssl_ctx)
|
||
elapsed = time.time() - start
|
||
status = resp.getcode()
|
||
return TestResult(
|
||
status=status,
|
||
elapsed=elapsed,
|
||
error_type=classify_error(status),
|
||
response=resp.read().decode("utf-8")
|
||
)
|
||
except urllib.error.HTTPError as e:
|
||
elapsed = time.time() - start
|
||
return TestResult(
|
||
status=e.code,
|
||
elapsed=elapsed,
|
||
error_type=classify_error(e.code),
|
||
response=e.read().decode("utf-8")
|
||
)
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
if attempt < retries:
|
||
time.sleep(0.5 * (attempt + 1)) # 递增延迟
|
||
continue
|
||
|
||
elapsed = time.time() - start
|
||
return TestResult(
|
||
status=None,
|
||
elapsed=elapsed,
|
||
error_type=ErrorType.NETWORK,
|
||
response=last_error or "Unknown error"
|
||
)
|
||
|
||
|
||
def http_stream_request(
|
||
url: str,
|
||
headers: Optional[Dict[str, str]] = None,
|
||
body: Optional[Any] = None,
|
||
ssl_ctx: Optional[ssl.SSLContext] = None,
|
||
retries: int = MAX_RETRIES
|
||
) -> TestResult:
|
||
"""执行流式 HTTP 请求 (SSE,支持重试)
|
||
|
||
Args:
|
||
url: 请求 URL
|
||
headers: 请求头字典
|
||
body: 请求体 (dict)
|
||
ssl_ctx: SSL 上下文
|
||
retries: 重试次数
|
||
|
||
Returns:
|
||
TestResult 对象
|
||
"""
|
||
req = urllib.request.Request(url, method="POST")
|
||
if headers:
|
||
for k, v in headers.items():
|
||
req.add_header(k, v)
|
||
if body is not None:
|
||
req.data = json.dumps(body).encode("utf-8")
|
||
|
||
start = time.time()
|
||
last_error = None
|
||
|
||
for attempt in range(retries + 1):
|
||
try:
|
||
resp = urllib.request.urlopen(req, timeout=TIMEOUT, context=ssl_ctx)
|
||
status = resp.getcode()
|
||
lines = []
|
||
for raw_line in resp:
|
||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||
if line:
|
||
lines.append(line)
|
||
elapsed = time.time() - start
|
||
return TestResult(
|
||
status=status,
|
||
elapsed=elapsed,
|
||
error_type=classify_error(status),
|
||
response="\n".join(lines)
|
||
)
|
||
except urllib.error.HTTPError as e:
|
||
elapsed = time.time() - start
|
||
return TestResult(
|
||
status=e.code,
|
||
elapsed=elapsed,
|
||
error_type=classify_error(e.code),
|
||
response=e.read().decode("utf-8")
|
||
)
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
if attempt < retries:
|
||
time.sleep(0.5 * (attempt + 1))
|
||
continue
|
||
|
||
elapsed = time.time() - start
|
||
return TestResult(
|
||
status=None,
|
||
elapsed=elapsed,
|
||
error_type=ErrorType.NETWORK,
|
||
response=last_error or "Unknown error"
|
||
)
|
||
|
||
|
||
def format_json(text: str) -> str:
|
||
"""格式化 JSON 文本(用于美化输出)
|
||
|
||
Args:
|
||
text: JSON 字符串或任意文本
|
||
|
||
Returns:
|
||
格式化后的 JSON 字符串,或原文本(如果不是有效 JSON)
|
||
"""
|
||
try:
|
||
parsed = json.loads(text)
|
||
return json.dumps(parsed, ensure_ascii=False, indent=2)
|
||
except (json.JSONDecodeError, TypeError):
|
||
return text
|
||
|
||
|
||
def run_test(
|
||
index: int,
|
||
total: int,
|
||
test_case: TestCase,
|
||
ssl_ctx: ssl.SSLContext
|
||
) -> TestResult:
|
||
"""执行单个测试用例并打印结果
|
||
|
||
Args:
|
||
index: 测试序号
|
||
total: 总测试数
|
||
test_case: 测试用例对象
|
||
ssl_ctx: SSL 上下文
|
||
|
||
Returns:
|
||
TestResult 对象
|
||
"""
|
||
print(f"\n[{index}/{total}] {test_case.desc}")
|
||
print(f">>> {test_case.method} {test_case.url}")
|
||
if test_case.body is not None:
|
||
if isinstance(test_case.body, str):
|
||
print(test_case.body)
|
||
else:
|
||
print(format_json(json.dumps(test_case.body, ensure_ascii=False)))
|
||
|
||
if test_case.stream:
|
||
result = http_stream_request(
|
||
test_case.url,
|
||
test_case.headers,
|
||
test_case.body,
|
||
ssl_ctx
|
||
)
|
||
else:
|
||
result = http_request(
|
||
test_case.url,
|
||
test_case.method,
|
||
test_case.headers,
|
||
test_case.body,
|
||
ssl_ctx
|
||
)
|
||
|
||
if result.status is not None:
|
||
print(f"状态码: {result.status} | 耗时: {result.elapsed:.2f}s")
|
||
else:
|
||
print(f"请求失败 | 耗时: {result.elapsed:.2f}s")
|
||
|
||
if test_case.stream and result.status and result.status < 300:
|
||
# 流式响应按 SSE 行逐行输出
|
||
for line in result.response.split("\n"):
|
||
print(line)
|
||
else:
|
||
print(format_json(result.response))
|
||
|
||
# 执行响应验证
|
||
if test_case.validator and result.status and 200 <= result.status < 300:
|
||
is_valid, errors = test_case.validator(result.response)
|
||
if is_valid:
|
||
print("✓ 响应验证通过")
|
||
else:
|
||
print("✗ 响应验证失败:")
|
||
for error in errors:
|
||
print(f" - {error}")
|
||
|
||
return result
|
||
|
||
|
||
def run_test_suite(
|
||
cases: List[TestCase],
|
||
ssl_ctx: ssl.SSLContext,
|
||
title: str,
|
||
base_url: str,
|
||
model: str,
|
||
flags: Optional[List[str]] = None
|
||
) -> Tuple[int, int, int, int]:
|
||
"""执行测试套件并打印总结
|
||
|
||
Args:
|
||
cases: 测试用例列表
|
||
ssl_ctx: SSL 上下文
|
||
title: 测试标题
|
||
base_url: API 基础地址
|
||
model: 模型名称
|
||
flags: 扩展测试标记列表
|
||
|
||
Returns:
|
||
(总数, 成功数, 客户端错误数, 服务端错误数)
|
||
"""
|
||
total = len(cases)
|
||
count_success = 0
|
||
count_client_error = 0
|
||
count_server_error = 0
|
||
count_network_error = 0
|
||
|
||
print("=" * 60)
|
||
print(title)
|
||
print(f"目标: {base_url}")
|
||
print(f"模型: {model}")
|
||
print(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
if flags:
|
||
print(f"用例: {total} 个 | 扩展: {', '.join(flags)}")
|
||
else:
|
||
print(f"用例: {total} 个")
|
||
print("=" * 60)
|
||
|
||
for i, test_case in enumerate(cases, 1):
|
||
result = run_test(i, total, test_case, ssl_ctx)
|
||
|
||
if result.error_type == ErrorType.SUCCESS:
|
||
count_success += 1
|
||
elif result.error_type == ErrorType.CLIENT:
|
||
count_client_error += 1
|
||
elif result.error_type == ErrorType.SERVER:
|
||
count_server_error += 1
|
||
else:
|
||
count_network_error += 1
|
||
|
||
print()
|
||
print("=" * 60)
|
||
print(f"测试完成 | 总计: {total} | 成功: {count_success} | "
|
||
f"客户端错误: {count_client_error} | 服务端错误: {count_server_error} | "
|
||
f"网络错误: {count_network_error}")
|
||
print("=" * 60)
|
||
|
||
return total, count_success, count_client_error, count_server_error
|
||
|
||
|
||
# ==================== 通用验证辅助函数 ====================
|
||
|
||
def check_required_fields(data: Dict[str, Any], required_fields: List[str]) -> Tuple[bool, List[str]]:
|
||
"""检查必需字段是否存在
|
||
|
||
Args:
|
||
data: 待检查的数据字典
|
||
required_fields: 必需字段列表
|
||
|
||
Returns:
|
||
(是否全部存在, 缺失字段列表)
|
||
"""
|
||
missing = []
|
||
for field in required_fields:
|
||
if field not in data:
|
||
missing.append(field)
|
||
return len(missing) == 0, missing
|
||
|
||
|
||
def check_field_type(value: Any, expected_type: Union[Type, tuple]) -> bool:
|
||
"""检查字段类型是否正确
|
||
|
||
Args:
|
||
value: 待检查的值
|
||
expected_type: 期望的类型(可以是类型元组)
|
||
|
||
Returns:
|
||
类型是否匹配
|
||
"""
|
||
if value is None:
|
||
return True # None值通常表示可选字段,允许
|
||
return isinstance(value, expected_type)
|
||
|
||
|
||
def check_enum_value(value: Any, allowed_values: List[Any]) -> bool:
|
||
"""检查值是否在允许的枚举值列表中
|
||
|
||
Args:
|
||
value: 待检查的值
|
||
allowed_values: 允许的值列表
|
||
|
||
Returns:
|
||
值是否合法
|
||
"""
|
||
if value is None:
|
||
return True # None值通常表示可选字段,允许
|
||
return value in allowed_values
|
||
|
||
|
||
def check_array_items_type(arr: List[Any], expected_item_type: Union[Type, tuple]) -> bool:
|
||
"""检查数组中所有元素的类型
|
||
|
||
Args:
|
||
arr: 待检查的数组
|
||
expected_item_type: 期望的元素类型
|
||
|
||
Returns:
|
||
所有元素类型是否匹配
|
||
"""
|
||
if not isinstance(arr, list):
|
||
return False
|
||
return all(check_field_type(item, expected_item_type) for item in arr)
|
||
|
||
|
||
def format_validation_errors(errors: List[str]) -> str:
|
||
"""格式化验证错误信息
|
||
|
||
Args:
|
||
errors: 错误信息列表
|
||
|
||
Returns:
|
||
格式化后的错误字符串
|
||
"""
|
||
if not errors:
|
||
return "验证通过"
|
||
return "验证失败:\n - " + "\n - ".join(errors)
|
||
|
||
|
||
def validate_response_structure(
|
||
response_text: str,
|
||
required_fields: List[str],
|
||
field_types: Optional[Dict[str, Union[Type, tuple]]] = None,
|
||
enum_values: Optional[Dict[str, List[Any]]] = None
|
||
) -> Tuple[bool, List[str]]:
|
||
"""验证响应结构(通用验证函数)
|
||
|
||
Args:
|
||
response_text: 响应文本
|
||
required_fields: 必需字段列表
|
||
field_types: 字段类型映射 {字段名: 期望类型}
|
||
enum_values: 枚举值映射 {字段名: 允许值列表}
|
||
|
||
Returns:
|
||
(是否验证通过, 错误信息列表)
|
||
"""
|
||
errors = []
|
||
|
||
# 尝试解析JSON
|
||
try:
|
||
data = json.loads(response_text)
|
||
except json.JSONDecodeError as e:
|
||
errors.append(f"响应不是有效的JSON: {e}")
|
||
return False, errors
|
||
|
||
# 检查必需字段
|
||
has_required, missing = check_required_fields(data, required_fields)
|
||
if not has_required:
|
||
errors.append(f"缺少必需字段: {', '.join(missing)}")
|
||
|
||
# 检查字段类型
|
||
if field_types:
|
||
for field, expected_type in field_types.items():
|
||
if field in data and not check_field_type(data[field], expected_type):
|
||
actual_type = type(data[field]).__name__
|
||
expected_name = expected_type.__name__ if isinstance(expected_type, type) else str(expected_type)
|
||
errors.append(f"字段 '{field}' 类型错误: 期望 {expected_name}, 实际 {actual_type}")
|
||
|
||
# 检查枚举值
|
||
if enum_values:
|
||
for field, allowed in enum_values.items():
|
||
if field in data and not check_enum_value(data[field], allowed):
|
||
errors.append(f"字段 '{field}' 值非法: {data[field]}, 允许值: {allowed}")
|
||
|
||
return len(errors) == 0, errors
|