diff --git a/pdns_protobuf_receiver/receiver.py b/pdns_protobuf_receiver/receiver.py index 32abb7f..3934997 100644 --- a/pdns_protobuf_receiver/receiver.py +++ b/pdns_protobuf_receiver/receiver.py @@ -23,17 +23,20 @@ # SOFTWARE. import argparse +import binascii import logging import asyncio import socket import json import sys +from datetime import datetime, timezone + import dns.rdatatype +import dns.rdataclass +import dns.rdata import dns.rcode -from datetime import datetime, timezone - # wget https://raw.githubusercontent.com/PowerDNS/dnsmessage/master/dnsmessage.proto # wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.2/protoc-3.12.2-linux-x86_64.zip # python3 -m pip install protobuf @@ -61,31 +64,230 @@ PBDNSMESSAGE_SOCKETPROTOCOL = {1: "UDP", 2: "TCP"} +PBDNSMESSAGE_POLICYTYPE = { + 1: "UNKNOWN", + 2: "QNAME", + 3: "CLIENTIP", + 4: "RESPONSEIP", + 5: "NSDNAME", + 6: "NSDNAME", +} + + +def get_rdata_attributes(cls, exclude_methods=True): + """ + Extract attributes to be set in rdata Json Dict + + Extract from dnspython class the attributes that + will be used to populate the rdata record + """ + base_attrs = dir(type("dummy", (object,), {})) + this_cls_attrs = dir(cls) + res = [] + for attr in this_cls_attrs: + if base_attrs.count(attr) or (callable(getattr(cls, attr)) and exclude_methods): + continue + if attr in ["rdclass", "rdtype", "__slots__"]: + continue + res += [attr] + return res + + +def parse_pb_msg(dns_pb2, dns_msg): + """ + Parse Common Fields in Protobuf PowerDNS Messages + """ + pb_fields_map = { + "type": ["dns_message", PBDNSMESSAGE_TYPE], # 1 + "messageId": "message_id", # 2 + "serverIdentity": "server_identity", # 3 + "socketFamily": ["socket_family", PBDNSMESSAGE_SOCKETFAMILY], # 4 + "socketProtocol": ["socket_protocol", PBDNSMESSAGE_SOCKETPROTOCOL], # 5 + "from": "from_address", # 6 + "to": "to_address", # 7 + "inBytes": "bytes", # 8 + "id": "dns_id", # 11 + "originalRequestorSubnet": "original_requestor_subnet", # 14 + "requestorId": "requestor_id", # 15 + "initialRequestId": "initial_request_id", # 16 + "deviceId": "device_id", # 17 + "newlyObservedDomain": "nod", # 18 + "deviceName": "device_name", # 19 + "fromPort": "from_port", # 20 + "toPort": "to_port", # 21 + } + + # print(dns_pb2) + for key, val in pb_fields_map.items(): + if dns_pb2.HasField(key): + if key in ["from", "to"]: + addr = getattr(dns_pb2, key) + if len(addr) > 0: + if dns_pb2.socketFamily == PBDNSMessage.SocketFamily.INET: + dns_msg[val] = socket.inet_ntop(socket.AF_INET, addr) + if dns_pb2.socketFamily == PBDNSMessage.SocketFamily.INET6: + dns_msg[val] = socket.inet_ntop(socket.AF_INET6, addr) + elif key == "originalRequestorSubnet": + ors = getattr(dns_pb2, key) + if len(ors) == 4: + dns_msg[val] = socket.inet_ntop(socket.AF_INET, ors) + elif len(ors) == 16: + dns_msg[val] = socket.inet_ntop(socket.AF_INET6, ors) + elif isinstance(val, str): + if key in {"messageId", "initialRequestId"}: + dns_msg[val] = binascii.hexlify( + bytearray(getattr(dns_pb2, key)) + ).decode() + else: + res = getattr(dns_pb2, key) + if isinstance(res, bytes): + dns_msg[val] = res.decode() + else: + dns_msg[val] = res + else: + dns_msg[val[0]] = val[1][getattr(dns_pb2, key)] + + +def parse_pb_msg_query(dns_pb2, dns_msg): + """ + Parse Query Fields in Protobuf PowerDNS Messages + """ + pb_fields_map = {"qName": "name", "qType": "type", "qClass": "class"} + + query = dns_pb2.question + query_d = {} + for key, val in pb_fields_map.items(): + if query.HasField(key): + if key == "qType": + query_d[val] = dns.rdatatype.to_text(getattr(query, key)) + elif key == "qClass": + query_d[val] = dns.rdataclass.to_text(getattr(query, key)) + else: + query_d[val] = getattr(query, key) + + dns_msg["query"] = query_d + + +def parse_pb_msg_response(dns_pb2, dns_msg): + """ + Parse Response Fields in Protobuf PowerDNS Messages + """ + pb_fields_map = { + "rcode": "return_code", # 1 + "appliedPolicy": "applied_policy", # 3 + "tags": "tags", # 4 + "appliedPolicyType": ["applied_policy_type", PBDNSMESSAGE_POLICYTYPE], # 7 + "appliedPolicyTrigger": "applied_policy_trigger", # 8 + "appliedPolicyHit": "applied_policy_hit", # 9 + } + + dns_msg["response"] = {} + resp = dns_msg["response"] + + for key, val in pb_fields_map.items(): + if key == "tags": + try: + tags = [] + for i in getattr(dns_pb2.response, val): + tags.append(i) + if len(tags) > 0: + resp["tags"] = tags + except AttributeError: + pass + elif key == "rcode": + if dns_pb2.response.rcode == 65536: + dns_msg["response"]["return_code"] = "NETWORK_ERROR" + else: + dns_msg["response"]["return_code"] = dns.rcode.to_text( + dns_pb2.response.rcode + ) + else: + try: + assert dns_pb2.response.HasField(key) + if isinstance(val, str): + res = getattr(dns_pb2.response, key) + resp[val] = res + else: + resp[val[0]] = val[1][getattr(dns_pb2.response, key)] + except AssertionError: + # take into account fields map that may not + # exist due to pb message version + pass + + +def parse_pb_msg_rrs(dns_pb2, dns_msg): + """ + Parse RRS Fields in Protobuf PowerDNS Messages + """ + pb_fields_map = { + "name": "name", # 1 + "type": "type", # 2 + "class": "class", # 3 + "ttl": "ttl", # 4 + "rdata": "rdata", # 5 + "udr": "udr", # 6 + } + + rrs = [] + + for rr in dns_pb2.response.rrs: + rr_dict = {} + for key, val in pb_fields_map.items(): + res = getattr(rr, key) + if key == "rdata": + rr_dict[val] = {} + if rr.type in [dns.rdatatype.A, dns.rdatatype.AAAA]: + rdata = dns.rdata.from_wire( + rr_dict["class"], rr_dict["type"], res, 0, len(res) + ) + else: + try: + rdata = dns.rdata.from_text( + rr_dict["class"], rr_dict["type"], res.decode() + ) + except UnicodeDecodeError: + rdata = dns.rdata.from_wire( + rr_dict["class"], rr_dict["type"], res, 0, len(res) + ) + except dns.exception.SyntaxError as e: + # Fix for MX & SRV as info sent by Recursors does + # not contain the preference/priority/port (int value) + # like "10 mymx.e.com" for MX + # Stays here in case it is fixed upstream + if rr.type == dns.rdatatype.MX: + rr_dict[val]["exchange"] = str(res.decode()) + break + elif rr.type == dns.rdatatype.SRV: + rr_dict[val]["target"] = str(res.decode()) + break + else: + raise e + + for k in get_rdata_attributes(rdata): + if rr.type == dns.rdatatype.TXT: + text_list = getattr(rdata, k) + text_list = [str(i.decode()) for i in text_list] + rr_dict[val][k] = " ".join(text_list) + else: + rr_dict[val][k] = str(getattr(rdata, k)) + elif key == "class": + rr_dict[val] = dns.rdataclass.to_text(res) + elif key == "type": + rr_dict[val] = dns.rdatatype.to_text(res) + else: + rr_dict[val] = res + rrs.append(rr_dict) + + if len(rrs) > 0: + dns_msg["response"]["rrs"] = rrs + async def cb_onpayload(dns_pb2, payload, tcp_writer, debug_mode, loop): """on dnsmessage protobuf2""" dns_pb2.ParseFromString(payload) dns_msg = {} - dns_msg["dns_message"] = PBDNSMESSAGE_TYPE[dns_pb2.type] - dns_msg["socket_family"] = PBDNSMESSAGE_SOCKETFAMILY[dns_pb2.socketFamily] - dns_msg["socket protocol"] = PBDNSMESSAGE_SOCKETPROTOCOL[dns_pb2.socketProtocol] - - dns_msg["from_address"] = "0.0.0.0" - from_addr = getattr(dns_pb2, "from") - if len(from_addr): - if dns_pb2.socketFamily == PBDNSMessage.SocketFamily.INET: - dns_msg["from_address"] = socket.inet_ntop(socket.AF_INET, from_addr) - if dns_pb2.socketFamily == PBDNSMessage.SocketFamily.INET6: - dns_msg["from_address"] = socket.inet_ntop(socket.AF_INET6, from_addr) - - dns_msg["to_address"] = "0.0.0.0" - to_addr = getattr(dns_pb2, "to") - if len(to_addr): - if dns_pb2.socketFamily == PBDNSMessage.SocketFamily.INET: - dns_msg["to_address"] = socket.inet_ntop(socket.AF_INET, to_addr) - if dns_pb2.socketFamily == PBDNSMessage.SocketFamily.INET6: - dns_msg["to_address"] = socket.inet_ntop(socket.AF_INET6, to_addr) + parse_pb_msg(dns_pb2, dns_msg) time_req = 0 time_rsp = 0 @@ -110,6 +312,9 @@ async def cb_onpayload(dns_pb2, payload, tcp_writer, debug_mode, loop): time_latency = round(float(time_rsp) - float(time_req), 6) + parse_pb_msg_response(dns_pb2, dns_msg) + parse_pb_msg_rrs(dns_pb2, dns_msg) + dns_msg["query_time"] = datetime.fromtimestamp( float(time_req), tz=timezone.utc ).isoformat() @@ -119,14 +324,7 @@ async def cb_onpayload(dns_pb2, payload, tcp_writer, debug_mode, loop): dns_msg["latency"] = time_latency - dns_msg["query_type"] = dns.rdatatype.to_text(dns_pb2.question.qType) - dns_msg["query_name"] = dns_pb2.question.qName - - if dns_pb2.response.rcode == 65536: - dns_msg["return_code"] = "NETWORK_ERROR" - else: - dns_msg["return_code"] = dns.rcode.to_text(dns_pb2.response.rcode) - dns_msg["bytes"] = dns_pb2.inBytes + parse_pb_msg_query(dns_pb2, dns_msg) dns_json = json.dumps(dns_msg) @@ -138,8 +336,8 @@ async def cb_onpayload(dns_pb2, payload, tcp_writer, debug_mode, loop): # exit if we lost the connection with the remote collector loop.stop() raise Exception("connection lost with remote") - else: - tcp_writer.write(dns_json.encode() + b"\n") + + tcp_writer.write(dns_json.encode() + b"\n") async def cb_onconnect(reader, writer, tcp_writer, debug_mode):