Advertisement
Guest User

A Non-Standard A-record Only Local DNS Cacher

a guest
Jan 2nd, 2019
167
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 27.82 KB | None | 0 0
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3.  
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Lesser General Public License as published
  6. # by the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  12. # GNU General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Lesser General Public License
  15. # along with this program.  If not, see <https://www.gnu.org/licenses/>.
  16.  
  17. # Acknowledgment
  18. # This program uses the following open source projects:
  19. # BitString, by Scott Griffiths, released under MIT License
  20. # Pony ORM, by Pony ORM, LLC, released under Apache v2 License
  21. # Requests, by Kenneth Reitz, released under Apache v2 License
  22.  
  23.  
  24. from socket import socket, AF_INET, SOCK_DGRAM
  25. import queue
  26. from threading import Thread
  27. import requests
  28. from bitstring import BitArray, BitStream
  29. import pony.orm
  30. from time import sleep, time
  31. from collections import deque
  32.  
  33. SERVER_IP = '127.0.0.1'
  34. SERVER_PORT = 53
  35.  
  36. DOH_URL = 'https://1.1.1.1/dns-query'
  37.  
  38. PERIOD = 0.5  # In days
  39.  
  40. POSITIVE_CACHE_FILENAME = 'dns_pcache.sqlite'
  41. NEGATIVE_CACHE_FILENAME = 'dns_ncache.sqlite'
  42. SCHEDULE_FILE_FILENAME = 'cache_update.schedule'
  43.  
  44.  
  45. def main():
  46.     if PERIOD <= 0:
  47.         raise ValueError('PERIOD should be larger than 0.')
  48.     # Load records cache database
  49.     records_cache = pony.orm.Database()
  50.  
  51.     class RecordsCache(records_cache.Entity):
  52.         full_domain = pony.orm.Required(str)
  53.         query_type = pony.orm.Required(str)
  54.         data = pony.orm.Required(str)
  55.         used_in_current_period = pony.orm.Required(bool)
  56.         used_in_previous_period = pony.orm.Required(bool)
  57.         pony.orm.PrimaryKey(full_domain, query_type)
  58.     records_cache.bind(provider='sqlite', filename=POSITIVE_CACHE_FILENAME, create_db=True)
  59.     records_cache.generate_mapping(create_tables=True)
  60.  
  61.     # Create negative cache database
  62.     negative_cache = pony.orm.Database()
  63.  
  64.     class NegativeCache(negative_cache.Entity):
  65.         full_domain = pony.orm.PrimaryKey(str)
  66.     negative_cache.bind(provider='sqlite', filename=NEGATIVE_CACHE_FILENAME, create_db=True)
  67.     negative_cache.generate_mapping(create_tables=True)
  68.  
  69.     with pony.orm.db_session:
  70.         NegativeCache.select().delete(bulk=True)  # Delete all negative cache from last session
  71.  
  72.     try:
  73.         schedule_file = open(file=SCHEDULE_FILE_FILENAME, mode='r+')
  74.         schedule = schedule_file.read().rstrip()
  75.         try:
  76.             schedule = int(schedule)
  77.         except ValueError as e:
  78.             print(e)
  79.             schedule = int(time() + PERIOD * 86400)
  80.             schedule_file.seek(0)
  81.             schedule_file.truncate()
  82.             schedule_file.writelines(str(schedule)+'\n')
  83.     except FileNotFoundError:
  84.         schedule_file = open(file=SCHEDULE_FILE_FILENAME, mode='w+')
  85.         schedule = int(time() + PERIOD * 86400)
  86.         schedule_file.writelines(str(schedule)+'\n')
  87.     finally:
  88.         schedule_file.close()
  89.  
  90.     Thread(target=timer_job, args=(schedule, RecordsCache, NegativeCache)).start()  # Timer thread
  91.  
  92.     task_queue = queue.Queue()
  93.  
  94.     server_socket = socket(AF_INET, SOCK_DGRAM)
  95.     server_socket.bind((SERVER_IP, SERVER_PORT))
  96.  
  97.     for i in range(4):
  98.         Thread(target=worker, args=(task_queue, RecordsCache, NegativeCache, server_socket)).start()  # Worker threads
  99.  
  100.     while True:
  101.         query_data, client_address = server_socket.recvfrom(512)
  102.         task_package = (query_data, client_address)
  103.         task_queue.put(task_package)
  104.  
  105.  
  106. def worker(task_queue, RecordsCache, NegativeCache, server_socket):
  107.     requests_session = requests.Session()
  108.     while True:
  109.         query_data, client_address = task_queue.get(block=True)  # Using block=False will result in high cpu usage
  110.         try:
  111.             full_domain, query_type, id, rd, cd, question_section = read_query_stream(BitStream(query_data))
  112.             print('Received query:', full_domain, query_type)
  113.             if query_type != 'A':
  114.                 if query_type == 'ALL':
  115.                     query_type = 'A'
  116.                 else:
  117.                     raise NotImplementedError('Received query for {} record. The current server supports A records only.'.format(query_type))
  118.         except BaseException as e:
  119.             print(e)
  120.             task_queue.task_done()
  121.             continue  # Skip following procedures, go to next iteration
  122.         try:
  123.             cache_status, cached_data = fetch_cached_data(RecordsCache, NegativeCache, full_domain, query_type)
  124.         except BaseException as e:
  125.             print(e)
  126.             task_queue.task_done()
  127.             continue  # Skip following procedures, go to next iteration
  128.  
  129.         status_to_respond = ''
  130.         address_data = None
  131.         if cache_status == 'positive':  # There is cache and it is valid record
  132.             address_data = cached_data
  133.             response = construct_positive_response(id, rd, cd, question_section, address_data)
  134.             status_to_respond = 'NOERROR'
  135.         elif cache_status == 'nxdomain':
  136.             response = construct_nxdomain_response(id, rd, cd, question_section)
  137.             status_to_respond = 'NXDOMAIN'
  138.         elif cache_status == 'nocache':
  139.             try:
  140.                 remote_status, remote_data = fetch_remote_data(requests_session, full_domain, query_type)
  141.             except BaseException as e:
  142.                 # Return server failure
  143.                 print('Failed to get address from remote:', e)
  144.                 response = construct_servfail_response(id, rd, cd, question_section)
  145.                 status_to_respond = 'SERVFAIL'
  146.             else:
  147.                 if remote_status == 'noerror':  # Successfully retrieved address from remote
  148.                     address_data = remote_data
  149.                     Thread(target=error_output_beautifier, args=(cache_remote_positive_answer, (RecordsCache, full_domain, query_type, remote_data))).start()
  150.                     response = construct_positive_response(id, rd, cd, question_section, address_data)
  151.                     status_to_respond = 'NOERROR'
  152.                 elif remote_status == 'nodata':
  153.                     response = construct_nodata_response(id, rd, cd, question_section)
  154.                     status_to_respond = 'NODATA'
  155.                 elif remote_status == 'nxdomain':
  156.                     Thread(target=error_output_beautifier, args=(cache_remote_nxdomain_answer, (NegativeCache, full_domain))).start()
  157.                     response = construct_nxdomain_response(id, rd, cd, question_section)
  158.                     status_to_respond = 'NXDOMAIN'
  159.                 elif remote_status == 'servfail':
  160.                     response = construct_servfail_response(id, rd, cd, question_section)
  161.                     status_to_respond = 'SERVFAIL'
  162.                 else:
  163.                     raise NotImplementedError('Unrecognized remote_status.')
  164.         else:
  165.             raise Exception('Invalid cache_status.')
  166.  
  167.         print('Responding message:', full_domain, query_type, str(address_data), status_to_respond)
  168.         binary_response = response.tobytes()
  169.         server_socket.sendto(binary_response, client_address)
  170.         task_queue.task_done()
  171.  
  172.         # Return server failure when both having no cache and fetching from remote fails
  173.         # Return no such domain when there are negative cache or remote server says so
  174.         # Return no error with cached record or remotely fetched record
  175.         # Otherwise return server failure
  176.  
  177.  
  178. def read_query_stream(query_stream):      # query_stream should be a BitStream
  179.     # Read header
  180.     id, qr, opcode, aa, tc, rd, ra, z, ad, cd, rcode, qdcount, ancount, nscount, arcount = query_stream.readlist(
  181.         'bits:16, bool, uint:4, bool, bool, bits:1, bool, bits:1, bool, bits:1, uint:4, uint:16, uint:16, uint:16, uint:16')
  182.  
  183.     # Initial checks on whether the query is legitimate or supported
  184.     if qr:
  185.         raise ValueError('Value of QR should have been 0 as it should be a query.')
  186.     if opcode != 0:
  187.         raise NotImplementedError('Non-standard queries are not supported.')
  188.     if qdcount != 1:
  189.         raise ValueError('The current server only supports query with single question.')
  190.     if ancount != 0:
  191.         raise ValueError('There should be no answer records in query.')
  192.     if nscount != 0:
  193.         raise ValueError('There should be no authority records in query.')
  194.     if aa:
  195.         raise ValueError('Unexpected AA flag being set.')
  196.     if ra:
  197.         raise ValueError('Unexpected RA flag being set.')
  198.  
  199.     # Read question
  200.     questions_start_pos = query_stream.pos  # Used in copying the question section
  201.     domain_labels = []
  202.     while True:
  203.         label_flag = query_stream.read('bits:2')
  204.         if label_flag.bin == '00':
  205.             label_length = query_stream.read('uint:6')
  206.             if label_length > 0:
  207.                 label = query_stream.read('bytes:' + str(label_length))
  208.                 domain_labels.append(label)
  209.             else:
  210.                 break
  211.         else:
  212.             raise ValueError('Question name label flag unrecognized.')
  213.     full_domain = '.'.join([str(label, 'utf-8') for label in domain_labels])
  214.     if not full_domain.endswith('.'):
  215.         full_domain = full_domain + '.'
  216.     full_domain = full_domain.lower()  # Convert to lower case
  217.     query_type_value = query_stream.read('uint:16')
  218.     record_type_dict = {1: 'A', 2: 'NS', 3: 'MD', 4: 'MF', 5: 'CNAME', 6: 'SOA', 7: 'MB', 8: 'MG',
  219.                         9: 'MR', 10: 'NULL', 11: 'WKS', 12: 'PTR', 13: 'HINFO', 14: 'MINFO', 15: 'MX',
  220.                         16: 'TXT', 17: 'RP', 18: 'AFSDB', 19: 'X25', 20: 'ISDN', 21: 'RT', 22: 'NSAP',
  221.                         23: 'NSAP-PTR', 24: 'SIG', 25: 'KEY', 26: 'PX', 27: 'GPOS', 28: 'AAAA',
  222.                         29: 'LOC', 30: 'NXT', 31: 'EID', 32: 'NIMLOC', 33: 'SRV', 34: 'ATMA',
  223.                         35: 'NAPTR', 36: 'KX', 37: 'CERT', 38: 'A6', 39: 'DNAME', 40: 'SINK',
  224.                         41: 'OPT', 42: 'APL', 43: 'DS', 44: 'SSHFP', 45: 'IPSECKEY', 46: 'RRSIG',
  225.                         47: 'NSEC', 48: 'DNSKEY', 49: 'DHCID', 50: 'NSEC3', 51: 'NSEC3PARAM',
  226.                         52: 'TLSA', 53: 'SMIMEA', 55: 'HIP', 59: 'CDS', 60: 'CDSKEY',
  227.                         61: 'OPENGPGKEY', 99: 'SPF', 100: 'UINFO', 101: 'UID', 102: 'GID',
  228.                         103: 'UNSPEC', 249: 'TKEY', 250: 'TSIG', 251: 'IXFR', 252: 'AXFR',
  229.                         253: 'MAILB', 254: 'MAILA', 255: 'ALL', 256: 'URI', 257: 'CAA', 32768: 'TA',
  230.                         32769: 'DLV'}
  231.     query_type = record_type_dict[query_type_value]
  232.     query_class_value = query_stream.read('uint:16')
  233.     if query_class_value != 1:
  234.         raise NotImplementedError('The current server supports Internet class only.')
  235.     questions_end_pos = query_stream.pos
  236.     question_section = query_stream[questions_start_pos:questions_end_pos]
  237.  
  238.     # Ignore other sections
  239.     return full_domain, query_type, id, rd, cd, question_section
  240.  
  241.  
  242. def fetch_cached_data(RecordsCache, NegativeCache, full_domain, query_type):
  243.     with pony.orm.db_session:
  244.         cached_record = RecordsCache.get(full_domain=full_domain, query_type=query_type)
  245.         if cached_record:
  246.             cached_record.used_in_current_period = True
  247.             cached_address = cached_record.data
  248.             return 'positive', cached_address
  249.         if NegativeCache.exists(full_domain=full_domain):
  250.             return 'nxdomain', None
  251.         return 'nocache', None
  252.  
  253.  
  254. def fetch_remote_data(requests_session, full_domain, query_type):
  255.     result = requests_session.get(DOH_URL, params={'name': full_domain, 'type': query_type},
  256.                                   headers={'Accept': 'application/dns-json'}, timeout=10)
  257.     result_json = result.json()
  258.     status_code = int(result_json['Status'])  # 0 = no error, 1 = format error, 2 = server failure,
  259.                                               # 3 = no such domain, 4 = not implemented, 5 = refused
  260.     if status_code == 0:  # Received NOERROR
  261.         try:
  262.             answers = result_json['Answer']
  263.             for answer in answers:
  264.                 if query_type == 'A':
  265.                     if answer['type'] == 1:
  266.                         return 'noerror', answer['data']
  267.         except KeyError:
  268.             pass
  269.         # If there is no data in Answer section or no answer section
  270.         return 'nodata', None
  271.     elif status_code == 3:
  272.         return 'nxdomain', None
  273.     elif status_code == 2:
  274.         return 'servfail', None
  275.     else:
  276.         raise NotImplementedError('Unsupported status code received.')
  277.  
  278.  
  279. def cache_remote_positive_answer(RecordsCache, full_domain, query_type, remote_data):
  280.     with pony.orm.db_session:
  281.         # Automatic handling by ponyorm
  282.         RecordsCache(full_domain=full_domain, query_type=query_type, data=remote_data,
  283.                      used_in_current_period=True,
  284.                      used_in_previous_period=False)
  285.     print('Cached record:', full_domain, query_type, remote_data)
  286.  
  287.  
  288. def cache_remote_nxdomain_answer(NegativeCache, full_domain):
  289.     with pony.orm.db_session:
  290.         # Automatic handling by ponyorm
  291.         NegativeCache(full_domain=full_domain)
  292.     print('Cached NXDOMAIN:', full_domain)
  293.  
  294.  
  295. def construct_positive_response(id, rd, cd, question_section, address_data):
  296.     # See https://tools.ietf.org/html/rfc1035#section-4.1.1
  297.     # Also https://tools.ietf.org/html/rfc6895
  298.     response = BitArray()
  299.  
  300.     # Header section begins
  301.     header = BitArray()
  302.  
  303.     id = id                                # 16bits ID identifier
  304.     qr = BitArray('uint:1=1')              # 1bit specifying if it's query(0) or response(1)
  305.     opcode = BitArray('uint:4=0')          # 4bits specifying query kind. 0 = standard
  306.     aa = BitArray('uint:1=0')              # 1bit specifying if response is authorative(1)
  307.     tc = BitArray('uint:1=0')              # 1bit specifying if message is truncated(1)
  308.     rd = rd                                # Copy from query, 1bit specifying if recursion is desired(1) by sender
  309.     ra = BitArray('uint:1=0')              # 1bit specifying if recursion is available(1) from server
  310.     z = BitArray('uint:1=0')               # 1bit reserved for future use, zero
  311.     ad = BitArray('uint:1=0')              # 1bit stating all records in answer and authority sections are authentic (DNSSEC)
  312.     cd = cd                                # Copy from query, 1bit requesting no signature validation by upstream servers (DNSSEC)
  313.     rcode = BitArray('uint:4=0')           # 4bits stating query status, 0 = no error, 1 = format error,
  314.                                            # 2 = server failure, 3 = name error, 4 = not implemented,
  315.                                            # 5 = refused, 6-15 = reserved
  316.     qdcount = BitArray(uint=1, length=16)  # unsigned 16bits int specifying number of entries in question section
  317.     ancount = BitArray(uint=1, length=16)  # unsigned 16bits int specifying number of records in answer section
  318.     nscount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in authority section
  319.     arcount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in additional record section
  320.  
  321.     header.append(id)
  322.     header.append(qr)
  323.     header.append(opcode)
  324.     header.append(aa)
  325.     header.append(tc)
  326.     header.append(rd)
  327.     header.append(ra)
  328.     header.append(z)
  329.     header.append(ad)
  330.     header.append(cd)
  331.     header.append(rcode)
  332.     header.append(qdcount)
  333.     header.append(ancount)
  334.     header.append(nscount)
  335.     header.append(arcount)
  336.     response.append(header)
  337.     # Header section ends
  338.  
  339.     # Question section, copied from query
  340.     response.append(question_section)
  341.  
  342.     # Answer section begins
  343.     ttl = abs(int(PERIOD * 86400))
  344.     offset_starting_point = 12  # It can be 12 only assuming query only has one question
  345.  
  346.     rr_part = BitArray()
  347.     # NAME
  348.     rr_part.append(BitArray('uint:2=3'))  # 2bits, '11', flag for compression
  349.     rr_part.append(BitArray(uint=offset_starting_point, length=14))  # 14bit, offset specified in compression
  350.     # TYPE
  351.     rr_part.append(BitArray('uint:16=1'))  # A record
  352.     # CLASS
  353.     rr_part.append(BitArray('uint:16=1'))  # Internet Class
  354.     # TTL
  355.     rr_part.append(BitArray(uint=ttl, length=32))  # TTL in seconds
  356.     # RDLENGTH, specifying length of data
  357.     rr_part.append(BitArray('uint:16=4'))  # IP address in A record requires 4 octets(bytes)
  358.     # RDATA
  359.     ipaddress = [BitArray(uint=int(part), length=8) for part in address_data.split('.')]  # Like socket.inet_aton()
  360.     for part in ipaddress:
  361.         rr_part.append(part)
  362.     response.append(rr_part)
  363.     # Answer section ends
  364.  
  365.     # Ignore other sections
  366.     return response
  367.  
  368.  
  369. def construct_nodata_response(id, rd, cd, question_section):
  370.     # See https://tools.ietf.org/html/rfc1035#section-4.1.1
  371.     # Also https://tools.ietf.org/html/rfc6895
  372.     response = BitArray()
  373.  
  374.     # Header section begins
  375.     header = BitArray()
  376.  
  377.     id = id                                # 16bits ID identifier
  378.     qr = BitArray('uint:1=1')              # 1bit specifying if it's query(0) or response(1)
  379.     opcode = BitArray('uint:4=0')          # 4bits specifying query kind. 0 = standard
  380.     aa = BitArray('uint:1=0')              # 1bit specifying if response is authorative(1)
  381.     tc = BitArray('uint:1=0')              # 1bit specifying if message is truncated(1)
  382.     rd = rd                                # Copy from query, 1bit specifying if recursion is desired(1) by sender
  383.     ra = BitArray('uint:1=0')              # 1bit specifying if recursion is available(1) from server
  384.     z = BitArray('uint:1=0')               # 1bit reserved for future use, zero
  385.     ad = BitArray('uint:1=0')              # 1bit stating all records in answer and authority sections are authentic (DNSSEC)
  386.     cd = cd                                # Copy from query, 1bit requesting no signature validation by upstream servers (DNSSEC)
  387.     rcode = BitArray('uint:4=0')           # 4bits stating query status, 0 = no error, 1 = format error,
  388.                                            # 2 = server failure, 3 = name error, 4 = not implemented,
  389.                                            # 5 = refused, 6-15 = reserved
  390.                                            # In NODATA scenario, it should be 0
  391.     qdcount = BitArray(uint=1, length=16)  # unsigned 16bits int specifying number of entries in question section
  392.     ancount = BitArray(uint=0, length=16)  # unsigned 16bits int specifying number of records in answer section. In NODATA scenario, it should be 0
  393.     nscount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in authority section
  394.     arcount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in additional record section
  395.  
  396.     header.append(id)
  397.     header.append(qr)
  398.     header.append(opcode)
  399.     header.append(aa)
  400.     header.append(tc)
  401.     header.append(rd)
  402.     header.append(ra)
  403.     header.append(z)
  404.     header.append(ad)
  405.     header.append(cd)
  406.     header.append(rcode)
  407.     header.append(qdcount)
  408.     header.append(ancount)
  409.     header.append(nscount)
  410.     header.append(arcount)
  411.     response.append(header)
  412.     # Header section ends
  413.  
  414.     # Question section, copied from query
  415.     response.append(question_section)
  416.  
  417.     # No Answer section in NODATA scenario
  418.  
  419.     # Ignore other sections
  420.     return response
  421.  
  422.  
  423. def construct_nxdomain_response(id, rd, cd, question_section):
  424.     # See https://tools.ietf.org/html/rfc1035#section-4.1.1
  425.     # Also https://tools.ietf.org/html/rfc6895
  426.     response = BitArray()
  427.  
  428.     # Header section begins
  429.     header = BitArray()
  430.  
  431.     id = id                                # 16bits ID identifier
  432.     qr = BitArray('uint:1=1')              # 1bit specifying if it's query(0) or response(1)
  433.     opcode = BitArray('uint:4=0')          # 4bits specifying query kind. 0 = standard
  434.     aa = BitArray('uint:1=0')              # 1bit specifying if response is authorative(1)
  435.     tc = BitArray('uint:1=0')              # 1bit specifying if message is truncated(1)
  436.     rd = rd                                # Copy from query, 1bit specifying if recursion is desired(1) by sender
  437.     ra = BitArray('uint:1=0')              # 1bit specifying if recursion is available(1) from server
  438.     z = BitArray('uint:1=0')               # 1bit reserved for future use, zero
  439.     ad = BitArray('uint:1=0')              # 1bit stating all records in answer and authority sections are authentic (DNSSEC)
  440.     cd = cd                                # Copy from query, 1bit requesting no signature validation by upstream servers (DNSSEC)
  441.     rcode = BitArray('uint:4=3')           # 4bits stating query status, 0 = no error, 1 = format error,
  442.                                            # 2 = server failure, 3 = name error, 4 = not implemented,
  443.                                            # 5 = refused, 6-15 = reserved
  444.                                            # In NXDOMAIN scenario, it should be 3
  445.     qdcount = BitArray(uint=1, length=16)  # unsigned 16bits int specifying number of entries in question section
  446.     ancount = BitArray(uint=0, length=16)  # unsigned 16bits int specifying number of records in answer section
  447.     nscount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in authority section
  448.     arcount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in additional record section
  449.  
  450.     header.append(id)
  451.     header.append(qr)
  452.     header.append(opcode)
  453.     header.append(aa)
  454.     header.append(tc)
  455.     header.append(rd)
  456.     header.append(ra)
  457.     header.append(z)
  458.     header.append(ad)
  459.     header.append(cd)
  460.     header.append(rcode)
  461.     header.append(qdcount)
  462.     header.append(ancount)
  463.     header.append(nscount)
  464.     header.append(arcount)
  465.     response.append(header)
  466.     # Header section ends
  467.  
  468.     # Question section, copied from query
  469.     response.append(question_section)
  470.  
  471.     # No Answer section in NXDOMAIN scenario
  472.  
  473.     # Ignore other sections
  474.     return response
  475.  
  476.  
  477. def construct_servfail_response(id, rd, cd, question_section):
  478.     # See https://tools.ietf.org/html/rfc1035#section-4.1.1
  479.     # Also https://tools.ietf.org/html/rfc6895
  480.     response = BitArray()
  481.  
  482.     # Header section begins
  483.     header = BitArray()
  484.  
  485.     id = id                                # 16bits ID identifier
  486.     qr = BitArray('uint:1=1')              # 1bit specifying if it's query(0) or response(1)
  487.     opcode = BitArray('uint:4=0')          # 4bits specifying query kind. 0 = standard
  488.     aa = BitArray('uint:1=0')              # 1bit specifying if response is authorative(1)
  489.     tc = BitArray('uint:1=0')              # 1bit specifying if message is truncated(1)
  490.     rd = rd                                # Copy from query, 1bit specifying if recursion is desired(1) by sender
  491.     ra = BitArray('uint:1=0')              # 1bit specifying if recursion is available(1) from server
  492.     z = BitArray('uint:1=0')               # 1bit reserved for future use, zero
  493.     ad = BitArray('uint:1=0')              # 1bit stating all records in answer and authority sections are authentic (DNSSEC)
  494.     cd = cd                                # Copy from query, 1bit requesting no signature validation by upstream servers (DNSSEC)
  495.     rcode = BitArray('uint:4=2')           # 4bits stating query status, 0 = no error, 1 = format error,
  496.                                            # 2 = server failure, 3 = name error, 4 = not implemented,
  497.                                            # 5 = refused, 6-15 = reserved
  498.                                            # In SERVFAIL scenario, it should be 2
  499.     qdcount = BitArray(uint=1, length=16)  # unsigned 16bits int specifying number of entries in question section
  500.     ancount = BitArray(uint=0, length=16)  # unsigned 16bits int specifying number of records in answer section
  501.     nscount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in authority section
  502.     arcount = BitArray('uint:16=0')        # unsigned 16bits int specifying number of records in additional record section
  503.  
  504.     header.append(id)
  505.     header.append(qr)
  506.     header.append(opcode)
  507.     header.append(aa)
  508.     header.append(tc)
  509.     header.append(rd)
  510.     header.append(ra)
  511.     header.append(z)
  512.     header.append(ad)
  513.     header.append(cd)
  514.     header.append(rcode)
  515.     header.append(qdcount)
  516.     header.append(ancount)
  517.     header.append(nscount)
  518.     header.append(arcount)
  519.     response.append(header)
  520.     # Header section ends
  521.  
  522.     # Question section, copied from query
  523.     response.append(question_section)
  524.  
  525.     # No Answer section in SERVFAIL scenario
  526.  
  527.     # Ignore other sections
  528.     return response
  529.  
  530.  
  531. def change_period(RecordsCache, NegativeCache):
  532.     print('Starting to change period.')
  533.     with pony.orm.db_session:
  534.         # Delete cached records unused in both current period and previous period
  535.         RecordsCache.select(lambda record: (not record.used_in_current_period) and (not record.used_in_previous_period)).delete(bulk=True)
  536.  
  537.         # Update used_in_current_period and used_in_previous_period
  538.         cached_records = RecordsCache.select()
  539.         for record in cached_records:
  540.             record.used_in_previous_period = record.used_in_current_period
  541.             record.used_in_current_period = False
  542.  
  543.         # Delete all negative cache
  544.         NegativeCache.select().delete(bulk=True)
  545.     print('Unused records discarded, negative cache deleted, and period columns updated in database.')
  546.     update_cache(RecordsCache, NegativeCache)
  547.  
  548.  
  549. def update_cache(RecordsCache, NegativeCache):
  550.     with pony.orm.db_session:
  551.         cached_records = RecordsCache.select()
  552.         records_to_update = deque((record.full_domain, record.query_type) for record in cached_records)
  553.  
  554.     requests_session = requests.Session()
  555.     while True:
  556.         try:
  557.             record = records_to_update.popleft()
  558.         except IndexError:
  559.             break
  560.         else:
  561.             full_domain, query_type = record[0], record[1]
  562.             try:
  563.                 remote_status, remote_data = fetch_remote_data(requests_session, full_domain, query_type)
  564.             except BaseException as e:
  565.                 print(e)
  566.                 print('Cache update failed, re-enqueue task:', full_domain, query_type)
  567.                 records_to_update.append(record)
  568.             else:
  569.                 if remote_status == 'noerror':
  570.                     try:
  571.                         with pony.orm.db_session:
  572.                             RecordsCache.get(full_domain=full_domain, query_type=query_type).set(data=remote_data)
  573.                     except BaseException as e:
  574.                         print(e)
  575.                     else:
  576.                         print('Updated cache:', full_domain, query_type, remote_data)
  577.                 else:
  578.                     print('Cache update failed, re-enqueue task:', full_domain, query_type)
  579.                     records_to_update.append(record)
  580.             sleep(5)
  581.  
  582.  
  583. def update_schedule_file(schedule):
  584.     try:
  585.         with open(file=SCHEDULE_FILE_FILENAME, mode='w+') as schedule_file:
  586.             schedule_file.writelines(str(schedule)+'\n')
  587.     except BaseException as e:
  588.         print(e)
  589.  
  590.  
  591. def timer_job(schedule, RecordsCache, NegativeCache):
  592.     print('Timer job started')
  593.     while True:
  594.         print('Seconds to next period change:', int(schedule - time()))
  595.         sleep(int(schedule - time()))
  596.         change_period(RecordsCache, NegativeCache)
  597.         schedule = int(time() + PERIOD * 86400)
  598.         update_schedule_file(schedule)
  599.  
  600.  
  601. def error_output_beautifier(function, args):
  602.     try:
  603.         function(*args)
  604.     except BaseException as e:
  605.         print(e)
  606.  
  607.  
  608. if __name__ == '__main__':
  609.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement