import { randomBytes } from 'crypto';
import * as dnsPacket from 'dns-packet';
import IResolvablePromise from '@secret-agent/interfaces/IResolvablePromise';
import { createPromise } from '@secret-agent/commons/utils';
import MitmSocket from '@secret-agent/mitm-socket/index';
import { CanceledPromiseError } from '@secret-agent/commons/interfaces/IPendingWaitEvent';
import IDnsSettings from '@secret-agent/interfaces/IDnsSettings';
import { IBoundLog } from '@secret-agent/interfaces/ILog';
import Log from '@secret-agent/commons/Logger';
import EventSubscriber from '@secret-agent/commons/EventSubscriber';
import RequestSession from '../handlers/RequestSession';
const { log } = Log(module);
export default class DnsOverTlsSocket {
public get host(): string {
return this.dnsSettings.dnsOverTlsConnection?.host;
}
public get isActive(): boolean {
return this.mitmSocket.isReusable() && !this.isClosing;
}
private readonly dnsSettings: IDnsSettings;
private mitmSocket: MitmSocket;
private isConnected: Promise<void>;
private pending = new Map<
number,
{ host: string; resolvable: IResolvablePromise<IDnsResponse> }
>();
private buffer: Buffer = null;
private isClosing = false;
private readonly onClose?: () => void;
private requestSession: RequestSession | undefined;
private logger: IBoundLog;
private eventSubscriber = new EventSubscriber();
constructor(dnsSettings: IDnsSettings, requestSession: RequestSession, onClose?: () => void) {
this.requestSession = requestSession;
this.logger = log.createChild({ sessionId: requestSession.sessionId });
this.dnsSettings = dnsSettings;
this.onClose = onClose;
}
public async lookupARecords(host: string): Promise<IDnsResponse> {
if (!this.isConnected) {
this.isConnected = this.connect();
}
await this.isConnected;
return this.getDnsResponse(host);
}
public close(): void {
if (this.isClosing) return;
this.isClosing = true;
this.mitmSocket?.close();
this.eventSubscriber.close();
this.requestSession = null;
this.mitmSocket = null;
if (this.onClose) this.onClose();
}
protected async connect(): Promise<void> {
const { host, port, servername } = this.dnsSettings.dnsOverTlsConnection || {};
this.mitmSocket = new MitmSocket(this.requestSession?.sessionId, {
host,
servername,
port: String(port ?? 853),
isSsl: true,
keepAlive: true,
debug: true,
});
await this.mitmSocket.connect(this.requestSession.requestAgent.socketSession, 10e3);
this.eventSubscriber.on(this.mitmSocket.socket, 'data', this.onData.bind(this));
const onClose = this.eventSubscriber.on(this.mitmSocket.socket, 'close', () => {
this.isClosing = true;
if (this.onClose) this.onClose();
});
this.eventSubscriber.on(this.mitmSocket, 'eof', async () => {
this.eventSubscriber.off(onClose);
if (this.isClosing) return;
this.mitmSocket.close();
try {
this.isConnected = this.connect();
await this.isConnected;
for (const [id, entry] of this.pending) {
this.pending.delete(id);
const newHost = this.getDnsResponse(entry.host);
entry.resolvable.resolve(newHost);
}
} catch (error) {
this.logger.info('Error re-connecting to dns', {
error,
});
}
});
}
private getDnsResponse(host: string): Promise<IDnsResponse> {
const id = this.query({
name: host,
class: 'IN',
type: 'A',
});
const resolvable = createPromise<IDnsResponse>(5e3);
this.pending.set(id, { host, resolvable });
return resolvable.promise;
}
private disconnect(): void {
for (const [, entry] of this.pending) {
entry.resolvable.reject(new CanceledPromiseError('Disconnecting Dns Socket'));
}
this.close();
}
private query(...questions: IQuestion[]): number {
const id = randomBytes(2).readUInt16BE(0);
const dnsQuery = dnsPacket.streamEncode({
flags: dnsPacket.RECURSION_DESIRED,
id,
questions,
type: 'query',
});
this.mitmSocket.socket.write(dnsQuery);
return id;
}
private onData(data: Buffer): void {
if (this.buffer === null) {
this.buffer = Buffer.from(data);
} else {
this.buffer = Buffer.concat([this.buffer, data]);
}
while (this.buffer.byteLength > 2) {
const messageLength = this.getMessageLength();
if (messageLength < 12) {
return this.disconnect();
}
if (this.buffer.byteLength < messageLength + 2) return;
const next = this.buffer.slice(0, messageLength + 2);
const decoded = dnsPacket.streamDecode(next) as IDnsResponse;
this.pending.get(decoded.id)?.resolvable?.resolve(decoded);
this.pending.delete(decoded.id);
this.buffer = this.buffer.slice(messageLength + 2);
}
}
private getMessageLength(): number | undefined {
if (this.buffer.byteLength >= 2) {
return this.buffer.readUInt16BE(0);
}
}
}
interface IQuestion {
name: string;
type: string;
class: string;
}
interface IAnswer {
name: string;
type: string;
class: string;
ttl: number;
flush: boolean;
data: string;
}
interface IDnsResponse {
id: number;
type: string;
flags: number;
flag_qr: boolean;
opcode: string;
flag_aa: boolean;
flag_tc: boolean;
flag_rd: boolean;
flag_ra: boolean;
flag_z: boolean;
flag_ad: boolean;
flag_cd: boolean;
rcode: string;
questions: IQuestion[];
answers: IAnswer[];
authorities: string[];
additionals: string[];
}