summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Chudnick <sam@chudnick.com>2022-07-04 12:32:16 -0400
committerSam Chudnick <sam@chudnick.com>2022-07-04 12:32:16 -0400
commitdbfb415edfbe1bc8db3a1272c28189785e623860 (patch)
treec4f519975cacd13123dcb15bec0c807ceb00f0f7
parent755d7f5f94b720b028d085cf971c5935c130dec1 (diff)
Added options for certificate and key files
Added command line arguments and config file options to specify TLS certificate and TLS private key files.
-rwxr-xr-xserver/mfad.py40
1 files changed, 26 insertions, 14 deletions
diff --git a/server/mfad.py b/server/mfad.py
index 18a048a..b9557bd 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -52,7 +52,9 @@ def parse_arguments():
52 parser.add_argument("--address",type=str,help="Bind Address") 52 parser.add_argument("--address",type=str,help="Bind Address")
53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") 53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests")
54 parser.add_argument("--client-port",type=int,help="Port for client connections") 54 parser.add_argument("--client-port",type=int,help="Port for client connections")
55 parser.add_argument("--database",type=str,help="Path to alternate database file") 55 parser.add_argument("--database",type=str,help="Path to database file")
56 parser.add_argument("--cert",type=str,help="TLS certificate file")
57 parser.add_argument("--key",type=str,help="TLS private key file")
56 parser.add_argument("--config",type=str,help="Alternate config file location",\ 58 parser.add_argument("--config",type=str,help="Alternate config file location",\
57 default="/etc/mfa/mfa.conf") 59 default="/etc/mfa/mfa.conf")
58 return parser.parse_args() 60 return parser.parse_args()
@@ -239,9 +241,9 @@ def handle_pam(db, conn, addr):
239 conn.send(str(decision).encode(FORMAT)) 241 conn.send(str(decision).encode(FORMAT))
240 242
241 243
242def listen_client(db, addr, port): 244def listen_client(db, addr, port, cert, key):
243 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 245 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
244 context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") 246 context.load_cert_chain(certfile=cert, keyfile=key)
245 with socket.create_server((addr, port)) as sock: 247 with socket.create_server((addr, port)) as sock:
246 with context.wrap_socket(sock, server_side=True) as tls_socket: 248 with context.wrap_socket(sock, server_side=True) as tls_socket:
247 while True: 249 while True:
@@ -253,9 +255,9 @@ def listen_client(db, addr, port):
253 print("client: ssl handshake error") 255 print("client: ssl handshake error")
254 256
255 257
256def listen_pam(db, addr, port): 258def listen_pam(db, addr, port, cert, key):
257 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 259 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
258 context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") 260 context.load_cert_chain(certfile=cert, keyfile=key)
259 with socket.create_server((addr,port)) as sock: 261 with socket.create_server((addr,port)) as sock:
260 with context.wrap_socket(sock, server_side=True) as tls_socket: 262 with context.wrap_socket(sock, server_side=True) as tls_socket:
261 while True: 263 while True:
@@ -296,6 +298,8 @@ def get_vars(args,confparser):
296 client_port = None 298 client_port = None
297 pam_port = None 299 pam_port = None
298 database = None 300 database = None
301 cert = None
302 key = None
299 303
300 # Set values from config file first 304 # Set values from config file first
301 if confparser.has_section("mfad"): 305 if confparser.has_section("mfad"):
@@ -303,37 +307,45 @@ def get_vars(args,confparser):
303 client_port = confparser.get("mfad","client-port",fallback=None) 307 client_port = confparser.get("mfad","client-port",fallback=None)
304 pam_port = confparser.get("mfad","pam-port",fallback=None) 308 pam_port = confparser.get("mfad","pam-port",fallback=None)
305 database = confparser.get("mfad","database",fallback=None) 309 database = confparser.get("mfad","database",fallback=None)
310 cert = confparser.get("mfad","cert",fallback=None)
311 key = confparser.get("mfad","key",fallback=None)
306 312
307 # Let command line args overwrite any values 313 # Let command line args overwrite any values
308 if args.address: 314 if args.address != None:
309 bind_addr = args.address 315 bind_addr = args.address
310 if args.client_port: 316 if args.client_port != None:
311 client_port = args.client_port 317 client_port = args.client_port
312 if args.pam_port: 318 if args.pam_port != None:
313 pam_port = args.pam_port 319 pam_port = args.pam_port
314 if args.database: 320 if args.database != None:
315 database = args.database 321 database = args.database
322 if args.cert != None:
323 cert = args.cert
324 if args.key != None:
325 key = args.key
316 326
317 # Exit if any value is null 327 # Exit if any value is null
318 if None in [bind_addr,client_port,pam_port,database]: 328 if None in [bind_addr,client_port,pam_port,database,cert,key]:
319 print("error: one or more items unspecified") 329 print("error: one or more items unspecified")
320 sys.exit(1) 330 sys.exit(1)
321 331
322 return bind_addr, int(client_port), int(pam_port), database 332 return bind_addr, int(client_port), int(pam_port), database, cert, key
323 333
324 334
325def main(): 335def main():
326 args = parse_arguments() 336 args = parse_arguments()
327 confparser = read_config(args.config) 337 confparser = read_config(args.config)
328 338
329 bind_addr, client_port, pam_port, db = get_vars(args,confparser) 339 bind_addr, client_port, pam_port, db, cert, key = get_vars(args,confparser)
330 340
331 if not os.path.exists(db): 341 if not os.path.exists(db):
332 print("Creating DB") 342 print("Creating DB")
333 create_db(db) 343 create_db(db)
334 344
335 clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port)) 345 clients = threading.Thread(target=listen_client,
336 pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port)) 346 args=(db, bind_addr,client_port,cert,key))
347 pam = threading.Thread(target=listen_pam,
348 args=(db, bind_addr,pam_port,cert,key))
337 clients.start() 349 clients.start()
338 pam.start() 350 pam.start()
339 351