summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorSam Chudnick <sam@chudnick.com>2022-07-04 20:03:27 -0400
committerSam Chudnick <sam@chudnick.com>2022-07-04 20:03:27 -0400
commit2e840e7c381f88425952c6fa9d68e0d433084a5a (patch)
tree31f7888ade33fbc112bd2b7509aac5c39bb2af82 /server
parent46564f357c175c7a01a36422307f05b543a83190 (diff)
Support both TLS encrypted sessions and plaintext sessions
Added support for both TLS and plaintext connections. Server can accept both types of connection simultaneously or in different combinations (i.e encrypted client and plaintext PAM). Added options for specifying dedicated TLS ports on server. Added --plain options for client and PAM to force plaintext connections, default is to use encrypted connections. Configuring encrypted client and PAM connections and plaintext server connections allows for use of a reverse proxy setup with something like nginx. This will avoid having to expose the MFA server directly in setups that traverse the internet.
Diffstat (limited to 'server')
-rwxr-xr-xserver/mfad.py74
1 files changed, 61 insertions, 13 deletions
diff --git a/server/mfad.py b/server/mfad.py
index a2a783b..cc5073b 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -52,6 +52,10 @@ def parse_arguments():
52 parser.add_argument("--address",type=str,help="Bind Address") 52 parser.add_argument("--address",type=str,help="Bind Address")
53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") 53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests")
54 parser.add_argument("--client-port",type=int,help="Port for client connections") 54 parser.add_argument("--client-port",type=int,help="Port for client connections")
55 parser.add_argument("--pam-tls-port",type=int,
56 help="Port to listen for encrypted PAM requests")
57 parser.add_argument("--client-tls-port",type=int,
58 help="Port for encrypted client connections")
55 parser.add_argument("--database",type=str,help="Path to database file") 59 parser.add_argument("--database",type=str,help="Path to database file")
56 parser.add_argument("--cert",type=str,help="TLS certificate file") 60 parser.add_argument("--cert",type=str,help="TLS certificate file")
57 parser.add_argument("--key",type=str,help="TLS private key file") 61 parser.add_argument("--key",type=str,help="TLS private key file")
@@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers):
266 270
267 return context 271 return context
268 272
269def listen_client(db, addr, port, tls_context): 273def listen_client_tls(db, addr, port, tls_context):
270 with socket.create_server((addr, port)) as sock: 274 with socket.create_server((addr, port)) as sock:
271 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 275 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
272 while True: 276 while True:
@@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context):
278 print("client: ssl handshake error") 282 print("client: ssl handshake error")
279 283
280 284
281def listen_pam(db, addr, port, tls_context): 285def listen_pam_tls(db, addr, port, tls_context):
282 with socket.create_server((addr,port)) as sock: 286 with socket.create_server((addr,port)) as sock:
283 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 287 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
284 while True: 288 while True:
@@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context):
290 print("pam: ssl handshake error") 294 print("pam: ssl handshake error")
291 295
292 296
297def listen_client(db, addr, port):
298 with socket.create_server((addr, port)) as sock:
299 while True:
300 conn, addr = sock.accept()
301 thread = threading.Thread(target=handle_client,args=(db,conn,addr))
302 thread.start()
303
304
305def listen_pam(db, addr, port):
306 with socket.create_server((addr, port)) as sock:
307 while True:
308 conn, addr = sock.accept()
309 thread = threading.Thread(target=handle_pam,args=(db,conn,addr))
310 thread.start()
311
312
293################################################################################ 313################################################################################
294 314
295def create_db(db): 315def create_db(db):
@@ -312,12 +332,13 @@ def create_db(db):
312 332
313def get_vars(args,confparser): 333def get_vars(args,confparser):
314 if not os.path.exists(args.config): 334 if not os.path.exists(args.config):
315 print("Unable to open config file") 335 die("Unable to open config file")
316 sys.exit(1)
317 336
318 bind_addr = None 337 bind_addr = None
319 client_port = None 338 client_port = None
339 client_tls_port = None
320 pam_port = None 340 pam_port = None
341 pam_tls_port = None
321 database = None 342 database = None
322 cert = None 343 cert = None
323 key = None 344 key = None
@@ -328,6 +349,8 @@ def get_vars(args,confparser):
328 bind_addr = confparser.get("mfad","address",fallback=None) 349 bind_addr = confparser.get("mfad","address",fallback=None)
329 client_port = confparser.get("mfad","client-port",fallback=None) 350 client_port = confparser.get("mfad","client-port",fallback=None)
330 pam_port = confparser.get("mfad","pam-port",fallback=None) 351 pam_port = confparser.get("mfad","pam-port",fallback=None)
352 client_tls_port = confparser.get("mfad","client-tls-port",fallback=None)
353 pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None)
331 database = confparser.get("mfad","database",fallback=None) 354 database = confparser.get("mfad","database",fallback=None)
332 cert = confparser.get("mfad","cert",fallback=None) 355 cert = confparser.get("mfad","cert",fallback=None)
333 key = confparser.get("mfad","key",fallback=None) 356 key = confparser.get("mfad","key",fallback=None)
@@ -340,6 +363,10 @@ def get_vars(args,confparser):
340 client_port = args.client_port 363 client_port = args.client_port
341 if args.pam_port != None: 364 if args.pam_port != None:
342 pam_port = args.pam_port 365 pam_port = args.pam_port
366 if args.client_tls_port != None:
367 client_tls_port = args.client_tls_port
368 if args.pam_tls_port != None:
369 pam_tls_port = args.pam_tls_port
343 if args.database != None: 370 if args.database != None:
344 database = args.database 371 database = args.database
345 if args.cert != None: 372 if args.cert != None:
@@ -350,28 +377,49 @@ def get_vars(args,confparser):
350 ciphers = args.ciphers 377 ciphers = args.ciphers
351 378
352 # Exit if any value is null 379 # Exit if any value is null
353 if None in [bind_addr,client_port,pam_port,database,cert,key]: 380 if None in [bind_addr,database,cert,key]:
354 print("error: one or more items unspecified") 381 die("error: one or more items unspecified")
355 sys.exit(1) 382
383 ports = [pam_port, client_port, pam_tls_port, client_tls_port]
384 for port in ports:
385 if port == None:
386 ports[ports.index(port)] = 0
356 387
357 return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers 388 return bind_addr, ports, database, cert, key, ciphers
358 389
359 390
360def main(): 391def main():
361 args = parse_arguments() 392 args = parse_arguments()
362 confparser = read_config(args.config) 393 confparser = read_config(args.config)
363 394
364 bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) 395 bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser)
396 pam_port = int(ports[0])
397 client_port = int(ports[1])
398 pam_tls_port = int(ports[2])
399 client_tls_port = int(ports[3])
400
401 if pam_port == 0 and pam_tls_port == 0:
402 die("error: not listening for any PAM connections")
403 if client_port == 0 and client_tls_port == 0:
404 die("error: not listening for any client connections")
405
365 406
366 if not os.path.exists(db): 407 if not os.path.exists(db):
367 print("Creating DB") 408 print("Creating DB")
368 create_db(db) 409 create_db(db)
369 410
370 context = get_tls_context(cert,key,ciphers) 411 context = get_tls_context(cert,key,ciphers)
371 clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) 412
372 pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) 413 if pam_port != 0:
373 clients.start() 414 threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start()
374 pam.start() 415 if client_port != 0:
416 threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start()
417 if pam_tls_port != 0:
418 threading.Thread(target=listen_pam_tls,
419 args=(db,bind_addr,pam_tls_port,context)).start()
420 if client_tls_port != 0:
421 threading.Thread(target=listen_client_tls,
422 args=(db,bind_addr,client_tls_port,context)).start()
375 423
376 424
377if __name__ == '__main__': 425if __name__ == '__main__':