import { describe, expect, it } from "bun:test"; import { buildQuery } from "../../../../../src/server/checker/runner/dns/codec"; import { queryDns } from "../../../../../src/server/checker/runner/dns/transport"; function buildName(name: string): Uint8Array { const parts: number[] = []; for (const label of name.split(".")) { const encoded = new TextEncoder().encode(label); parts.push(encoded.length); parts.push(...encoded); } parts.push(0); return new Uint8Array(parts); } function buildResponse(options: { answers?: Array<{ class?: number; name: string; rdata: Uint8Array; ttl: number; type: number }>; flags?: { tc?: boolean }; id: number; questions?: Array<{ name: string; qclass?: number; qtype: number }>; rcode?: number; }): Uint8Array { const questions = options.questions ?? []; const answers = options.answers ?? []; let flags = 0x8000; if (options.flags?.tc) flags |= 0x0200; flags |= (options.rcode ?? 0) & 0x000f; const header = new Uint8Array(12); const hv = new DataView(header.buffer); hv.setUint16(0, options.id); hv.setUint16(2, flags); hv.setUint16(4, questions.length); hv.setUint16(6, answers.length); hv.setUint16(8, 0); hv.setUint16(10, 0); const qParts: Uint8Array[] = []; for (const q of questions) { const nameBytes = buildName(q.name); const qtype = new Uint8Array(4); const qv = new DataView(qtype.buffer); qv.setUint16(0, q.qtype); qv.setUint16(2, q.qclass ?? 1); qParts.push(nameBytes, qtype); } const aParts: Uint8Array[] = []; for (const a of answers) { const nameBytes = buildName(a.name); const rrHead = new Uint8Array(10); const rv = new DataView(rrHead.buffer); rv.setUint16(0, a.type); rv.setUint16(2, a.class ?? 1); rv.setUint32(4, a.ttl); rv.setUint16(8, a.rdata.length); aParts.push(nameBytes, rrHead, a.rdata); } const allParts = [header, ...qParts, ...aParts]; const totalLen = allParts.reduce((s, p) => s + p.length, 0); const result = new Uint8Array(totalLen); let offset = 0; for (const part of allParts) { result.set(part, offset); offset += part.length; } return result; } function createTcpServer(respondWith: (query: Uint8Array) => Uint8Array, port = 0): { port: number; stop: () => void } { const states = new WeakMap(); const server = Bun.listen({ hostname: "127.0.0.1", port, socket: { data(socket, data) { const key = socket as object; const state = states.get(key) ?? { chunks: [], totalBytes: 0 }; const chunk = new Uint8Array(data.buffer, data.byteOffset, data.byteLength); state.chunks.push(chunk); state.totalBytes += chunk.byteLength; states.set(key, state); const full = mergeChunks(state.chunks, state.totalBytes); if (full.byteLength < 2) return; const queryLength = new DataView(full.buffer, full.byteOffset, full.byteLength).getUint16(0); if (full.byteLength < queryLength + 2) return; const response = respondWith(full.subarray(2, 2 + queryLength)); const lengthPrefix = new Uint8Array(2); new DataView(lengthPrefix.buffer).setUint16(0, response.byteLength); socket.write(lengthPrefix); socket.write(response); socket.close(); }, error() { // 测试 server 忽略错误 }, open() { // Bun.listen 必填 handler }, }, }); return { port: server.port, stop: () => server.stop() }; } async function createUdpServer( respondWith: (query: Uint8Array) => Uint8Array, port?: number, ): Promise<{ close: () => void; port: number }> { const socketHandlers = { data( sock: { send(data: Uint8Array, port: number, hostname: string): void }, data: Uint8Array, remotePort: number, addr: string, ) { const query = new Uint8Array(data.buffer, data.byteOffset, data.byteLength); sock.send(respondWith(query), remotePort, addr); }, drain() { // Bun UDP socket handler 必填项 }, error() { // 测试 server 忽略错误 }, }; const socket = port === undefined ? await Bun.udpSocket({ hostname: "127.0.0.1", socket: socketHandlers }) : await Bun.udpSocket({ hostname: "127.0.0.1", port, socket: socketHandlers }); return { close: () => socket.close(), port: socket.port }; } function makeSignal(timeoutMs: number): { cleanup: () => void; signal: AbortSignal } { const controller = new AbortController(); const timer = setTimeout(() => controller.abort(), timeoutMs); return { cleanup: () => clearTimeout(timer), signal: controller.signal }; } function mergeChunks(chunks: Uint8Array[], totalBytes: number): Uint8Array { const result = new Uint8Array(totalBytes); let offset = 0; for (const chunk of chunks) { result.set(chunk, offset); offset += chunk.byteLength; } return result; } describe("DNS transport", () => { it("executes TCP DNS query with length-prefixed response", async () => { const server = createTcpServer((query) => { const view = new DataView(query.buffer, query.byteOffset, query.byteLength); const id = view.getUint16(0); return buildResponse({ answers: [{ name: "example.com", rdata: new Uint8Array([93, 184, 216, 34]), ttl: 300, type: 1 }], id, questions: [{ name: "example.com", qtype: 1 }], }); }); try { const { cleanup, signal } = makeSignal(5000); const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), { maxResponseBytes: 4096, protocol: "tcp", signal, tcpFallback: false, }); cleanup(); expect(result.ok).toBe(true); if (result.ok) { expect(result.protocolUsed).toBe("tcp"); expect(result.response.answers[0]!.value).toBe("93.184.216.34"); } } finally { server.stop(); } }); it("falls back from UDP to TCP when response is truncated", async () => { const tcpServer = createTcpServer((query) => { const view = new DataView(query.buffer, query.byteOffset, query.byteLength); const id = view.getUint16(0); return buildResponse({ answers: [{ name: "example.com", rdata: new Uint8Array([1, 1, 1, 1]), ttl: 60, type: 1 }], id, questions: [{ name: "example.com", qtype: 1 }], }); }); const udpServer = await createUdpServer((query) => { const view = new DataView(query.buffer, query.byteOffset, query.byteLength); const id = view.getUint16(0); return buildResponse({ flags: { tc: true }, id, questions: [{ name: "example.com", qtype: 1 }] }); }, tcpServer.port); try { const { cleanup, signal } = makeSignal(5000); const result = await queryDns("127.0.0.1", tcpServer.port, buildQuery("example.com", 1, true), { maxResponseBytes: 4096, protocol: "udp", signal, tcpFallback: true, }); cleanup(); expect(result.ok).toBe(true); if (result.ok) { expect(result.protocolUsed).toBe("tcp"); expect(result.response.answers[0]!.value).toBe("1.1.1.1"); } } finally { udpServer.close(); tcpServer.stop(); } }); it("rejects UDP responses larger than maxResponseBytes", async () => { const server = await createUdpServer((query) => { const view = new DataView(query.buffer, query.byteOffset, query.byteLength); const id = view.getUint16(0); return buildResponse({ answers: [{ name: "example.com", rdata: new Uint8Array([93, 184, 216, 34]), ttl: 300, type: 1 }], id, questions: [{ name: "example.com", qtype: 1 }], }); }); try { const { cleanup, signal } = makeSignal(5000); const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), { maxResponseBytes: 8, protocol: "udp", signal, tcpFallback: false, }); cleanup(); expect(result.ok).toBe(false); if (!result.ok) expect(result.error).toContain("超过"); } finally { server.close(); } }); it("rejects response ID mismatch", async () => { const server = await createUdpServer((query) => { const view = new DataView(query.buffer, query.byteOffset, query.byteLength); const id = (view.getUint16(0) + 1) & 0xffff; return buildResponse({ id, questions: [{ name: "example.com", qtype: 1 }] }); }); try { const { cleanup, signal } = makeSignal(5000); const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), { maxResponseBytes: 4096, protocol: "udp", signal, tcpFallback: false, }); cleanup(); expect(result.ok).toBe(false); if (!result.ok) expect(result.error).toContain("ID 不匹配"); } finally { server.close(); } }); });