267 lines
8.8 KiB
TypeScript
267 lines
8.8 KiB
TypeScript
import { describe, expect, it } from "bun:test";
|
|
|
|
import { buildQuery } from "../../../../../src/server/checker/runner/dns/codec";
|
|
import { queryDns } from "../../../../../src/server/checker/runner/dns/transport";
|
|
|
|
function buildName(name: string): Uint8Array {
|
|
const parts: number[] = [];
|
|
for (const label of name.split(".")) {
|
|
const encoded = new TextEncoder().encode(label);
|
|
parts.push(encoded.length);
|
|
parts.push(...encoded);
|
|
}
|
|
parts.push(0);
|
|
return new Uint8Array(parts);
|
|
}
|
|
|
|
function buildResponse(options: {
|
|
answers?: Array<{ class?: number; name: string; rdata: Uint8Array; ttl: number; type: number }>;
|
|
flags?: { tc?: boolean };
|
|
id: number;
|
|
questions?: Array<{ name: string; qclass?: number; qtype: number }>;
|
|
rcode?: number;
|
|
}): Uint8Array {
|
|
const questions = options.questions ?? [];
|
|
const answers = options.answers ?? [];
|
|
let flags = 0x8000;
|
|
if (options.flags?.tc) flags |= 0x0200;
|
|
flags |= (options.rcode ?? 0) & 0x000f;
|
|
|
|
const header = new Uint8Array(12);
|
|
const hv = new DataView(header.buffer);
|
|
hv.setUint16(0, options.id);
|
|
hv.setUint16(2, flags);
|
|
hv.setUint16(4, questions.length);
|
|
hv.setUint16(6, answers.length);
|
|
hv.setUint16(8, 0);
|
|
hv.setUint16(10, 0);
|
|
|
|
const qParts: Uint8Array[] = [];
|
|
for (const q of questions) {
|
|
const nameBytes = buildName(q.name);
|
|
const qtype = new Uint8Array(4);
|
|
const qv = new DataView(qtype.buffer);
|
|
qv.setUint16(0, q.qtype);
|
|
qv.setUint16(2, q.qclass ?? 1);
|
|
qParts.push(nameBytes, qtype);
|
|
}
|
|
|
|
const aParts: Uint8Array[] = [];
|
|
for (const a of answers) {
|
|
const nameBytes = buildName(a.name);
|
|
const rrHead = new Uint8Array(10);
|
|
const rv = new DataView(rrHead.buffer);
|
|
rv.setUint16(0, a.type);
|
|
rv.setUint16(2, a.class ?? 1);
|
|
rv.setUint32(4, a.ttl);
|
|
rv.setUint16(8, a.rdata.length);
|
|
aParts.push(nameBytes, rrHead, a.rdata);
|
|
}
|
|
|
|
const allParts = [header, ...qParts, ...aParts];
|
|
const totalLen = allParts.reduce((s, p) => s + p.length, 0);
|
|
const result = new Uint8Array(totalLen);
|
|
let offset = 0;
|
|
for (const part of allParts) {
|
|
result.set(part, offset);
|
|
offset += part.length;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
function createTcpServer(respondWith: (query: Uint8Array) => Uint8Array, port = 0): { port: number; stop: () => void } {
|
|
const states = new WeakMap<object, { chunks: Uint8Array[]; totalBytes: number }>();
|
|
const server = Bun.listen({
|
|
hostname: "127.0.0.1",
|
|
port,
|
|
socket: {
|
|
data(socket, data) {
|
|
const key = socket as object;
|
|
const state = states.get(key) ?? { chunks: [], totalBytes: 0 };
|
|
const chunk = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
state.chunks.push(chunk);
|
|
state.totalBytes += chunk.byteLength;
|
|
states.set(key, state);
|
|
|
|
const full = mergeChunks(state.chunks, state.totalBytes);
|
|
if (full.byteLength < 2) return;
|
|
const queryLength = new DataView(full.buffer, full.byteOffset, full.byteLength).getUint16(0);
|
|
if (full.byteLength < queryLength + 2) return;
|
|
|
|
const response = respondWith(full.subarray(2, 2 + queryLength));
|
|
const lengthPrefix = new Uint8Array(2);
|
|
new DataView(lengthPrefix.buffer).setUint16(0, response.byteLength);
|
|
socket.write(lengthPrefix);
|
|
socket.write(response);
|
|
socket.close();
|
|
},
|
|
error() {
|
|
// 测试 server 忽略错误
|
|
},
|
|
open() {
|
|
// Bun.listen 必填 handler
|
|
},
|
|
},
|
|
});
|
|
return { port: server.port, stop: () => server.stop() };
|
|
}
|
|
|
|
async function createUdpServer(
|
|
respondWith: (query: Uint8Array) => Uint8Array,
|
|
port?: number,
|
|
): Promise<{ close: () => void; port: number }> {
|
|
const socketHandlers = {
|
|
data(
|
|
sock: { send(data: Uint8Array, port: number, hostname: string): void },
|
|
data: Uint8Array,
|
|
remotePort: number,
|
|
addr: string,
|
|
) {
|
|
const query = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
sock.send(respondWith(query), remotePort, addr);
|
|
},
|
|
drain() {
|
|
// Bun UDP socket handler 必填项
|
|
},
|
|
error() {
|
|
// 测试 server 忽略错误
|
|
},
|
|
};
|
|
const socket =
|
|
port === undefined
|
|
? await Bun.udpSocket({ hostname: "127.0.0.1", socket: socketHandlers })
|
|
: await Bun.udpSocket({ hostname: "127.0.0.1", port, socket: socketHandlers });
|
|
return { close: () => socket.close(), port: socket.port };
|
|
}
|
|
|
|
function makeSignal(timeoutMs: number): { cleanup: () => void; signal: AbortSignal } {
|
|
const controller = new AbortController();
|
|
const timer = setTimeout(() => controller.abort(), timeoutMs);
|
|
return { cleanup: () => clearTimeout(timer), signal: controller.signal };
|
|
}
|
|
|
|
function mergeChunks(chunks: Uint8Array[], totalBytes: number): Uint8Array {
|
|
const result = new Uint8Array(totalBytes);
|
|
let offset = 0;
|
|
for (const chunk of chunks) {
|
|
result.set(chunk, offset);
|
|
offset += chunk.byteLength;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
describe("DNS transport", () => {
|
|
it("executes TCP DNS query with length-prefixed response", async () => {
|
|
const server = createTcpServer((query) => {
|
|
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
|
|
const id = view.getUint16(0);
|
|
return buildResponse({
|
|
answers: [{ name: "example.com", rdata: new Uint8Array([93, 184, 216, 34]), ttl: 300, type: 1 }],
|
|
id,
|
|
questions: [{ name: "example.com", qtype: 1 }],
|
|
});
|
|
});
|
|
try {
|
|
const { cleanup, signal } = makeSignal(5000);
|
|
const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), {
|
|
maxResponseBytes: 4096,
|
|
protocol: "tcp",
|
|
signal,
|
|
tcpFallback: false,
|
|
});
|
|
cleanup();
|
|
expect(result.ok).toBe(true);
|
|
if (result.ok) {
|
|
expect(result.protocolUsed).toBe("tcp");
|
|
expect(result.response.answers[0]!.value).toBe("93.184.216.34");
|
|
}
|
|
} finally {
|
|
server.stop();
|
|
}
|
|
});
|
|
|
|
it("falls back from UDP to TCP when response is truncated", async () => {
|
|
const tcpServer = createTcpServer((query) => {
|
|
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
|
|
const id = view.getUint16(0);
|
|
return buildResponse({
|
|
answers: [{ name: "example.com", rdata: new Uint8Array([1, 1, 1, 1]), ttl: 60, type: 1 }],
|
|
id,
|
|
questions: [{ name: "example.com", qtype: 1 }],
|
|
});
|
|
});
|
|
const udpServer = await createUdpServer((query) => {
|
|
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
|
|
const id = view.getUint16(0);
|
|
return buildResponse({ flags: { tc: true }, id, questions: [{ name: "example.com", qtype: 1 }] });
|
|
}, tcpServer.port);
|
|
|
|
try {
|
|
const { cleanup, signal } = makeSignal(5000);
|
|
const result = await queryDns("127.0.0.1", tcpServer.port, buildQuery("example.com", 1, true), {
|
|
maxResponseBytes: 4096,
|
|
protocol: "udp",
|
|
signal,
|
|
tcpFallback: true,
|
|
});
|
|
cleanup();
|
|
expect(result.ok).toBe(true);
|
|
if (result.ok) {
|
|
expect(result.protocolUsed).toBe("tcp");
|
|
expect(result.response.answers[0]!.value).toBe("1.1.1.1");
|
|
}
|
|
} finally {
|
|
udpServer.close();
|
|
tcpServer.stop();
|
|
}
|
|
});
|
|
|
|
it("rejects UDP responses larger than maxResponseBytes", async () => {
|
|
const server = await createUdpServer((query) => {
|
|
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
|
|
const id = view.getUint16(0);
|
|
return buildResponse({
|
|
answers: [{ name: "example.com", rdata: new Uint8Array([93, 184, 216, 34]), ttl: 300, type: 1 }],
|
|
id,
|
|
questions: [{ name: "example.com", qtype: 1 }],
|
|
});
|
|
});
|
|
try {
|
|
const { cleanup, signal } = makeSignal(5000);
|
|
const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), {
|
|
maxResponseBytes: 8,
|
|
protocol: "udp",
|
|
signal,
|
|
tcpFallback: false,
|
|
});
|
|
cleanup();
|
|
expect(result.ok).toBe(false);
|
|
if (!result.ok) expect(result.error).toContain("超过");
|
|
} finally {
|
|
server.close();
|
|
}
|
|
});
|
|
|
|
it("rejects response ID mismatch", async () => {
|
|
const server = await createUdpServer((query) => {
|
|
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
|
|
const id = (view.getUint16(0) + 1) & 0xffff;
|
|
return buildResponse({ id, questions: [{ name: "example.com", qtype: 1 }] });
|
|
});
|
|
try {
|
|
const { cleanup, signal } = makeSignal(5000);
|
|
const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), {
|
|
maxResponseBytes: 4096,
|
|
protocol: "udp",
|
|
signal,
|
|
tcpFallback: false,
|
|
});
|
|
cleanup();
|
|
expect(result.ok).toBe(false);
|
|
if (!result.ok) expect(result.error).toContain("ID 不匹配");
|
|
} finally {
|
|
server.close();
|
|
}
|
|
});
|
|
});
|