#!/usr/bin/env python3 import socket import os import sys import time import threading import pyotp import sqlite3 import re import configparser import argparse ## Listens for authentication request from PAM module ## Recevies connection from client ## Correlates authentication request to client connection ## Sends MFA prompt to client ## Evaluates response from client ## Return pass or fail response to PAM moudle DB_NAME = "" HEADER_LENGTH = 64 KEY_LENGTH = 64 DISCONNECT_LENGTH = ACK_LENGTH = 3 ACK_MESSAGE = "ACK" DISCONNECT_MESSAGE = "BYE" FORMAT = "utf-8" AUTHED = 0 DENIED = 1 # DB object index constants DB_USERNAME_INDEX = 0 DB_HOSTNAME_INDEX = 1 DB_SERVICE_INDEX = 2 DB_ALIAS_INDEX = 3 DB_MFAMETHODS_INDEX = 4 CLIENT_ALIAS_INDEX = 0 CLIENT_KEY_INDEX = 1 CLIENT_SECRET_INDEX = 2 # Stores connected clients as a dictionary with the client key as the dictionary # key and a tuple of (socket,(addr,port)) as the value client_connections = dict() def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--address",type=str,help="Bind Address") parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") parser.add_argument("--client-port",type=int,help="Port for client connections") parser.add_argument("--database",type=str,help="Path to alternate database file") parser.add_argument("--config",type=str,help="Alternate config file location",\ default="/etc/mfa/mfa.conf") return parser.parse_args() def read_config(config): parser = configparser.ConfigParser(inline_comment_prefixes="#") parser.read(config) return parser def eval_mfa(client_key, mfa_methods, client_response): print("response: " + client_response) print("length: " + str(len(client_response))) print("methods: " + str(mfa_methods)) # Evaluates MFA and decides if authenticated or denied # Returns 0 for authenticated on 1 for denied if "push" in mfa_methods and client_response == "allow": return AUTHED elif "totp" in mfa_methods and len(client_response) == 6: # Only attempt to validate if response is a valid TOTP format totp_format = (r'(\d)(\d)(\d)(\d)(\d)(\d)') totp_regex = re.compile(totp_format) matched = totp_regex.match(client_response) if matched: return validate_totp(client_key, client_response) return DENIED def validate_totp(client_key, client_response): secret = "" with sqlite3.connect(DB_NAME) as conn: c = conn.cursor() c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) client = c.fetchone() secret = client[CLIENT_SECRET_INDEX] totp = pyotp.TOTP(secret) print("Client Response: " + str(client_response)) print("Valid TOTP: " + str(totp.now())) if totp.verify(client_response): return AUTHED else: return DENIED ################################################################################ # Client is registered by admin with secret key stored on server # Client is provisioned with secret key and passes secret key to server on # connection for identification # Client key is used to identify client throughout communication process # //TODO RSA public/private key pairs for proper authentication def get_client_key(username,hostname,service): # Correlates a PAM request to a registered client # This is done by checking the PAM request against a preconfigured # database mapping request info (username,hostname,etc...) to clients # Returns a tuple consisting of the key and approved MFA methods application = None client = None with sqlite3.connect(DB_NAME) as conn: c = conn.cursor() c.execute("""SELECT * FROM applications WHERE username=? AND hostname=? AND service=?""",(username,hostname,service)) application = c.fetchone() # Return None if no results found if application == None: return application alias = application[DB_ALIAS_INDEX] c.execute("SELECT * FROM clients WHERE alias=?",(alias,)) client = c.fetchone() client_key = client[CLIENT_KEY_INDEX] methods = application[DB_MFAMETHODS_INDEX] return (client_key,methods) def prompt_client(client_key, user, host, service, methods, timeout=10): # Prompts client for MFA timer = 0 while timer < timeout: if client_key in client_connections.keys(): conn = client_connections[client_key][0] # Use try block to catch cases where client was connected and so # is in list but is not currently connected try: # send prompts methodstr = ", ".join(methods) methodstr = "Available methods: " + methodstr prompt_msg = "Login approved for user '" + user + \ "' attempting to access service '" + service + \ "' on host '" + host + "'?\n" + methodstr prompt_len = len(prompt_msg) length_msg = str(prompt_len) length_msg += ' ' * (HEADER_LENGTH - len(length_msg)) conn.send(length_msg.encode(FORMAT)) conn.send(prompt_msg.encode(FORMAT)) # receive response response_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) response = conn.recv(response_length).decode(FORMAT) return response except BrokenPipeError: client_connections.pop(client_key) timer += 1 time.sleep(1) else: timer +=1 time.sleep(1) return 0 def validate_client(client_key): # Validates a client client = None with sqlite3.connect(DB_NAME) as conn: c = conn.cursor() c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) client = c.fetchall() if len(client) == 0: # No client matches provided key, invalid return False elif len(client) == 1: return True else: print("A strange error has occurred") return False def handle_client(conn, addr): # Receive key from client key = conn.recv(KEY_LENGTH).decode(FORMAT) # Validate client if not validate_client(key): print("WARNING: client attempted to connect with invalid key") conn.send(DISCONNECT_MESSAGE.encode(FORMAT)) conn.close() else: conn.send(ACK_MESSAGE.encode(FORMAT)) client_connections[key] = (conn,addr) print("client connected with key " + key) def parse_pam_data(data): # Parses pam data and returns (user,host,service) tuple return tuple(data.split(',')) def handle_pam(conn, addr): # Get request and data from PAM module data_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) pam_data = conn.recv(data_length).decode(FORMAT) print("Got pam_data: " + pam_data) user,host,service = parse_pam_data(pam_data) # Correlate request to client client_key,mfa_methods = get_client_key(user,host,service) mfa_methods = mfa_methods.split(' ') if client_key == None: print("No applications found for user="+user+" host="+host+" service="+service) conn.send(str(DENIED).encode(FORMAT)) return # Prompt client response = prompt_client(client_key,user,host,service,mfa_methods) # Evaluate Response decision = eval_mfa(client_key, mfa_methods, response) # Return response to PAM module # Respone will either be 0 for authenticated and 1 for denied conn.send(str(decision).encode(FORMAT)) def listen_client(addr, port): with socket.create_server((addr, port)) as server: while True: conn, addr = server.accept() thread = threading.Thread(target=handle_client,args=(conn,addr)) thread.start() def listen_pam(addr, port): with socket.create_server((addr,port)) as pam_server: while True: conn, addr = pam_server.accept() thread = threading.Thread(target=handle_pam,args=(conn,addr)) thread.start() ################################################################################ def create_db(db): with sqlite3.connect(db) as conn: c = conn.cursor() c.execute("""CREATE TABLE applications ( username text, hostname text, service text, alias text, mfa_methods text )""") c.execute("""CREATE TABLE clients ( alias text, key text, totp_secret text )""") conn.commit() def get_vars(args,confparser): if not os.path.exists(args.config): print("Unable to open config file") sys.exit(1) bind_addr = None client_port = None pam_port = None database = None # Set values from config file first if confparser.has_section("mfad"): bind_addr = confparser.get("mfad","address",fallback=None) client_port = confparser.get("mfad","client-port",fallback=None) pam_port = confparser.get("mfad","pam-port",fallback=None) database = confparser.get("mfad","database",fallback=None) # Let command line args overwrite any values if args.address: bind_addr = args.address if args.client_port: client_port = args.client_port if args.pam_port: pam_port = args.pam_port if args.database: database = args.database # Exit if any value is null if None in [bind_addr,client_port,pam_port,database]: print("error: one or more items unspecified") sys.exit(1) return bind_addr, int(client_port), int(pam_port), database def main(): args = parse_arguments() confparser = read_config(args.config) bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser) if not os.path.exists(DB_NAME): print("Creating DB") create_db(DB_NAME) clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) clients.start() pam.start() if __name__ == '__main__': main()