Skip to content

DnsPlugin

DnsPlugin

DnsPlugin(verifier)

Bases: BasePlugin

DNS interception plugin.

Patches socket.getaddrinfo, socket.gethostbyname at the module level. When dnspython is available, also patches dns.resolver.resolve and dns.resolver.Resolver.resolve.

Uses reference counting so nested sandboxes work correctly.

Source code in src/tripwire/plugins/dns_plugin.py
def __init__(self, verifier: StrictVerifier) -> None:
    super().__init__(verifier)
    self._queues: dict[str, deque[DnsMockConfig]] = {}
    self._registry_lock: threading.Lock = threading.Lock()

mock_getaddrinfo

mock_getaddrinfo(hostname, *, returns, raises=None, required=True)

Register a mock for socket.getaddrinfo for the given hostname.

Source code in src/tripwire/plugins/dns_plugin.py
def mock_getaddrinfo(
    self,
    hostname: str,
    *,
    returns: Any,  # noqa: ANN401
    raises: BaseException | None = None,
    required: bool = True,
) -> None:
    """Register a mock for socket.getaddrinfo for the given hostname."""
    config = DnsMockConfig(
        operation="getaddrinfo",
        hostname=hostname,
        returns=returns,
        raises=raises,
        required=required,
    )
    queue_key = f"getaddrinfo:{hostname}"
    with self._registry_lock:
        if queue_key not in self._queues:
            self._queues[queue_key] = deque()
        self._queues[queue_key].append(config)

mock_gethostbyname

mock_gethostbyname(hostname, *, returns, raises=None, required=True)

Register a mock for socket.gethostbyname for the given hostname.

Source code in src/tripwire/plugins/dns_plugin.py
def mock_gethostbyname(
    self,
    hostname: str,
    *,
    returns: Any,  # noqa: ANN401
    raises: BaseException | None = None,
    required: bool = True,
) -> None:
    """Register a mock for socket.gethostbyname for the given hostname."""
    config = DnsMockConfig(
        operation="gethostbyname",
        hostname=hostname,
        returns=returns,
        raises=raises,
        required=required,
    )
    queue_key = f"gethostbyname:{hostname}"
    with self._registry_lock:
        if queue_key not in self._queues:
            self._queues[queue_key] = deque()
        self._queues[queue_key].append(config)

mock_resolve

mock_resolve(qname, rdtype, *, returns, raises=None, required=True)

Register a mock for dns.resolver.resolve for the given qname/rdtype.

Only available when dnspython is installed.

Source code in src/tripwire/plugins/dns_plugin.py
def mock_resolve(
    self,
    qname: str,
    rdtype: str,
    *,
    returns: Any,  # noqa: ANN401
    raises: BaseException | None = None,
    required: bool = True,
) -> None:
    """Register a mock for dns.resolver.resolve for the given qname/rdtype.

    Only available when dnspython is installed.
    """
    config = DnsMockConfig(
        operation="resolve",
        hostname=qname,
        returns=returns,
        raises=raises,
        required=required,
    )
    queue_key = f"resolve:{qname}"
    with self._registry_lock:
        if queue_key not in self._queues:
            self._queues[queue_key] = deque()
        self._queues[queue_key].append(config)

install_patches

install_patches()

Install DNS interception patches.

Source code in src/tripwire/plugins/dns_plugin.py
def install_patches(self) -> None:
    """Install DNS interception patches."""
    DnsPlugin._original_getaddrinfo = socket.getaddrinfo
    DnsPlugin._original_gethostbyname = socket.gethostbyname
    setattr(socket, "getaddrinfo", _patched_getaddrinfo)
    socket.gethostbyname = _patched_gethostbyname

    if _DNSPYTHON_AVAILABLE:
        DnsPlugin._original_resolve = dns.resolver.resolve
        DnsPlugin._original_resolver_resolve = dns.resolver.Resolver.resolve
        setattr(dns.resolver, "resolve", _patched_module_resolve)
        setattr(dns.resolver.Resolver, "resolve", _patched_resolver_resolve)

restore_patches

restore_patches()

Restore original DNS functions.

Source code in src/tripwire/plugins/dns_plugin.py
def restore_patches(self) -> None:
    """Restore original DNS functions."""
    if DnsPlugin._original_getaddrinfo is not None:
        socket.getaddrinfo = DnsPlugin._original_getaddrinfo
        DnsPlugin._original_getaddrinfo = None
    if DnsPlugin._original_gethostbyname is not None:
        socket.gethostbyname = DnsPlugin._original_gethostbyname
        DnsPlugin._original_gethostbyname = None
    if DnsPlugin._original_resolve is not None and _DNSPYTHON_AVAILABLE:
        dns.resolver.resolve = DnsPlugin._original_resolve
        DnsPlugin._original_resolve = None
    if DnsPlugin._original_resolver_resolve is not None and _DNSPYTHON_AVAILABLE:
        setattr(dns.resolver.Resolver, "resolve", DnsPlugin._original_resolver_resolve)
        DnsPlugin._original_resolver_resolve = None

matches

matches(interaction, expected)

Field-by-field comparison with dirty-equals support.

Source code in src/tripwire/plugins/dns_plugin.py
def matches(self, interaction: Interaction, expected: dict[str, Any]) -> bool:
    """Field-by-field comparison with dirty-equals support."""
    try:
        for key, expected_val in expected.items():
            actual_val = interaction.details.get(key)
            if expected_val != actual_val:
                return False
        return True
    except Exception:
        return False

get_unused_mocks

get_unused_mocks()

Return all DnsMockConfig with required=True still in any queue.

Source code in src/tripwire/plugins/dns_plugin.py
def get_unused_mocks(self) -> list[DnsMockConfig]:
    """Return all DnsMockConfig with required=True still in any queue."""
    unused: list[DnsMockConfig] = []
    with self._registry_lock:
        for queue in self._queues.values():
            for config in queue:
                if config.required:
                    unused.append(config)
    return unused

assert_getaddrinfo

assert_getaddrinfo(host, port, family, type, proto)

Typed helper: assert the next getaddrinfo interaction.

Source code in src/tripwire/plugins/dns_plugin.py
def assert_getaddrinfo(
    self,
    host: str,
    port: Any,  # noqa: ANN401
    family: int,
    type: int,  # noqa: A002
    proto: int,
) -> None:
    """Typed helper: assert the next getaddrinfo interaction."""
    from tripwire._context import _get_test_verifier_or_raise  # noqa: PLC0415

    source_id = f"dns:getaddrinfo:{host}"
    sentinel = _DnsSentinel(source_id)
    _get_test_verifier_or_raise().assert_interaction(
        sentinel,
        host=host,
        port=port,
        family=family,
        type=type,
        proto=proto,
    )

assert_gethostbyname

assert_gethostbyname(hostname)

Typed helper: assert the next gethostbyname interaction.

Source code in src/tripwire/plugins/dns_plugin.py
def assert_gethostbyname(
    self,
    hostname: str,
) -> None:
    """Typed helper: assert the next gethostbyname interaction."""
    from tripwire._context import _get_test_verifier_or_raise  # noqa: PLC0415

    source_id = f"dns:gethostbyname:{hostname}"
    sentinel = _DnsSentinel(source_id)
    _get_test_verifier_or_raise().assert_interaction(
        sentinel,
        hostname=hostname,
    )

assert_resolve

assert_resolve(qname, rdtype)

Typed helper: assert the next resolve interaction.

Source code in src/tripwire/plugins/dns_plugin.py
def assert_resolve(
    self,
    qname: str,
    rdtype: str,
) -> None:
    """Typed helper: assert the next resolve interaction."""
    from tripwire._context import _get_test_verifier_or_raise  # noqa: PLC0415

    source_id = f"dns:resolve:{qname}"
    sentinel = _DnsSentinel(source_id)
    _get_test_verifier_or_raise().assert_interaction(
        sentinel,
        qname=qname,
        rdtype=rdtype,
    )