import type { DnsResponse } from "./codec"; import { parseResponse } from "./codec"; export type DnsQueryResult = DnsTransportError | DnsTransportResult; export interface DnsTransportError { error: string; ok: false; } export interface DnsTransportResult { data: Uint8Array; ok: true; protocolUsed: "tcp" | "udp"; response: DnsResponse; } interface QueryMeta { id: number; name: string; qclass: number; qtype: number; } export async function queryDns( server: string, port: number, query: Uint8Array, options: { maxResponseBytes: number; protocol: "tcp" | "udp"; signal: AbortSignal; tcpFallback: boolean; }, ): Promise { if (options.protocol === "tcp") { return queryTcp(server, port, query, options.signal, options.maxResponseBytes); } const udpResult = await queryUdp(server, port, query, options.signal, options.maxResponseBytes); if (!udpResult.ok) return udpResult; if (udpResult.response.header.flags.truncated && options.tcpFallback) { const tcpResult = await queryTcp(server, port, query, options.signal, options.maxResponseBytes); if (tcpResult.ok) { return { ...tcpResult, protocolUsed: "tcp" }; } return udpResult; } return udpResult; } function mergeChunks(chunks: Uint8Array[], totalBytes: number): Uint8Array { const result = new Uint8Array(totalBytes); let offset = 0; for (const chunk of chunks) { result.set(chunk, offset); offset += chunk.byteLength; } return result; } function parseAndValidateResponse(query: Uint8Array, payload: Uint8Array, protocolUsed: "tcp" | "udp"): DnsQueryResult { try { const response = parseResponse(payload); const validationError = validateResponseForQuery(query, response); if (validationError) { return { error: validationError, ok: false }; } return { data: payload, ok: true, protocolUsed, response }; } catch (e) { return { error: `DNS 响应解析失败: ${e instanceof Error ? e.message : String(e)}`, ok: false }; } } async function queryTcp( server: string, port: number, query: Uint8Array, signal: AbortSignal, maxResponseBytes: number, ): Promise { const chunks: Uint8Array[] = []; let totalBytes = 0; let settled = false; let resolver: ((value: DnsQueryResult) => void) | undefined; const promise = new Promise((resolve) => { resolver = resolve; }); const settle = (result: DnsQueryResult) => { if (settled) return; settled = true; resolver!(result); }; const socketHandlers: Record void> = { close() { if (totalBytes >= 2) { const full = mergeChunks(chunks, totalBytes); const view = new DataView(full.buffer, full.byteOffset, full.byteLength); const respLen = view.getUint16(0); const payloadLen = Math.min(respLen, maxResponseBytes); if (totalBytes - 2 >= payloadLen) { if (respLen > maxResponseBytes) { settle({ error: `TCP 响应超过 ${maxResponseBytes} 字节限制 (${respLen} bytes)`, ok: false }); return; } const payload = full.subarray(2, 2 + payloadLen); settle(parseAndValidateResponse(query, payload, "tcp")); } else { settle({ error: `TCP 响应不完整: 期望 ${respLen} 字节,收到 ${totalBytes - 2} 字节`, ok: false }); } } else { settle({ error: "TCP 连接关闭,未收到响应", ok: false }); } }, data(_socket: unknown, data: unknown) { const buf = data instanceof Uint8Array ? data : new Uint8Array(data as ArrayBuffer); if (totalBytes + buf.byteLength > maxResponseBytes + 2) { const trimmed = buf.subarray(0, maxResponseBytes + 2 - totalBytes); if (trimmed.byteLength > 0) { chunks.push(new Uint8Array(trimmed)); totalBytes += trimmed.byteLength; } } else { chunks.push(new Uint8Array(buf)); totalBytes += buf.byteLength; } if (totalBytes >= 2) { const full = mergeChunks(chunks, totalBytes); const view = new DataView(full.buffer, full.byteOffset, full.byteLength); const respLen = view.getUint16(0); const payloadLen = Math.min(respLen, maxResponseBytes); if (totalBytes - 2 >= payloadLen) { try { (_socket as { close(): void }).close(); } catch { /* best-effort */ } } } }, error(_socket: unknown, error: unknown) { settle({ error: error instanceof Error ? error.message : String(error), ok: false }); }, open() { // Bun socket handler 必填项,连接成功由 Bun.connect() resolve 表示 }, }; const onAbort = () => { settle({ error: "探测超时", ok: false }); }; signal.addEventListener("abort", onAbort, { once: true }); try { const socket = await Bun.connect({ hostname: server, port, socket: socketHandlers, }); if (signal.aborted) { try { socket.close(); } catch { /* best-effort */ } signal.removeEventListener("abort", onAbort); return promise; } const lengthBuf = new Uint8Array(2); new DataView(lengthBuf.buffer).setUint16(0, query.byteLength); socket.write(lengthBuf); socket.write(query); const result = await promise; signal.removeEventListener("abort", onAbort); try { socket.close(); } catch { /* best-effort */ } return result; } catch (error) { signal.removeEventListener("abort", onAbort); if (signal.aborted) { return { error: "探测超时", ok: false }; } const message = error instanceof Error ? error.message : String(error); return { error: simplifyError(message), ok: false }; } } async function queryUdp( server: string, port: number, query: Uint8Array, signal: AbortSignal, maxResponseBytes: number, ): Promise { try { const socket = await Bun.udpSocket({ connect: { hostname: server, port }, socket: { data(socket, data) { if (data.byteLength > maxResponseBytes) { settle({ error: `UDP 响应超过 ${maxResponseBytes} 字节限制 (${data.byteLength} bytes)`, type: "error" }); try { socket.close(); } catch { /* best-effort */ } return; } settle({ data: new Uint8Array(data.buffer, data.byteOffset, data.byteLength), type: "data" }); try { socket.close(); } catch { /* best-effort */ } }, drain() { // Bun UDP socket handler 必填项,DNS checker 不关注 drain 事件 }, error(_socket, error) { settle({ error: error.message, type: "error" }); try { _socket.close(); } catch { /* best-effort */ } }, }, }); if (signal.aborted) { try { socket.close(); } catch { /* best-effort */ } return { error: "探测已取消", ok: false }; } let settled = false; let resolver: ((value: { data?: Uint8Array; error?: string; type: string }) => void) | undefined; const promise = new Promise<{ data?: Uint8Array; error?: string; type: string }>((resolve) => { resolver = resolve; }); const settle = (result: { data?: Uint8Array; error?: string; type: string }) => { if (settled) return; settled = true; resolver!(result); }; const onAbort = () => { settle({ type: "abort" }); try { socket.close(); } catch { /* best-effort */ } }; signal.addEventListener("abort", onAbort, { once: true }); socket.send(query); const result = await promise; signal.removeEventListener("abort", onAbort); if (result.type === "error") { return { error: result.error ?? "UDP 查询失败", ok: false }; } if (result.type === "abort") { return { error: "探测超时", ok: false }; } if (!result.data) { return { error: "未收到 UDP 响应", ok: false }; } return parseAndValidateResponse(query, result.data, "udp"); } catch (error) { if (signal.aborted) { return { error: "探测超时", ok: false }; } const message = error instanceof Error ? error.message : String(error); return { error: simplifyError(message), ok: false }; } } function readQueryMeta(query: Uint8Array): null | QueryMeta { if (query.byteLength < 12) return null; const view = new DataView(query.buffer, query.byteOffset, query.byteLength); const id = view.getUint16(0); const labels: string[] = []; let offset = 12; while (true) { if (offset >= query.byteLength) return null; const len = query[offset]!; offset++; if (len === 0) break; if ((len & 0xc0) !== 0 || offset + len > query.byteLength) return null; labels.push(new TextDecoder().decode(query.subarray(offset, offset + len))); offset += len; } if (offset + 4 > query.byteLength) return null; const qtype = view.getUint16(offset); const qclass = view.getUint16(offset + 2); return { id, name: labels.join("."), qclass, qtype }; } function simplifyError(message: string): string { const lower = message.toLowerCase(); if (lower.includes("econnrefused") || lower.includes("connection refused")) return "connection refused"; if (lower.includes("enoent") || lower.includes("not found")) return "host not found"; if (lower.includes("etimedout") || lower.includes("timed out")) return "timed out"; if (lower.includes("econnreset") || lower.includes("reset")) return "connection reset"; if (lower.includes("enetwork") || lower.includes("network")) return "network error"; return message; } function validateResponseForQuery(query: Uint8Array, response: DnsResponse): null | string { const meta = readQueryMeta(query); if (!meta) return "DNS 查询报文不完整"; if (response.header.id !== meta.id) { return `DNS 响应 ID 不匹配: 期望 ${meta.id},实际 ${response.header.id}`; } const question = response.questions[0]; if (question && (question.name !== meta.name || question.qtype !== meta.qtype || question.qclass !== meta.qclass)) { return "DNS 响应 question 与查询不匹配"; } return null; }