Add remote client/server clustering system

This commit is contained in:
RichardG867
2022-04-24 01:10:22 -03:00
parent 4ba4f14945
commit 738eaad75f

View File

@@ -16,11 +16,12 @@
# Copyright 2021 RichardG.
#
import getopt, os, multiprocessing, re, subprocess, sys
import getopt, os, pickle, multiprocessing, re, socket, subprocess, sys, threading
from . import analyzers, extractors, formatters, util
# Constants.
ANALYZER_MAX_CACHE_MB = 512
DEFAULT_REMOTE_PORT = 8620
# Extraction module.
@@ -166,8 +167,16 @@ def extract(dir_path, _, options):
# Start multiprocessing pool.
print('Starting extraction on directory {0}'.format(dir_number), end='', flush=True)
queue = multiprocessing.Queue(maxsize=options['threads'])
mp_pool = multiprocessing.Pool(options['threads'], initializer=extract_process, initargs=(queue, dir_number_path, next_dir_number_path, options['debug']))
queue_size = options['threads'] + len(options['remote_servers'])
queue = multiprocessing.Queue(maxsize=queue_size)
initargs = (queue, dir_number_path, next_dir_number_path, options['debug'])
mp_pool = multiprocessing.Pool(options['threads'], initializer=extract_process, initargs=initargs)
print(flush=True)
# Start remote clients.
remote_clients = []
for remote_server in options['remote_servers']:
remote_clients.append(RemoteClient(remote_server, 'x', initargs))
# Create next directory.
if not os.path.isdir(next_dir_number_path):
@@ -176,7 +185,6 @@ def extract(dir_path, _, options):
# Scan directory structure. I really wanted this to have file-level
# granularity, but IntelExtractor and InterleaveBIOSExtractor
# both require directory-level granularity for inspecting other files.
print(flush=True)
found_any_files = False
for scan_dir_path, scan_dir_names, scan_file_names in os.walk(dir_number_path):
if len(scan_file_names) > 0:
@@ -201,11 +209,17 @@ def extract(dir_path, _, options):
dir_number += 1
# Stop multiprocessing pool and wait for its workers to finish.
for _ in range(options['threads']):
for _ in range(queue_size):
queue.put(None)
mp_pool.close()
mp_pool.join()
# Stop remote clients and wait for them to finish.
for client in remote_clients:
client.close()
for client in remote_clients:
client.join()
# Create 0 directory if it doesn't exist.
print('Merging directories:', end=' ')
merge_dest_path = os.path.join(dir_path, '0')
@@ -592,6 +606,235 @@ def analyze(dir_path, formatter_args, options):
return 0
# Remote server module.
class RemoteClient:
"""State and functions for communicating with a remote server."""
def __init__(self, addr, action, initargs):
# Initialize state.
self.action = action
self.initargs = initargs[1:]
self.queue = initargs[0]
self.sock = self.f = None
self.queue_lock = threading.Lock()
self.write_lock = threading.Lock()
self.close_event = threading.Event()
self.close_event.clear()
# Parse address:port.
addr_split = addr.split(':')
self.port = DEFAULT_REMOTE_PORT
if len(addr_split) == 0:
return
elif len(addr_split) == 1:
self.addr = addr_split[0]
else:
self.port = int(addr_split[1])
self.addr = addr_split[0]
# Start client thread.
self.queue_thread = None
self.client_thread = threading.Thread(target=self.client_thread_func)
self.client_thread.daemon = True
self.client_thread.start()
def client_thread_func(self):
"""Thread function for a remote client."""
# Connect to server.
print('Connecting to {0}:{1}\n'.format(self.addr, self.port), end='')
self.sock = socket.create_connection((self.addr, self.port))
self.f = self.sock.makefile('rwb')
print('Connected to {0}:{1}\n'.format(self.addr, self.port), end='')
# Start multiprocessing pool.
self.f.write((self.action + '\n').encode('utf8', 'ignore'))
self.f.write(pickle.dumps(self.initargs))
self.f.flush()
# Read responses from server.
while True:
try:
line = self.f.readline().rstrip(b'\r\n')
except:
break
if not line:
break
if line[0:1] in b'xa':
# Multiprocessing pool started, now start the queue thread.
self.queue_thread = threading.Thread(target=self.queue_thread_func)
self.queue_thread.daemon = True
self.queue_thread.start()
elif line[0:1] == b'q':
# Allow queue thread to proceed.
try:
self.queue_lock.release()
except:
pass
elif line[0:1] == b'j':
# We're done.
self.close_event.set()
break
# Close connection.
try:
self.f.close()
except:
pass
try:
self.sock.close()
except:
pass
print('Disconnected from {0}:{1}\n'.format(self.addr, self.port), end='')
def queue_thread_func(self):
"""Thread function to remove items from the local
queue and push them to the remote server's queue."""
while True:
# Wait for the queue to be available.
self.queue_lock.acquire()
# Read queue item.
item = self.queue.get()
if item == None: # special item to stop the loop
self.close()
break
# Send queue item to server.
scan_dir_path, scan_file_names = item
with self.write_lock:
self.f.write(b'q' + scan_dir_path.encode('utf8', 'ignore'))
for scan_file_name in scan_file_names:
self.f.write(b'\x00' + scan_file_name.encode('utf8', 'ignore'))
self.f.write(b'\n')
self.f.flush()
def close(self):
"""Close connection to the server."""
# Write stop message.
with self.write_lock:
try:
self.f.write(b'j\n')
self.f.flush()
except:
return
def join(self):
"""Wait for the server connection to be closed."""
self.close_event.wait()
class RemoteServerClient:
"""State and functions for communicating with remote clients."""
def __init__(self, accept, options):
# Initialize state.
self.sock, self.addr = accept
self.options = options
self.queue = self.mp_pool = None
self.write_lock = threading.Lock()
self.queue_lock = threading.Lock()
self.f = self.sock.makefile('rwb')
# Start client thread.
self.client_thread = threading.Thread(target=self.client_thread_func)
self.client_thread.daemon = True
self.client_thread.start()
def client_thread_func(self):
"""Thread function for a remote client."""
print(self.addr, 'New connection')
# Parse commands.
while True:
try:
line = self.f.readline().rstrip(b'\r\n')
except:
break
if not line:
break
if line[0:1] in b'xa':
# Start multiprocessing pool.
print(self.addr, 'Starting pool for', (line[0] == b'x') and 'extraction' or 'analysis')
self.queue = multiprocessing.Queue(maxsize=self.options['threads'])
if line[0:1] == b'x':
func = extract_process
else:
func = analyze_process
self.mp_pool = multiprocessing.Pool(self.options['threads'], initializer=func, initargs=(self.queue,) + pickle.load(self.f))
elif line[0:1] == b'q':
# Add directory to queue.
file_list = [item.decode('utf8', 'ignore') for item in line[1:].split(b'\x00')]
if self.options['debug']:
print(self.addr, 'Queuing', file_list[0], 'with', len(file_list) - 1, 'files')
if self.queue:
self.queue.put((file_list[0], file_list[1:]))
else:
print(self.addr, 'Attempted queuing with no queue')
elif line[0:1] == b'j':
# Stop multiprocessing pool and wait for its workers to finish.
print(self.addr, 'Waiting for pool')
if self.mp_pool and self.queue:
for _ in range(self.options['threads']):
self.queue.put(None)
self.mp_pool.close()
self.mp_pool.join()
self.mp_pool = None
else:
print(self.addr, 'Attempted pool wait with no pool/queue')
# Write acknowledgement.
with self.write_lock:
self.f.write(line[0:1] + b'\n')
self.f.flush()
# Stop if requested by the client.
if line[0:1] == b'j':
break
# Close connection.
print(self.addr, 'Closing connection')
try:
self.f.close()
except:
pass
try:
self.sock.close()
except:
pass
if self.mp_pool:
self.mp_pool.close()
self.mp_pool.join()
def remote_server(dir_path, formatter_args, options):
# Create server and listen for connections.
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('', options['remote_port']))
server.listen(5)
print('Listening on port', options['remote_port'])
# Receive connections.
try:
while True:
RemoteServerClient(server.accept(), options)
except KeyboardInterrupt:
pass
# Close server.
print('Closing server')
server.close()
return 0
def main():
# Set default options.
mode = None
@@ -603,15 +846,17 @@ def main():
'hyperlink': False,
'threads': 0,
'docker-usage': False,
'remote_servers': [],
'remote_port': 0,
}
# Parse arguments.
args, remainder = getopt.gnu_getopt(sys.argv[1:], 'xadf:hnrt', ['extract', 'analyze', 'debug', 'format=', 'hyperlink', 'no-headers', 'array', 'threads', 'docker-usage'])
args, remainder = getopt.gnu_getopt(sys.argv[1:], 'xadf:hnrt', ['extract', 'analyze', 'debug', 'format=', 'hyperlink', 'no-headers', 'array', 'threads', 'remote=', 'remote-server', 'docker-usage'])
for opt, arg in args:
if opt in ('-x', '--extract'):
mode = 'extract'
mode = extract
elif opt in ('-a', '--analyze'):
mode = 'analyze'
mode = analyze
elif opt in ('-d', '--debug'):
options['debug'] = True
elif opt in ('-f', '--format'):
@@ -627,19 +872,28 @@ def main():
options['threads'] = int(arg)
except:
pass
elif opt == '--remote':
options['remote_servers'].append(arg)
elif opt == '--remote-server':
mode = remote_server
try:
options['remote_port'] = int(remainder[0])
except:
pass
remainder.append(None) # dummy
elif opt == '--docker-usage':
options['docker-usage'] = True
if len(remainder) > 0:
# Set default thread count.
# Set default numeric options.
if options['threads'] <= 0:
options['threads'] = options['debug'] and 1 or (os.cpu_count() or 4)
if options['remote_port'] <= 0:
options['remote_port'] = DEFAULT_REMOTE_PORT
# Run mode handler.
if mode == 'extract':
return extract(remainder[0], remainder[1:], options)
elif mode == 'analyze':
return analyze(remainder[0], remainder[1:], options)
if mode:
return mode(remainder[0], remainder[1:], options)
# Print usage.
if options['docker-usage']: