1
0
Files
DiAL/tests/server/checker/runner/dns/transport.test.ts

267 lines
8.8 KiB
TypeScript

import { describe, expect, it } from "bun:test";
import { buildQuery } from "../../../../../src/server/checker/runner/dns/codec";
import { queryDns } from "../../../../../src/server/checker/runner/dns/transport";
function buildName(name: string): Uint8Array {
const parts: number[] = [];
for (const label of name.split(".")) {
const encoded = new TextEncoder().encode(label);
parts.push(encoded.length);
parts.push(...encoded);
}
parts.push(0);
return new Uint8Array(parts);
}
function buildResponse(options: {
answers?: Array<{ class?: number; name: string; rdata: Uint8Array; ttl: number; type: number }>;
flags?: { tc?: boolean };
id: number;
questions?: Array<{ name: string; qclass?: number; qtype: number }>;
rcode?: number;
}): Uint8Array {
const questions = options.questions ?? [];
const answers = options.answers ?? [];
let flags = 0x8000;
if (options.flags?.tc) flags |= 0x0200;
flags |= (options.rcode ?? 0) & 0x000f;
const header = new Uint8Array(12);
const hv = new DataView(header.buffer);
hv.setUint16(0, options.id);
hv.setUint16(2, flags);
hv.setUint16(4, questions.length);
hv.setUint16(6, answers.length);
hv.setUint16(8, 0);
hv.setUint16(10, 0);
const qParts: Uint8Array[] = [];
for (const q of questions) {
const nameBytes = buildName(q.name);
const qtype = new Uint8Array(4);
const qv = new DataView(qtype.buffer);
qv.setUint16(0, q.qtype);
qv.setUint16(2, q.qclass ?? 1);
qParts.push(nameBytes, qtype);
}
const aParts: Uint8Array[] = [];
for (const a of answers) {
const nameBytes = buildName(a.name);
const rrHead = new Uint8Array(10);
const rv = new DataView(rrHead.buffer);
rv.setUint16(0, a.type);
rv.setUint16(2, a.class ?? 1);
rv.setUint32(4, a.ttl);
rv.setUint16(8, a.rdata.length);
aParts.push(nameBytes, rrHead, a.rdata);
}
const allParts = [header, ...qParts, ...aParts];
const totalLen = allParts.reduce((s, p) => s + p.length, 0);
const result = new Uint8Array(totalLen);
let offset = 0;
for (const part of allParts) {
result.set(part, offset);
offset += part.length;
}
return result;
}
function createTcpServer(respondWith: (query: Uint8Array) => Uint8Array, port = 0): { port: number; stop: () => void } {
const states = new WeakMap<object, { chunks: Uint8Array[]; totalBytes: number }>();
const server = Bun.listen({
hostname: "127.0.0.1",
port,
socket: {
data(socket, data) {
const key = socket as object;
const state = states.get(key) ?? { chunks: [], totalBytes: 0 };
const chunk = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
state.chunks.push(chunk);
state.totalBytes += chunk.byteLength;
states.set(key, state);
const full = mergeChunks(state.chunks, state.totalBytes);
if (full.byteLength < 2) return;
const queryLength = new DataView(full.buffer, full.byteOffset, full.byteLength).getUint16(0);
if (full.byteLength < queryLength + 2) return;
const response = respondWith(full.subarray(2, 2 + queryLength));
const lengthPrefix = new Uint8Array(2);
new DataView(lengthPrefix.buffer).setUint16(0, response.byteLength);
socket.write(lengthPrefix);
socket.write(response);
socket.close();
},
error() {
// 测试 server 忽略错误
},
open() {
// Bun.listen 必填 handler
},
},
});
return { port: server.port, stop: () => server.stop() };
}
async function createUdpServer(
respondWith: (query: Uint8Array) => Uint8Array,
port?: number,
): Promise<{ close: () => void; port: number }> {
const socketHandlers = {
data(
sock: { send(data: Uint8Array, port: number, hostname: string): void },
data: Uint8Array,
remotePort: number,
addr: string,
) {
const query = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
sock.send(respondWith(query), remotePort, addr);
},
drain() {
// Bun UDP socket handler 必填项
},
error() {
// 测试 server 忽略错误
},
};
const socket =
port === undefined
? await Bun.udpSocket({ hostname: "127.0.0.1", socket: socketHandlers })
: await Bun.udpSocket({ hostname: "127.0.0.1", port, socket: socketHandlers });
return { close: () => socket.close(), port: socket.port };
}
function makeSignal(timeoutMs: number): { cleanup: () => void; signal: AbortSignal } {
const controller = new AbortController();
const timer = setTimeout(() => controller.abort(), timeoutMs);
return { cleanup: () => clearTimeout(timer), signal: controller.signal };
}
function mergeChunks(chunks: Uint8Array[], totalBytes: number): Uint8Array {
const result = new Uint8Array(totalBytes);
let offset = 0;
for (const chunk of chunks) {
result.set(chunk, offset);
offset += chunk.byteLength;
}
return result;
}
describe("DNS transport", () => {
it("executes TCP DNS query with length-prefixed response", async () => {
const server = createTcpServer((query) => {
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
const id = view.getUint16(0);
return buildResponse({
answers: [{ name: "example.com", rdata: new Uint8Array([93, 184, 216, 34]), ttl: 300, type: 1 }],
id,
questions: [{ name: "example.com", qtype: 1 }],
});
});
try {
const { cleanup, signal } = makeSignal(5000);
const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), {
maxResponseBytes: 4096,
protocol: "tcp",
signal,
tcpFallback: false,
});
cleanup();
expect(result.ok).toBe(true);
if (result.ok) {
expect(result.protocolUsed).toBe("tcp");
expect(result.response.answers[0]!.value).toBe("93.184.216.34");
}
} finally {
server.stop();
}
});
it("falls back from UDP to TCP when response is truncated", async () => {
const tcpServer = createTcpServer((query) => {
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
const id = view.getUint16(0);
return buildResponse({
answers: [{ name: "example.com", rdata: new Uint8Array([1, 1, 1, 1]), ttl: 60, type: 1 }],
id,
questions: [{ name: "example.com", qtype: 1 }],
});
});
const udpServer = await createUdpServer((query) => {
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
const id = view.getUint16(0);
return buildResponse({ flags: { tc: true }, id, questions: [{ name: "example.com", qtype: 1 }] });
}, tcpServer.port);
try {
const { cleanup, signal } = makeSignal(5000);
const result = await queryDns("127.0.0.1", tcpServer.port, buildQuery("example.com", 1, true), {
maxResponseBytes: 4096,
protocol: "udp",
signal,
tcpFallback: true,
});
cleanup();
expect(result.ok).toBe(true);
if (result.ok) {
expect(result.protocolUsed).toBe("tcp");
expect(result.response.answers[0]!.value).toBe("1.1.1.1");
}
} finally {
udpServer.close();
tcpServer.stop();
}
});
it("rejects UDP responses larger than maxResponseBytes", async () => {
const server = await createUdpServer((query) => {
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
const id = view.getUint16(0);
return buildResponse({
answers: [{ name: "example.com", rdata: new Uint8Array([93, 184, 216, 34]), ttl: 300, type: 1 }],
id,
questions: [{ name: "example.com", qtype: 1 }],
});
});
try {
const { cleanup, signal } = makeSignal(5000);
const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), {
maxResponseBytes: 8,
protocol: "udp",
signal,
tcpFallback: false,
});
cleanup();
expect(result.ok).toBe(false);
if (!result.ok) expect(result.error).toContain("超过");
} finally {
server.close();
}
});
it("rejects response ID mismatch", async () => {
const server = await createUdpServer((query) => {
const view = new DataView(query.buffer, query.byteOffset, query.byteLength);
const id = (view.getUint16(0) + 1) & 0xffff;
return buildResponse({ id, questions: [{ name: "example.com", qtype: 1 }] });
});
try {
const { cleanup, signal } = makeSignal(5000);
const result = await queryDns("127.0.0.1", server.port, buildQuery("example.com", 1, true), {
maxResponseBytes: 4096,
protocol: "udp",
signal,
tcpFallback: false,
});
cleanup();
expect(result.ok).toBe(false);
if (!result.ok) expect(result.error).toContain("ID 不匹配");
} finally {
server.close();
}
});
});