From 2e840e7c381f88425952c6fa9d68e0d433084a5a Mon Sep 17 00:00:00 2001 From: Sam Chudnick Date: Mon, 4 Jul 2022 20:03:27 -0400 Subject: Support both TLS encrypted sessions and plaintext sessions Added support for both TLS and plaintext connections. Server can accept both types of connection simultaneously or in different combinations (i.e encrypted client and plaintext PAM). Added options for specifying dedicated TLS ports on server. Added --plain options for client and PAM to force plaintext connections, default is to use encrypted connections. Configuring encrypted client and PAM connections and plaintext server connections allows for use of a reverse proxy setup with something like nginx. This will avoid having to expose the MFA server directly in setups that traverse the internet. --- server/mfad.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 13 deletions(-) (limited to 'server') diff --git a/server/mfad.py b/server/mfad.py index a2a783b..cc5073b 100755 --- a/server/mfad.py +++ b/server/mfad.py @@ -52,6 +52,10 @@ def parse_arguments(): parser.add_argument("--address",type=str,help="Bind Address") parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") parser.add_argument("--client-port",type=int,help="Port for client connections") + parser.add_argument("--pam-tls-port",type=int, + help="Port to listen for encrypted PAM requests") + parser.add_argument("--client-tls-port",type=int, + help="Port for encrypted client connections") parser.add_argument("--database",type=str,help="Path to database file") parser.add_argument("--cert",type=str,help="TLS certificate file") parser.add_argument("--key",type=str,help="TLS private key file") @@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers): return context -def listen_client(db, addr, port, tls_context): +def listen_client_tls(db, addr, port, tls_context): with socket.create_server((addr, port)) as sock: with tls_context.wrap_socket(sock, server_side=True) as tls_socket: while True: @@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context): print("client: ssl handshake error") -def listen_pam(db, addr, port, tls_context): +def listen_pam_tls(db, addr, port, tls_context): with socket.create_server((addr,port)) as sock: with tls_context.wrap_socket(sock, server_side=True) as tls_socket: while True: @@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context): print("pam: ssl handshake error") +def listen_client(db, addr, port): + with socket.create_server((addr, port)) as sock: + while True: + conn, addr = sock.accept() + thread = threading.Thread(target=handle_client,args=(db,conn,addr)) + thread.start() + + +def listen_pam(db, addr, port): + with socket.create_server((addr, port)) as sock: + while True: + conn, addr = sock.accept() + thread = threading.Thread(target=handle_pam,args=(db,conn,addr)) + thread.start() + + ################################################################################ def create_db(db): @@ -312,12 +332,13 @@ def create_db(db): def get_vars(args,confparser): if not os.path.exists(args.config): - print("Unable to open config file") - sys.exit(1) + die("Unable to open config file") bind_addr = None client_port = None + client_tls_port = None pam_port = None + pam_tls_port = None database = None cert = None key = None @@ -328,6 +349,8 @@ def get_vars(args,confparser): bind_addr = confparser.get("mfad","address",fallback=None) client_port = confparser.get("mfad","client-port",fallback=None) pam_port = confparser.get("mfad","pam-port",fallback=None) + client_tls_port = confparser.get("mfad","client-tls-port",fallback=None) + pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None) database = confparser.get("mfad","database",fallback=None) cert = confparser.get("mfad","cert",fallback=None) key = confparser.get("mfad","key",fallback=None) @@ -340,6 +363,10 @@ def get_vars(args,confparser): client_port = args.client_port if args.pam_port != None: pam_port = args.pam_port + if args.client_tls_port != None: + client_tls_port = args.client_tls_port + if args.pam_tls_port != None: + pam_tls_port = args.pam_tls_port if args.database != None: database = args.database if args.cert != None: @@ -350,28 +377,49 @@ def get_vars(args,confparser): ciphers = args.ciphers # Exit if any value is null - if None in [bind_addr,client_port,pam_port,database,cert,key]: - print("error: one or more items unspecified") - sys.exit(1) + if None in [bind_addr,database,cert,key]: + die("error: one or more items unspecified") + + ports = [pam_port, client_port, pam_tls_port, client_tls_port] + for port in ports: + if port == None: + ports[ports.index(port)] = 0 - return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers + return bind_addr, ports, database, cert, key, ciphers def main(): args = parse_arguments() confparser = read_config(args.config) - bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) + bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser) + pam_port = int(ports[0]) + client_port = int(ports[1]) + pam_tls_port = int(ports[2]) + client_tls_port = int(ports[3]) + + if pam_port == 0 and pam_tls_port == 0: + die("error: not listening for any PAM connections") + if client_port == 0 and client_tls_port == 0: + die("error: not listening for any client connections") + if not os.path.exists(db): print("Creating DB") create_db(db) context = get_tls_context(cert,key,ciphers) - clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) - pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) - clients.start() - pam.start() + + if pam_port != 0: + threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start() + if client_port != 0: + threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start() + if pam_tls_port != 0: + threading.Thread(target=listen_pam_tls, + args=(db,bind_addr,pam_tls_port,context)).start() + if client_tls_port != 0: + threading.Thread(target=listen_client_tls, + args=(db,bind_addr,client_tls_port,context)).start() if __name__ == '__main__': -- cgit v1.2.3