From ce3c9f1e849b871db2fa91b5aa030e8ea471a7ca Mon Sep 17 00:00:00 2001 From: Sam Chudnick Date: Sun, 3 Jul 2022 05:45:21 -0400 Subject: Fixed issue caused by non-static database location Pass database location as argument where needed now that location is not static. --- server/mfad.py | 59 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 25 deletions(-) (limited to 'server') diff --git a/server/mfad.py b/server/mfad.py index 46fc0cc..17d2585 100755 --- a/server/mfad.py +++ b/server/mfad.py @@ -18,7 +18,6 @@ import argparse ## Return pass or fail response to PAM moudle -DB_NAME = "" HEADER_LENGTH = 64 KEY_LENGTH = 64 DISCONNECT_LENGTH = ACK_LENGTH = 3 @@ -43,6 +42,10 @@ CLIENT_SECRET_INDEX = 2 # key and a tuple of (socket,(addr,port)) as the value client_connections = dict() +def die(msg): + print(msg) + sys.exit(1) + def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--address",type=str,help="Bind Address") @@ -60,7 +63,7 @@ def read_config(config): return parser -def eval_mfa(client_key, mfa_methods, client_response): +def eval_mfa(db, client_key, mfa_methods, client_response): print("response: " + client_response) print("length: " + str(len(client_response))) print("methods: " + str(mfa_methods)) @@ -74,13 +77,13 @@ def eval_mfa(client_key, mfa_methods, client_response): totp_regex = re.compile(totp_format) matched = totp_regex.match(client_response) if matched: - return validate_totp(client_key, client_response) + return validate_totp(db, client_key, client_response) return DENIED -def validate_totp(client_key, client_response): +def validate_totp(db, client_key, client_response): secret = "" - with sqlite3.connect(DB_NAME) as conn: + with sqlite3.connect(db) as conn: c = conn.cursor() c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) client = c.fetchone() @@ -103,7 +106,7 @@ def validate_totp(client_key, client_response): # //TODO RSA public/private key pairs for proper authentication -def get_client_key(username,hostname,service): +def get_client_key(db, username,hostname,service): # Correlates a PAM request to a registered client # This is done by checking the PAM request against a preconfigured # database mapping request info (username,hostname,etc...) to clients @@ -111,7 +114,7 @@ def get_client_key(username,hostname,service): application = None client = None - with sqlite3.connect(DB_NAME) as conn: + with sqlite3.connect(db) as conn: c = conn.cursor() c.execute("""SELECT * FROM applications WHERE username=? AND hostname=? AND service=?""",(username,hostname,service)) @@ -151,7 +154,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10): conn.send(length_msg.encode(FORMAT)) conn.send(prompt_msg.encode(FORMAT)) # receive response - response_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) + header = conn.recv(HEADER_LENGTH).decode(FORMAT) + if header == "": + die("error: lost connection to client") + response_length = int(header) response = conn.recv(response_length).decode(FORMAT) return response except BrokenPipeError: @@ -165,10 +171,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10): return 0 -def validate_client(client_key): +def validate_client(db, client_key): # Validates a client client = None - with sqlite3.connect(DB_NAME) as conn: + with sqlite3.connect(db) as conn: c = conn.cursor() c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) client = c.fetchall() @@ -184,11 +190,11 @@ def validate_client(client_key): -def handle_client(conn, addr): +def handle_client(db, conn, addr): # Receive key from client key = conn.recv(KEY_LENGTH).decode(FORMAT) # Validate client - if not validate_client(key): + if not validate_client(db, key): print("WARNING: client attempted to connect with invalid key") conn.send(DISCONNECT_MESSAGE.encode(FORMAT)) conn.close() @@ -202,15 +208,18 @@ def parse_pam_data(data): # Parses pam data and returns (user,host,service) tuple return tuple(data.split(',')) -def handle_pam(conn, addr): +def handle_pam(db, conn, addr): # Get request and data from PAM module - data_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) + header = conn.recv(HEADER_LENGTH).decode(FORMAT) + if header == "": + die("error: lost connection to pam module") + data_length = int(header) pam_data = conn.recv(data_length).decode(FORMAT) print("Got pam_data: " + pam_data) user,host,service = parse_pam_data(pam_data) # Correlate request to client - client_key,mfa_methods = get_client_key(user,host,service) + client_key,mfa_methods = get_client_key(db, user,host,service) mfa_methods = mfa_methods.split(' ') if client_key == None: print("No applications found for user="+user+" host="+host+" service="+service) @@ -221,26 +230,26 @@ def handle_pam(conn, addr): response = prompt_client(client_key,user,host,service,mfa_methods) # Evaluate Response - decision = eval_mfa(client_key, mfa_methods, response) + decision = eval_mfa(db, client_key, mfa_methods, response) # Return response to PAM module # Respone will either be 0 for authenticated and 1 for denied conn.send(str(decision).encode(FORMAT)) -def listen_client(addr, port): +def listen_client(db, addr, port): with socket.create_server((addr, port)) as server: while True: conn, addr = server.accept() - thread = threading.Thread(target=handle_client,args=(conn,addr)) + thread = threading.Thread(target=handle_client,args=(db, conn,addr)) thread.start() -def listen_pam(addr, port): +def listen_pam(db, addr, port): with socket.create_server((addr,port)) as pam_server: while True: conn, addr = pam_server.accept() - thread = threading.Thread(target=handle_pam,args=(conn,addr)) + thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) thread.start() @@ -303,14 +312,14 @@ def main(): args = parse_arguments() confparser = read_config(args.config) - bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser) + bind_addr, client_port, pam_port, db = get_vars(args,confparser) - if not os.path.exists(DB_NAME): + if not os.path.exists(db): print("Creating DB") - create_db(DB_NAME) + create_db(db) - clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) - pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) + clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port)) + pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port)) clients.start() pam.start() -- cgit v1.2.3