#!/usr/local/bin/python3.10
#
# Copyright 2019-2021 OARC, Inc.
# Copyright 2017-2018 Akamai Technologies
# Copyright 2006-2016 Nominum, Inc.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dns.message
import dns.rrset
import dns.flags
import dns.name
import pcapy
import socket
import sys
import struct
from optparse import OptionParser


__author__ = "Nominum, Inc."
__version__ = "1.0.2.0"
__date__ = "2007-05-14"

IPHeader = '!BBHHHBBHLL'

IPHDRLEN = 20
UDPHDRLEN = 8
LINKTYPE_C_HDLC = 104
LINKTYPE_ETHERNET = 1
qtypecount = {}


def main(argv):
    parser = OptionParser(usage="%prog [options]",
                          version = "%prog " + __version__ )
    parser.add_option("-i", "--input", dest="fin",
                     help="name of tcpdump file to parse", metavar="FILE")
    parser.add_option("-o", "--output", dest="fout",
                     help="file in which to save parsed DNS queries",
                     metavar="FILE")
    parser.add_option("-r", "--recursion", dest="recurse", action="store_true",
                     default=False,
                     help="Keep queries whose RD flag is 0 (default: discard)")
    parser.add_option("-R", "--responses", dest="responses",
                     action="store_true", default=False,
                     help="Parse query responses instead of queries")
    (opts, args) = parser.parse_args()

    if opts.fin:
        pcap = pcapy.open_offline(opts.fin)
    else:
        pcap = pcapy.open_offline('-')
    linktype = pcap.datalink()
    if linktype == LINKTYPE_C_HDLC:
        IPHDRSTART = 4
    else:
        IPHDRSTART = 14
    if opts.fout:
        outfile = open(opts.fout, "w")
    else:
        outfile = sys.stdout
    while True:
        try:
            packet = pcap.next()
        except Exception:
            break

        if packet[0] is None:
            break
        packet = packet[1]
        # Toss the stuff before the IP header
        packet = packet[IPHDRSTART:]

        # Grab the rest of the packet so we can parse proto
        iphdr = packet[0:IPHDRLEN]
        if len(iphdr) < IPHDRLEN:
            continue
        (vhl, tos, tlen, ipid, fragoff, ttl, proto, cksum, srcip, dstip) = \
                struct.unpack(IPHeader, iphdr)

        # Toss the IP header, we're done with it.  We need to account
        # for any IP header options.
        ihl = (vhl & 0xF) * 4
        packet = packet[ihl:]

        if proto == socket.IPPROTO_UDP: # UDP, 8-byte header
            packet = packet[UDPHDRLEN:]
        else:
            continue

        try:
            msg = dns.message.from_wire(packet)
        except Exception:
            continue
        if not opts.recurse and not dns.flags.RD:
            continue
        if opts.responses:
            querytest = msg.flags & dns.flags.QR
        else:
            querytest = not (msg.flags & dns.flags.QR)
        if querytest:
            for query in msg.question: # handle multiple queries per packet
                fqdn = query.name.to_text()
                qtype = dns.rdatatype.to_text(query.rdtype)
                outfile.write("%s %s\n" % (fqdn, qtype))
                # add qtype to dict if not present, otherwise increment
                qtypecount[query.rdtype] = qtypecount.get(query.rdtype, 0) + 1

    if outfile is not sys.stdout:
        outfile.close()
    sum = 0
    print("Statistics:")
    qtypes = list(qtypecount.keys())
    qtypes.sort()
    for qtype in qtypes:
        qtype_str = dns.rdatatype.to_text(qtype)
        count = qtypecount[qtype]
        print("    %10s:\t%d" % (qtype_str, count))
        sum += count
    print("-------------------------")
    print("         TOTAL:\t%d" % sum)

if __name__ == '__main__':
    main(sys.argv[1:])
