diff --git a/biostools/__main__.py b/biostools/__main__.py index 8a8e3af..eecd7d8 100644 --- a/biostools/__main__.py +++ b/biostools/__main__.py @@ -16,11 +16,12 @@ # Copyright 2021 RichardG. # -import getopt, os, multiprocessing, re, subprocess, sys +import getopt, os, pickle, multiprocessing, re, socket, subprocess, sys, threading from . import analyzers, extractors, formatters, util # Constants. ANALYZER_MAX_CACHE_MB = 512 +DEFAULT_REMOTE_PORT = 8620 # Extraction module. @@ -166,8 +167,16 @@ def extract(dir_path, _, options): # Start multiprocessing pool. print('Starting extraction on directory {0}'.format(dir_number), end='', flush=True) - queue = multiprocessing.Queue(maxsize=options['threads']) - mp_pool = multiprocessing.Pool(options['threads'], initializer=extract_process, initargs=(queue, dir_number_path, next_dir_number_path, options['debug'])) + queue_size = options['threads'] + len(options['remote_servers']) + queue = multiprocessing.Queue(maxsize=queue_size) + initargs = (queue, dir_number_path, next_dir_number_path, options['debug']) + mp_pool = multiprocessing.Pool(options['threads'], initializer=extract_process, initargs=initargs) + print(flush=True) + + # Start remote clients. + remote_clients = [] + for remote_server in options['remote_servers']: + remote_clients.append(RemoteClient(remote_server, 'x', initargs)) # Create next directory. if not os.path.isdir(next_dir_number_path): @@ -176,7 +185,6 @@ def extract(dir_path, _, options): # Scan directory structure. I really wanted this to have file-level # granularity, but IntelExtractor and InterleaveBIOSExtractor # both require directory-level granularity for inspecting other files. - print(flush=True) found_any_files = False for scan_dir_path, scan_dir_names, scan_file_names in os.walk(dir_number_path): if len(scan_file_names) > 0: @@ -201,11 +209,17 @@ def extract(dir_path, _, options): dir_number += 1 # Stop multiprocessing pool and wait for its workers to finish. - for _ in range(options['threads']): + for _ in range(queue_size): queue.put(None) mp_pool.close() mp_pool.join() + # Stop remote clients and wait for them to finish. + for client in remote_clients: + client.close() + for client in remote_clients: + client.join() + # Create 0 directory if it doesn't exist. print('Merging directories:', end=' ') merge_dest_path = os.path.join(dir_path, '0') @@ -592,6 +606,235 @@ def analyze(dir_path, formatter_args, options): return 0 +# Remote server module. + +class RemoteClient: + """State and functions for communicating with a remote server.""" + + def __init__(self, addr, action, initargs): + # Initialize state. + self.action = action + self.initargs = initargs[1:] + self.queue = initargs[0] + + self.sock = self.f = None + self.queue_lock = threading.Lock() + self.write_lock = threading.Lock() + self.close_event = threading.Event() + self.close_event.clear() + + # Parse address:port. + addr_split = addr.split(':') + self.port = DEFAULT_REMOTE_PORT + if len(addr_split) == 0: + return + elif len(addr_split) == 1: + self.addr = addr_split[0] + else: + self.port = int(addr_split[1]) + self.addr = addr_split[0] + + # Start client thread. + self.queue_thread = None + self.client_thread = threading.Thread(target=self.client_thread_func) + self.client_thread.daemon = True + self.client_thread.start() + + def client_thread_func(self): + """Thread function for a remote client.""" + + # Connect to server. + print('Connecting to {0}:{1}\n'.format(self.addr, self.port), end='') + self.sock = socket.create_connection((self.addr, self.port)) + self.f = self.sock.makefile('rwb') + print('Connected to {0}:{1}\n'.format(self.addr, self.port), end='') + + # Start multiprocessing pool. + self.f.write((self.action + '\n').encode('utf8', 'ignore')) + self.f.write(pickle.dumps(self.initargs)) + self.f.flush() + + # Read responses from server. + while True: + try: + line = self.f.readline().rstrip(b'\r\n') + except: + break + if not line: + break + + if line[0:1] in b'xa': + # Multiprocessing pool started, now start the queue thread. + self.queue_thread = threading.Thread(target=self.queue_thread_func) + self.queue_thread.daemon = True + self.queue_thread.start() + elif line[0:1] == b'q': + # Allow queue thread to proceed. + try: + self.queue_lock.release() + except: + pass + elif line[0:1] == b'j': + # We're done. + self.close_event.set() + break + + # Close connection. + try: + self.f.close() + except: + pass + try: + self.sock.close() + except: + pass + print('Disconnected from {0}:{1}\n'.format(self.addr, self.port), end='') + + def queue_thread_func(self): + """Thread function to remove items from the local + queue and push them to the remote server's queue.""" + + while True: + # Wait for the queue to be available. + self.queue_lock.acquire() + + # Read queue item. + item = self.queue.get() + if item == None: # special item to stop the loop + self.close() + break + + # Send queue item to server. + scan_dir_path, scan_file_names = item + with self.write_lock: + self.f.write(b'q' + scan_dir_path.encode('utf8', 'ignore')) + for scan_file_name in scan_file_names: + self.f.write(b'\x00' + scan_file_name.encode('utf8', 'ignore')) + self.f.write(b'\n') + self.f.flush() + + def close(self): + """Close connection to the server.""" + + # Write stop message. + with self.write_lock: + try: + self.f.write(b'j\n') + self.f.flush() + except: + return + + def join(self): + """Wait for the server connection to be closed.""" + self.close_event.wait() + +class RemoteServerClient: + """State and functions for communicating with remote clients.""" + + def __init__(self, accept, options): + # Initialize state. + self.sock, self.addr = accept + self.options = options + self.queue = self.mp_pool = None + self.write_lock = threading.Lock() + self.queue_lock = threading.Lock() + + self.f = self.sock.makefile('rwb') + + # Start client thread. + self.client_thread = threading.Thread(target=self.client_thread_func) + self.client_thread.daemon = True + self.client_thread.start() + + def client_thread_func(self): + """Thread function for a remote client.""" + + print(self.addr, 'New connection') + + # Parse commands. + while True: + try: + line = self.f.readline().rstrip(b'\r\n') + except: + break + if not line: + break + + if line[0:1] in b'xa': + # Start multiprocessing pool. + print(self.addr, 'Starting pool for', (line[0] == b'x') and 'extraction' or 'analysis') + self.queue = multiprocessing.Queue(maxsize=self.options['threads']) + if line[0:1] == b'x': + func = extract_process + else: + func = analyze_process + self.mp_pool = multiprocessing.Pool(self.options['threads'], initializer=func, initargs=(self.queue,) + pickle.load(self.f)) + elif line[0:1] == b'q': + # Add directory to queue. + file_list = [item.decode('utf8', 'ignore') for item in line[1:].split(b'\x00')] + if self.options['debug']: + print(self.addr, 'Queuing', file_list[0], 'with', len(file_list) - 1, 'files') + if self.queue: + self.queue.put((file_list[0], file_list[1:])) + else: + print(self.addr, 'Attempted queuing with no queue') + elif line[0:1] == b'j': + # Stop multiprocessing pool and wait for its workers to finish. + print(self.addr, 'Waiting for pool') + if self.mp_pool and self.queue: + for _ in range(self.options['threads']): + self.queue.put(None) + self.mp_pool.close() + self.mp_pool.join() + self.mp_pool = None + else: + print(self.addr, 'Attempted pool wait with no pool/queue') + + # Write acknowledgement. + with self.write_lock: + self.f.write(line[0:1] + b'\n') + self.f.flush() + + # Stop if requested by the client. + if line[0:1] == b'j': + break + + # Close connection. + print(self.addr, 'Closing connection') + try: + self.f.close() + except: + pass + try: + self.sock.close() + except: + pass + if self.mp_pool: + self.mp_pool.close() + self.mp_pool.join() + +def remote_server(dir_path, formatter_args, options): + # Create server and listen for connections. + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind(('', options['remote_port'])) + server.listen(5) + + print('Listening on port', options['remote_port']) + + # Receive connections. + try: + while True: + RemoteServerClient(server.accept(), options) + except KeyboardInterrupt: + pass + + # Close server. + print('Closing server') + server.close() + + return 0 + + def main(): # Set default options. mode = None @@ -603,15 +846,17 @@ def main(): 'hyperlink': False, 'threads': 0, 'docker-usage': False, + 'remote_servers': [], + 'remote_port': 0, } # Parse arguments. - args, remainder = getopt.gnu_getopt(sys.argv[1:], 'xadf:hnrt', ['extract', 'analyze', 'debug', 'format=', 'hyperlink', 'no-headers', 'array', 'threads', 'docker-usage']) + args, remainder = getopt.gnu_getopt(sys.argv[1:], 'xadf:hnrt', ['extract', 'analyze', 'debug', 'format=', 'hyperlink', 'no-headers', 'array', 'threads', 'remote=', 'remote-server', 'docker-usage']) for opt, arg in args: if opt in ('-x', '--extract'): - mode = 'extract' + mode = extract elif opt in ('-a', '--analyze'): - mode = 'analyze' + mode = analyze elif opt in ('-d', '--debug'): options['debug'] = True elif opt in ('-f', '--format'): @@ -627,19 +872,28 @@ def main(): options['threads'] = int(arg) except: pass + elif opt == '--remote': + options['remote_servers'].append(arg) + elif opt == '--remote-server': + mode = remote_server + try: + options['remote_port'] = int(remainder[0]) + except: + pass + remainder.append(None) # dummy elif opt == '--docker-usage': options['docker-usage'] = True if len(remainder) > 0: - # Set default thread count. + # Set default numeric options. if options['threads'] <= 0: options['threads'] = options['debug'] and 1 or (os.cpu_count() or 4) + if options['remote_port'] <= 0: + options['remote_port'] = DEFAULT_REMOTE_PORT # Run mode handler. - if mode == 'extract': - return extract(remainder[0], remainder[1:], options) - elif mode == 'analyze': - return analyze(remainder[0], remainder[1:], options) + if mode: + return mode(remainder[0], remainder[1:], options) # Print usage. if options['docker-usage']: