From 46564f357c175c7a01a36422307f05b543a83190 Mon Sep 17 00:00:00 2001 From: Sam Chudnick Date: Mon, 4 Jul 2022 13:44:26 -0400 Subject: Added option to specify TLS ciphers Added a command line argument and config file option to set the TLS ciphers that the server will use. Set to Mozilla intermediate compatibility by default. --- server/mfad.py | 50 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/server/mfad.py b/server/mfad.py index b9557bd..a2a783b 100755 --- a/server/mfad.py +++ b/server/mfad.py @@ -55,6 +55,7 @@ def parse_arguments(): parser.add_argument("--database",type=str,help="Path to database file") parser.add_argument("--cert",type=str,help="TLS certificate file") parser.add_argument("--key",type=str,help="TLS private key file") + parser.add_argument("--ciphers",type=str,help="TLS ciphers to use") parser.add_argument("--config",type=str,help="Alternate config file location",\ default="/etc/mfa/mfa.conf") return parser.parse_args() @@ -241,11 +242,33 @@ def handle_pam(db, conn, addr): conn.send(str(decision).encode(FORMAT)) -def listen_client(db, addr, port, cert, key): +def get_tls_context(cert, key, ciphers): + if not os.path.exists(cert): + die("error: cannot open cert file") + if not os.path.exists(key): + die("error: cannot open key file") + + # Create context, load cert and key context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context.load_cert_chain(certfile=cert, keyfile=key) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.maximum_version = ssl.TLSVersion.TLSv1_3 + + if ciphers == None: + # Mozilla intermediate compatibility + ciphers = "ECDHE+ECDSA+AESGCM:ECDHE+aRSA+AESGCM:ECDHE+ECDSA+CHACHA20:ECDHE+aRSA+CHACHA20:DHE+aRSA+AESGCM:!aNULL:!eNULL" + + # Set ciphers + try: + context.set_ciphers(ciphers) + except ssl.SSLError: + die("error: invalid cipherlist") + + return context + +def listen_client(db, addr, port, tls_context): with socket.create_server((addr, port)) as sock: - with context.wrap_socket(sock, server_side=True) as tls_socket: + with tls_context.wrap_socket(sock, server_side=True) as tls_socket: while True: try: conn, addr = tls_socket.accept() @@ -255,11 +278,9 @@ def listen_client(db, addr, port, cert, key): print("client: ssl handshake error") -def listen_pam(db, addr, port, cert, key): - context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - context.load_cert_chain(certfile=cert, keyfile=key) +def listen_pam(db, addr, port, tls_context): with socket.create_server((addr,port)) as sock: - with context.wrap_socket(sock, server_side=True) as tls_socket: + with tls_context.wrap_socket(sock, server_side=True) as tls_socket: while True: try: conn, addr = tls_socket.accept() @@ -300,6 +321,7 @@ def get_vars(args,confparser): database = None cert = None key = None + ciphers = None # Set values from config file first if confparser.has_section("mfad"): @@ -309,6 +331,7 @@ def get_vars(args,confparser): database = confparser.get("mfad","database",fallback=None) cert = confparser.get("mfad","cert",fallback=None) key = confparser.get("mfad","key",fallback=None) + ciphers = confparser.get("mfad","ciphers",fallback=None) # Let command line args overwrite any values if args.address != None: @@ -323,29 +346,30 @@ def get_vars(args,confparser): cert = args.cert if args.key != None: key = args.key + if args.ciphers != None: + ciphers = args.ciphers # Exit if any value is null if None in [bind_addr,client_port,pam_port,database,cert,key]: print("error: one or more items unspecified") sys.exit(1) - return bind_addr, int(client_port), int(pam_port), database, cert, key + return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers def main(): args = parse_arguments() confparser = read_config(args.config) - bind_addr, client_port, pam_port, db, cert, key = get_vars(args,confparser) + bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) if not os.path.exists(db): print("Creating DB") create_db(db) - - clients = threading.Thread(target=listen_client, - args=(db, bind_addr,client_port,cert,key)) - pam = threading.Thread(target=listen_pam, - args=(db, bind_addr,pam_port,cert,key)) + + context = get_tls_context(cert,key,ciphers) + clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) + pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) clients.start() pam.start() -- cgit v1.2.3