350 lines
10 KiB
TypeScript
350 lines
10 KiB
TypeScript
import type { DnsResponse } from "./codec";
|
||
|
||
import { parseResponse } from "./codec";
|
||
|
||
export type DnsQueryResult = DnsTransportError | DnsTransportResult;
|
||
|
||
export interface DnsTransportError {
|
||
error: string;
|
||
ok: false;
|
||
}
|
||
|
||
export interface DnsTransportResult {
|
||
data: Uint8Array;
|
||
ok: true;
|
||
protocolUsed: "tcp" | "udp";
|
||
response: DnsResponse;
|
||
}
|
||
|
||
interface QueryMeta {
|
||
id: number;
|
||
name: string;
|
||
qclass: number;
|
||
qtype: number;
|
||
}
|
||
|
||
export async function queryDns(
|
||
server: string,
|
||
port: number,
|
||
query: Uint8Array,
|
||
options: {
|
||
maxResponseBytes: number;
|
||
protocol: "tcp" | "udp";
|
||
signal: AbortSignal;
|
||
tcpFallback: boolean;
|
||
},
|
||
): Promise<DnsQueryResult> {
|
||
if (options.protocol === "tcp") {
|
||
return queryTcp(server, port, query, options.signal, options.maxResponseBytes);
|
||
}
|
||
|
||
const udpResult = await queryUdp(server, port, query, options.signal, options.maxResponseBytes);
|
||
if (!udpResult.ok) return udpResult;
|
||
|
||
if (udpResult.response.header.flags.truncated && options.tcpFallback) {
|
||
const tcpResult = await queryTcp(server, port, query, options.signal, options.maxResponseBytes);
|
||
if (tcpResult.ok) {
|
||
return { ...tcpResult, protocolUsed: "tcp" };
|
||
}
|
||
return udpResult;
|
||
}
|
||
|
||
return udpResult;
|
||
}
|
||
|
||
function mergeChunks(chunks: Uint8Array[], totalBytes: number): Uint8Array {
|
||
const result = new Uint8Array(totalBytes);
|
||
let offset = 0;
|
||
for (const chunk of chunks) {
|
||
result.set(chunk, offset);
|
||
offset += chunk.byteLength;
|
||
}
|
||
return result;
|
||
}
|
||
|
||
function parseAndValidateResponse(query: Uint8Array, payload: Uint8Array, protocolUsed: "tcp" | "udp"): DnsQueryResult {
|
||
try {
|
||
const response = parseResponse(payload);
|
||
const validationError = validateResponseForQuery(query, response);
|
||
if (validationError) {
|
||
return { error: validationError, ok: false };
|
||
}
|
||
return { data: payload, ok: true, protocolUsed, response };
|
||
} catch (e) {
|
||
return { error: `DNS 响应解析失败: ${e instanceof Error ? e.message : String(e)}`, ok: false };
|
||
}
|
||
}
|
||
|
||
async function queryTcp(
|
||
server: string,
|
||
port: number,
|
||
query: Uint8Array,
|
||
signal: AbortSignal,
|
||
maxResponseBytes: number,
|
||
): Promise<DnsQueryResult> {
|
||
const chunks: Uint8Array[] = [];
|
||
let totalBytes = 0;
|
||
let settled = false;
|
||
let resolver: ((value: DnsQueryResult) => void) | undefined;
|
||
const promise = new Promise<DnsQueryResult>((resolve) => {
|
||
resolver = resolve;
|
||
});
|
||
|
||
const settle = (result: DnsQueryResult) => {
|
||
if (settled) return;
|
||
settled = true;
|
||
resolver!(result);
|
||
};
|
||
|
||
const socketHandlers: Record<string, (...args: unknown[]) => void> = {
|
||
close() {
|
||
if (totalBytes >= 2) {
|
||
const full = mergeChunks(chunks, totalBytes);
|
||
const view = new DataView(full.buffer, full.byteOffset, full.byteLength);
|
||
const respLen = view.getUint16(0);
|
||
const payloadLen = Math.min(respLen, maxResponseBytes);
|
||
if (totalBytes - 2 >= payloadLen) {
|
||
if (respLen > maxResponseBytes) {
|
||
settle({ error: `TCP 响应超过 ${maxResponseBytes} 字节限制 (${respLen} bytes)`, ok: false });
|
||
return;
|
||
}
|
||
const payload = full.subarray(2, 2 + payloadLen);
|
||
settle(parseAndValidateResponse(query, payload, "tcp"));
|
||
} else {
|
||
settle({ error: `TCP 响应不完整: 期望 ${respLen} 字节,收到 ${totalBytes - 2} 字节`, ok: false });
|
||
}
|
||
} else {
|
||
settle({ error: "TCP 连接关闭,未收到响应", ok: false });
|
||
}
|
||
},
|
||
data(_socket: unknown, data: unknown) {
|
||
const buf = data instanceof Uint8Array ? data : new Uint8Array(data as ArrayBuffer);
|
||
if (totalBytes + buf.byteLength > maxResponseBytes + 2) {
|
||
const trimmed = buf.subarray(0, maxResponseBytes + 2 - totalBytes);
|
||
if (trimmed.byteLength > 0) {
|
||
chunks.push(new Uint8Array(trimmed));
|
||
totalBytes += trimmed.byteLength;
|
||
}
|
||
} else {
|
||
chunks.push(new Uint8Array(buf));
|
||
totalBytes += buf.byteLength;
|
||
}
|
||
|
||
if (totalBytes >= 2) {
|
||
const full = mergeChunks(chunks, totalBytes);
|
||
const view = new DataView(full.buffer, full.byteOffset, full.byteLength);
|
||
const respLen = view.getUint16(0);
|
||
const payloadLen = Math.min(respLen, maxResponseBytes);
|
||
if (totalBytes - 2 >= payloadLen) {
|
||
try {
|
||
(_socket as { close(): void }).close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
}
|
||
}
|
||
},
|
||
error(_socket: unknown, error: unknown) {
|
||
settle({ error: error instanceof Error ? error.message : String(error), ok: false });
|
||
},
|
||
open() {
|
||
// Bun socket handler 必填项,连接成功由 Bun.connect() resolve 表示
|
||
},
|
||
};
|
||
|
||
const onAbort = () => {
|
||
settle({ error: "探测超时", ok: false });
|
||
};
|
||
signal.addEventListener("abort", onAbort, { once: true });
|
||
|
||
try {
|
||
const socket = await Bun.connect({
|
||
hostname: server,
|
||
port,
|
||
socket: socketHandlers,
|
||
});
|
||
|
||
if (signal.aborted) {
|
||
try {
|
||
socket.close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
signal.removeEventListener("abort", onAbort);
|
||
return promise;
|
||
}
|
||
|
||
const lengthBuf = new Uint8Array(2);
|
||
new DataView(lengthBuf.buffer).setUint16(0, query.byteLength);
|
||
socket.write(lengthBuf);
|
||
socket.write(query);
|
||
|
||
const result = await promise;
|
||
signal.removeEventListener("abort", onAbort);
|
||
|
||
try {
|
||
socket.close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
|
||
return result;
|
||
} catch (error) {
|
||
signal.removeEventListener("abort", onAbort);
|
||
if (signal.aborted) {
|
||
return { error: "探测超时", ok: false };
|
||
}
|
||
const message = error instanceof Error ? error.message : String(error);
|
||
return { error: simplifyError(message), ok: false };
|
||
}
|
||
}
|
||
|
||
async function queryUdp(
|
||
server: string,
|
||
port: number,
|
||
query: Uint8Array,
|
||
signal: AbortSignal,
|
||
maxResponseBytes: number,
|
||
): Promise<DnsQueryResult> {
|
||
try {
|
||
const socket = await Bun.udpSocket({
|
||
connect: { hostname: server, port },
|
||
socket: {
|
||
data(socket, data) {
|
||
if (data.byteLength > maxResponseBytes) {
|
||
settle({ error: `UDP 响应超过 ${maxResponseBytes} 字节限制 (${data.byteLength} bytes)`, type: "error" });
|
||
try {
|
||
socket.close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
return;
|
||
}
|
||
settle({ data: new Uint8Array(data.buffer, data.byteOffset, data.byteLength), type: "data" });
|
||
try {
|
||
socket.close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
},
|
||
drain() {
|
||
// Bun UDP socket handler 必填项,DNS checker 不关注 drain 事件
|
||
},
|
||
error(_socket, error) {
|
||
settle({ error: error.message, type: "error" });
|
||
try {
|
||
_socket.close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
},
|
||
},
|
||
});
|
||
|
||
if (signal.aborted) {
|
||
try {
|
||
socket.close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
return { error: "探测已取消", ok: false };
|
||
}
|
||
|
||
let settled = false;
|
||
let resolver: ((value: { data?: Uint8Array; error?: string; type: string }) => void) | undefined;
|
||
const promise = new Promise<{ data?: Uint8Array; error?: string; type: string }>((resolve) => {
|
||
resolver = resolve;
|
||
});
|
||
|
||
const settle = (result: { data?: Uint8Array; error?: string; type: string }) => {
|
||
if (settled) return;
|
||
settled = true;
|
||
resolver!(result);
|
||
};
|
||
|
||
const onAbort = () => {
|
||
settle({ type: "abort" });
|
||
try {
|
||
socket.close();
|
||
} catch {
|
||
/* best-effort */
|
||
}
|
||
};
|
||
signal.addEventListener("abort", onAbort, { once: true });
|
||
|
||
socket.send(query);
|
||
|
||
const result = await promise;
|
||
signal.removeEventListener("abort", onAbort);
|
||
|
||
if (result.type === "error") {
|
||
return { error: result.error ?? "UDP 查询失败", ok: false };
|
||
}
|
||
|
||
if (result.type === "abort") {
|
||
return { error: "探测超时", ok: false };
|
||
}
|
||
|
||
if (!result.data) {
|
||
return { error: "未收到 UDP 响应", ok: false };
|
||
}
|
||
|
||
return parseAndValidateResponse(query, result.data, "udp");
|
||
} catch (error) {
|
||
if (signal.aborted) {
|
||
return { error: "探测超时", ok: false };
|
||
}
|
||
const message = error instanceof Error ? error.message : String(error);
|
||
return { error: simplifyError(message), ok: false };
|
||
}
|
||
}
|
||
|
||
function readQueryMeta(query: Uint8Array): null | QueryMeta {
|
||
if (query.byteLength < 12) return null;
|
||
|
||
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
|
||
const id = view.getUint16(0);
|
||
const labels: string[] = [];
|
||
let offset = 12;
|
||
|
||
while (true) {
|
||
if (offset >= query.byteLength) return null;
|
||
const len = query[offset]!;
|
||
offset++;
|
||
if (len === 0) break;
|
||
if ((len & 0xc0) !== 0 || offset + len > query.byteLength) return null;
|
||
labels.push(new TextDecoder().decode(query.subarray(offset, offset + len)));
|
||
offset += len;
|
||
}
|
||
|
||
if (offset + 4 > query.byteLength) return null;
|
||
const qtype = view.getUint16(offset);
|
||
const qclass = view.getUint16(offset + 2);
|
||
return { id, name: labels.join("."), qclass, qtype };
|
||
}
|
||
|
||
function simplifyError(message: string): string {
|
||
const lower = message.toLowerCase();
|
||
if (lower.includes("econnrefused") || lower.includes("connection refused")) return "connection refused";
|
||
if (lower.includes("enoent") || lower.includes("not found")) return "host not found";
|
||
if (lower.includes("etimedout") || lower.includes("timed out")) return "timed out";
|
||
if (lower.includes("econnreset") || lower.includes("reset")) return "connection reset";
|
||
if (lower.includes("enetwork") || lower.includes("network")) return "network error";
|
||
return message;
|
||
}
|
||
|
||
function validateResponseForQuery(query: Uint8Array, response: DnsResponse): null | string {
|
||
const meta = readQueryMeta(query);
|
||
if (!meta) return "DNS 查询报文不完整";
|
||
if (response.header.id !== meta.id) {
|
||
return `DNS 响应 ID 不匹配: 期望 ${meta.id},实际 ${response.header.id}`;
|
||
}
|
||
|
||
const question = response.questions[0];
|
||
if (question && (question.name !== meta.name || question.qtype !== meta.qtype || question.qclass !== meta.qclass)) {
|
||
return "DNS 响应 question 与查询不匹配";
|
||
}
|
||
|
||
return null;
|
||
}
|