From 2e840e7c381f88425952c6fa9d68e0d433084a5a Mon Sep 17 00:00:00 2001 From: Sam Chudnick Date: Mon, 4 Jul 2022 20:03:27 -0400 Subject: Support both TLS encrypted sessions and plaintext sessions Added support for both TLS and plaintext connections. Server can accept both types of connection simultaneously or in different combinations (i.e encrypted client and plaintext PAM). Added options for specifying dedicated TLS ports on server. Added --plain options for client and PAM to force plaintext connections, default is to use encrypted connections. Configuring encrypted client and PAM connections and plaintext server connections allows for use of a reverse proxy setup with something like nginx. This will avoid having to expose the MFA server directly in setups that traverse the internet. --- client/client.py | 44 +++++++++++++++++++++++++++++---- pam/pam_mfa.py | 46 ++++++++++++++++++++++++++++------- server/mfad.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 137 insertions(+), 27 deletions(-) diff --git a/client/client.py b/client/client.py index cc22d0b..0388073 100755 --- a/client/client.py +++ b/client/client.py @@ -26,6 +26,7 @@ def parse_arguments(): parser.add_argument("--config",type=str,help="Path to config file",\ default="/etc/mfa/mfa.conf") parser.add_argument("--key",type=str,help="Client connection key") + parser.add_argument("--plain",action="store_true",help="Connect without TLS") parser.add_argument("--insecure",action="store_true",\ help="Accept invalid TLS certificates") return parser.parse_args() @@ -37,7 +38,7 @@ def prompt_user(prompt): return result -def init_connection(mfa_server, client_port, client_key, insecure=False): +def init_connection_tls(mfa_server, client_port, client_key, insecure=False): # Attempts to connect to MFA server with provided address,port, and key. # Repeats attempt once a seconds until timeout is reached. # Returns socket or None if unable to connect @@ -70,6 +71,27 @@ def init_connection(mfa_server, client_port, client_key, insecure=False): return connection +def init_connection(mfa_server, client_port, client_key): + connection = None + timeout = 0 + timeout_length = 5 + sleep_length = 1 + while connection == None and timeout < timeout_length: + try: + connection = socket.create_connection((mfa_server,client_port)) + connection.send(client_key.encode(FORMAT)) + response = connection.recv(ACK_LENGTH).decode(FORMAT) + if response == ACK_MESSAGE: + print("connected to mfa server") + elif response == DISCONNECT_MESSAGE: + print("server terminated connection") + sys.exit(1) + except ConnectionError: + time.sleep(sleep_length) + timeout += sleep_length + return connection + + def read_config(config_file): parser = configparser.ConfigParser(inline_comment_prefixes="#") parser.read(config_file) @@ -84,6 +106,7 @@ def get_vars(args,confparser): server = None port = None key = None + plain = None insecure = None # Set values from config file first @@ -91,7 +114,13 @@ def get_vars(args,confparser): server = confparser.get("client","server",fallback=None) port = confparser.get("client","port",fallback=None) key = confparser.get("client","key",fallback=None) - insecure = bool(confparser.get("client","insecure",fallback=False)) + plain = confparser.get("client","plain",fallback=False) + insecure = confparser.get("client","insecure",fallback=False) + + if plain.lower() == "false": + plain = False + if insecure.lower() == "false": + insecure = False # Let command line args overwrite any values if args.server != None: @@ -100,6 +129,8 @@ def get_vars(args,confparser): port = args.port if args.key != None: key = args.key + if args.plain: + plain = args.plain if args.insecure: insecure = args.insecure @@ -108,7 +139,7 @@ def get_vars(args,confparser): print("error: one or more items unspecified") sys.exit(1) - return server,port,key,insecure + return server,port,key,plain,insecure def main(): @@ -116,7 +147,7 @@ def main(): args = parse_arguments() confparser = read_config(args.config) - mfa_server,client_port,client_key,insecure = get_vars(args,confparser) + mfa_server,client_port,client_key,plain,insecure = get_vars(args,confparser) # Exit if invalid key is provided if len(client_key) != KEY_LENGTH: @@ -124,7 +155,10 @@ def main(): sys.exit(1) # Open connection to server - conn = init_connection(mfa_server,client_port,client_key,insecure) + if plain: + conn = init_connection(mfa_server,client_port,client_key) + else: + conn = init_connection_tls(mfa_server,client_port,client_key,insecure) if conn == None: print("timed out attempting to connect to server") sys.exit(1) diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py index 5a5a112..a5105a2 100755 --- a/pam/pam_mfa.py +++ b/pam/pam_mfa.py @@ -34,11 +34,12 @@ def parse_arguments(): default="/etc/mfa/mfa.conf") parser.add_argument("--server",type=str,help="MFA server address") parser.add_argument("--port",type=str,help="MFA server PAM connection port") + parser.add_argument("--plain",action="store_true",help="Connect without TLS") parser.add_argument("--insecure",action="store_true", help="Accept invalid TLS certificates") return parser.parse_args() -def init_connection(mfa_server, pam_port, insecure): +def init_connection_tls(mfa_server, pam_port, insecure): # Attempts to connect to MFA server with provided address and port # Repeats connection attempts once per second until timeout is reached # Returns the socket if connection was successful or None otherwise @@ -52,7 +53,6 @@ def init_connection(mfa_server, pam_port, insecure): context.verify_mode = 0 while connection == None and timeout < timeout_length: try: - #connection = socket.create_connection((mfa_server,client_port)) connection = context.wrap_socket(socket.socket(socket.AF_INET), server_hostname=mfa_server) connection.connect((mfa_server,int(pam_port))) @@ -65,6 +65,24 @@ def init_connection(mfa_server, pam_port, insecure): return None +def init_connection(mfa_server, pam_port): + # Attempts to connect to MFA server with provided address and port + # Repeats connection attempts once per second until timeout is reached + # Returns the socket if connection was successful or None otherwise + connection = None + timeout = 0 + timeout_length = 5 + sleep_length = 1 + while connection == None and timeout < timeout_length: + try: + connection = socket.create_connection((mfa_server,pam_port)) + return connection + except (ConnectionError,ConnectionRefusedError): + time.sleep(sleep_length) + timeout += sleep_length + return None + + def read_config(config_file): # Read config file for server and port info # Return tuple (server,port) @@ -94,19 +112,28 @@ def get_vars(args,confparser): server = None port = None + plain = None insecure = None # Set values from config file first if confparser.has_section("pam"): server = confparser.get("pam","server",fallback=None) port = confparser.get("pam","port",fallback=None) - insecure = bool(confparser.get("pam","insecure",fallback=False)) + plain = confparser.get("client","plain",fallback=False) + insecure = confparser.get("client","insecure",fallback=False) + if plain.lower() == "false": + plain = False + if insecure.lower() == "false": + insecure = False + # Let command line args overwrite any values if args.server != None: server = args.server if args.port != None: port = args.port + if args.plain: + plain = args.plain if args.insecure: insecure = args.insecure @@ -115,7 +142,7 @@ def get_vars(args,confparser): print("error: one or more items unspecified") sys.exit(1) - return server,port,insecure + return server,port,plain,insecure def main(): @@ -125,7 +152,7 @@ def main(): # Get arguments args = parse_arguments() confparser = read_config(args.config) - mfa_server,pam_port,insecure = get_vars(args,confparser) + mfa_server,pam_port,plain,insecure = get_vars(args,confparser) user = args.user service = args.service @@ -144,12 +171,13 @@ def main(): hostname = args.host data = user + "," + hostname + "," + service - # Initalize connection to MFA server. Quit if unable to connect. - connection = init_connection(mfa_server,pam_port,insecure) + if plain: + connection = init_connection(mfa_server, pam_port) + else: + connection = init_connection_tls(mfa_server,pam_port,insecure) if connection == None: - print(failed) - sys.exit(1) + die("failed to connect") # Send authentication data to MFA server data_length = len(data) diff --git a/server/mfad.py b/server/mfad.py index a2a783b..cc5073b 100755 --- a/server/mfad.py +++ b/server/mfad.py @@ -52,6 +52,10 @@ def parse_arguments(): parser.add_argument("--address",type=str,help="Bind Address") parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") parser.add_argument("--client-port",type=int,help="Port for client connections") + parser.add_argument("--pam-tls-port",type=int, + help="Port to listen for encrypted PAM requests") + parser.add_argument("--client-tls-port",type=int, + help="Port for encrypted client connections") parser.add_argument("--database",type=str,help="Path to database file") parser.add_argument("--cert",type=str,help="TLS certificate file") parser.add_argument("--key",type=str,help="TLS private key file") @@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers): return context -def listen_client(db, addr, port, tls_context): +def listen_client_tls(db, addr, port, tls_context): with socket.create_server((addr, port)) as sock: with tls_context.wrap_socket(sock, server_side=True) as tls_socket: while True: @@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context): print("client: ssl handshake error") -def listen_pam(db, addr, port, tls_context): +def listen_pam_tls(db, addr, port, tls_context): with socket.create_server((addr,port)) as sock: with tls_context.wrap_socket(sock, server_side=True) as tls_socket: while True: @@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context): print("pam: ssl handshake error") +def listen_client(db, addr, port): + with socket.create_server((addr, port)) as sock: + while True: + conn, addr = sock.accept() + thread = threading.Thread(target=handle_client,args=(db,conn,addr)) + thread.start() + + +def listen_pam(db, addr, port): + with socket.create_server((addr, port)) as sock: + while True: + conn, addr = sock.accept() + thread = threading.Thread(target=handle_pam,args=(db,conn,addr)) + thread.start() + + ################################################################################ def create_db(db): @@ -312,12 +332,13 @@ def create_db(db): def get_vars(args,confparser): if not os.path.exists(args.config): - print("Unable to open config file") - sys.exit(1) + die("Unable to open config file") bind_addr = None client_port = None + client_tls_port = None pam_port = None + pam_tls_port = None database = None cert = None key = None @@ -328,6 +349,8 @@ def get_vars(args,confparser): bind_addr = confparser.get("mfad","address",fallback=None) client_port = confparser.get("mfad","client-port",fallback=None) pam_port = confparser.get("mfad","pam-port",fallback=None) + client_tls_port = confparser.get("mfad","client-tls-port",fallback=None) + pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None) database = confparser.get("mfad","database",fallback=None) cert = confparser.get("mfad","cert",fallback=None) key = confparser.get("mfad","key",fallback=None) @@ -340,6 +363,10 @@ def get_vars(args,confparser): client_port = args.client_port if args.pam_port != None: pam_port = args.pam_port + if args.client_tls_port != None: + client_tls_port = args.client_tls_port + if args.pam_tls_port != None: + pam_tls_port = args.pam_tls_port if args.database != None: database = args.database if args.cert != None: @@ -350,28 +377,49 @@ def get_vars(args,confparser): ciphers = args.ciphers # Exit if any value is null - if None in [bind_addr,client_port,pam_port,database,cert,key]: - print("error: one or more items unspecified") - sys.exit(1) + if None in [bind_addr,database,cert,key]: + die("error: one or more items unspecified") + + ports = [pam_port, client_port, pam_tls_port, client_tls_port] + for port in ports: + if port == None: + ports[ports.index(port)] = 0 - return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers + return bind_addr, ports, database, cert, key, ciphers def main(): args = parse_arguments() confparser = read_config(args.config) - bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) + bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser) + pam_port = int(ports[0]) + client_port = int(ports[1]) + pam_tls_port = int(ports[2]) + client_tls_port = int(ports[3]) + + if pam_port == 0 and pam_tls_port == 0: + die("error: not listening for any PAM connections") + if client_port == 0 and client_tls_port == 0: + die("error: not listening for any client connections") + if not os.path.exists(db): print("Creating DB") create_db(db) context = get_tls_context(cert,key,ciphers) - clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) - pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) - clients.start() - pam.start() + + if pam_port != 0: + threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start() + if client_port != 0: + threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start() + if pam_tls_port != 0: + threading.Thread(target=listen_pam_tls, + args=(db,bind_addr,pam_tls_port,context)).start() + if client_tls_port != 0: + threading.Thread(target=listen_client_tls, + args=(db,bind_addr,client_tls_port,context)).start() if __name__ == '__main__': -- cgit v1.2.3