From dbfb415edfbe1bc8db3a1272c28189785e623860 Mon Sep 17 00:00:00 2001 From: Sam Chudnick Date: Mon, 4 Jul 2022 12:32:16 -0400 Subject: Added options for certificate and key files Added command line arguments and config file options to specify TLS certificate and TLS private key files. --- server/mfad.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/server/mfad.py b/server/mfad.py index 18a048a..b9557bd 100755 --- a/server/mfad.py +++ b/server/mfad.py @@ -52,7 +52,9 @@ def parse_arguments(): parser.add_argument("--address",type=str,help="Bind Address") parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") parser.add_argument("--client-port",type=int,help="Port for client connections") - parser.add_argument("--database",type=str,help="Path to alternate database file") + parser.add_argument("--database",type=str,help="Path to database file") + parser.add_argument("--cert",type=str,help="TLS certificate file") + parser.add_argument("--key",type=str,help="TLS private key file") parser.add_argument("--config",type=str,help="Alternate config file location",\ default="/etc/mfa/mfa.conf") return parser.parse_args() @@ -239,9 +241,9 @@ def handle_pam(db, conn, addr): conn.send(str(decision).encode(FORMAT)) -def listen_client(db, addr, port): +def listen_client(db, addr, port, cert, key): context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") + context.load_cert_chain(certfile=cert, keyfile=key) with socket.create_server((addr, port)) as sock: with context.wrap_socket(sock, server_side=True) as tls_socket: while True: @@ -253,9 +255,9 @@ def listen_client(db, addr, port): print("client: ssl handshake error") -def listen_pam(db, addr, port): +def listen_pam(db, addr, port, cert, key): context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") + context.load_cert_chain(certfile=cert, keyfile=key) with socket.create_server((addr,port)) as sock: with context.wrap_socket(sock, server_side=True) as tls_socket: while True: @@ -296,6 +298,8 @@ def get_vars(args,confparser): client_port = None pam_port = None database = None + cert = None + key = None # Set values from config file first if confparser.has_section("mfad"): @@ -303,37 +307,45 @@ def get_vars(args,confparser): client_port = confparser.get("mfad","client-port",fallback=None) pam_port = confparser.get("mfad","pam-port",fallback=None) database = confparser.get("mfad","database",fallback=None) + cert = confparser.get("mfad","cert",fallback=None) + key = confparser.get("mfad","key",fallback=None) # Let command line args overwrite any values - if args.address: + if args.address != None: bind_addr = args.address - if args.client_port: + if args.client_port != None: client_port = args.client_port - if args.pam_port: + if args.pam_port != None: pam_port = args.pam_port - if args.database: + if args.database != None: database = args.database + if args.cert != None: + cert = args.cert + if args.key != None: + key = args.key # Exit if any value is null - if None in [bind_addr,client_port,pam_port,database]: + if None in [bind_addr,client_port,pam_port,database,cert,key]: print("error: one or more items unspecified") sys.exit(1) - return bind_addr, int(client_port), int(pam_port), database + return bind_addr, int(client_port), int(pam_port), database, cert, key def main(): args = parse_arguments() confparser = read_config(args.config) - bind_addr, client_port, pam_port, db = get_vars(args,confparser) + bind_addr, client_port, pam_port, db, cert, key = get_vars(args,confparser) if not os.path.exists(db): print("Creating DB") create_db(db) - clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port)) - pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port)) + clients = threading.Thread(target=listen_client, + args=(db, bind_addr,client_port,cert,key)) + pam = threading.Thread(target=listen_pam, + args=(db, bind_addr,pam_port,cert,key)) clients.start() pam.start() -- cgit v1.2.3