diff options
author | Sam Chudnick <sam@chudnick.com> | 2022-07-04 20:03:27 -0400 |
---|---|---|
committer | Sam Chudnick <sam@chudnick.com> | 2022-07-04 20:03:27 -0400 |
commit | 2e840e7c381f88425952c6fa9d68e0d433084a5a (patch) | |
tree | 31f7888ade33fbc112bd2b7509aac5c39bb2af82 /server | |
parent | 46564f357c175c7a01a36422307f05b543a83190 (diff) |
Support both TLS encrypted sessions and plaintext sessions
Added support for both TLS and plaintext connections. Server can accept
both types of connection simultaneously or in different combinations
(i.e encrypted client and plaintext PAM). Added options for specifying
dedicated TLS ports on server. Added --plain options for client and PAM
to force plaintext connections, default is to use encrypted connections.
Configuring encrypted client and PAM connections and plaintext server
connections allows for use of a reverse proxy setup with something like
nginx. This will avoid having to expose the MFA server directly in setups
that traverse the internet.
Diffstat (limited to 'server')
-rwxr-xr-x | server/mfad.py | 74 |
1 files changed, 61 insertions, 13 deletions
diff --git a/server/mfad.py b/server/mfad.py index a2a783b..cc5073b 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
@@ -52,6 +52,10 @@ def parse_arguments(): | |||
52 | parser.add_argument("--address",type=str,help="Bind Address") | 52 | parser.add_argument("--address",type=str,help="Bind Address") |
53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") | 53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") |
54 | parser.add_argument("--client-port",type=int,help="Port for client connections") | 54 | parser.add_argument("--client-port",type=int,help="Port for client connections") |
55 | parser.add_argument("--pam-tls-port",type=int, | ||
56 | help="Port to listen for encrypted PAM requests") | ||
57 | parser.add_argument("--client-tls-port",type=int, | ||
58 | help="Port for encrypted client connections") | ||
55 | parser.add_argument("--database",type=str,help="Path to database file") | 59 | parser.add_argument("--database",type=str,help="Path to database file") |
56 | parser.add_argument("--cert",type=str,help="TLS certificate file") | 60 | parser.add_argument("--cert",type=str,help="TLS certificate file") |
57 | parser.add_argument("--key",type=str,help="TLS private key file") | 61 | parser.add_argument("--key",type=str,help="TLS private key file") |
@@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers): | |||
266 | 270 | ||
267 | return context | 271 | return context |
268 | 272 | ||
269 | def listen_client(db, addr, port, tls_context): | 273 | def listen_client_tls(db, addr, port, tls_context): |
270 | with socket.create_server((addr, port)) as sock: | 274 | with socket.create_server((addr, port)) as sock: |
271 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 275 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
272 | while True: | 276 | while True: |
@@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context): | |||
278 | print("client: ssl handshake error") | 282 | print("client: ssl handshake error") |
279 | 283 | ||
280 | 284 | ||
281 | def listen_pam(db, addr, port, tls_context): | 285 | def listen_pam_tls(db, addr, port, tls_context): |
282 | with socket.create_server((addr,port)) as sock: | 286 | with socket.create_server((addr,port)) as sock: |
283 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 287 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
284 | while True: | 288 | while True: |
@@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context): | |||
290 | print("pam: ssl handshake error") | 294 | print("pam: ssl handshake error") |
291 | 295 | ||
292 | 296 | ||
297 | def listen_client(db, addr, port): | ||
298 | with socket.create_server((addr, port)) as sock: | ||
299 | while True: | ||
300 | conn, addr = sock.accept() | ||
301 | thread = threading.Thread(target=handle_client,args=(db,conn,addr)) | ||
302 | thread.start() | ||
303 | |||
304 | |||
305 | def listen_pam(db, addr, port): | ||
306 | with socket.create_server((addr, port)) as sock: | ||
307 | while True: | ||
308 | conn, addr = sock.accept() | ||
309 | thread = threading.Thread(target=handle_pam,args=(db,conn,addr)) | ||
310 | thread.start() | ||
311 | |||
312 | |||
293 | ################################################################################ | 313 | ################################################################################ |
294 | 314 | ||
295 | def create_db(db): | 315 | def create_db(db): |
@@ -312,12 +332,13 @@ def create_db(db): | |||
312 | 332 | ||
313 | def get_vars(args,confparser): | 333 | def get_vars(args,confparser): |
314 | if not os.path.exists(args.config): | 334 | if not os.path.exists(args.config): |
315 | print("Unable to open config file") | 335 | die("Unable to open config file") |
316 | sys.exit(1) | ||
317 | 336 | ||
318 | bind_addr = None | 337 | bind_addr = None |
319 | client_port = None | 338 | client_port = None |
339 | client_tls_port = None | ||
320 | pam_port = None | 340 | pam_port = None |
341 | pam_tls_port = None | ||
321 | database = None | 342 | database = None |
322 | cert = None | 343 | cert = None |
323 | key = None | 344 | key = None |
@@ -328,6 +349,8 @@ def get_vars(args,confparser): | |||
328 | bind_addr = confparser.get("mfad","address",fallback=None) | 349 | bind_addr = confparser.get("mfad","address",fallback=None) |
329 | client_port = confparser.get("mfad","client-port",fallback=None) | 350 | client_port = confparser.get("mfad","client-port",fallback=None) |
330 | pam_port = confparser.get("mfad","pam-port",fallback=None) | 351 | pam_port = confparser.get("mfad","pam-port",fallback=None) |
352 | client_tls_port = confparser.get("mfad","client-tls-port",fallback=None) | ||
353 | pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None) | ||
331 | database = confparser.get("mfad","database",fallback=None) | 354 | database = confparser.get("mfad","database",fallback=None) |
332 | cert = confparser.get("mfad","cert",fallback=None) | 355 | cert = confparser.get("mfad","cert",fallback=None) |
333 | key = confparser.get("mfad","key",fallback=None) | 356 | key = confparser.get("mfad","key",fallback=None) |
@@ -340,6 +363,10 @@ def get_vars(args,confparser): | |||
340 | client_port = args.client_port | 363 | client_port = args.client_port |
341 | if args.pam_port != None: | 364 | if args.pam_port != None: |
342 | pam_port = args.pam_port | 365 | pam_port = args.pam_port |
366 | if args.client_tls_port != None: | ||
367 | client_tls_port = args.client_tls_port | ||
368 | if args.pam_tls_port != None: | ||
369 | pam_tls_port = args.pam_tls_port | ||
343 | if args.database != None: | 370 | if args.database != None: |
344 | database = args.database | 371 | database = args.database |
345 | if args.cert != None: | 372 | if args.cert != None: |
@@ -350,28 +377,49 @@ def get_vars(args,confparser): | |||
350 | ciphers = args.ciphers | 377 | ciphers = args.ciphers |
351 | 378 | ||
352 | # Exit if any value is null | 379 | # Exit if any value is null |
353 | if None in [bind_addr,client_port,pam_port,database,cert,key]: | 380 | if None in [bind_addr,database,cert,key]: |
354 | print("error: one or more items unspecified") | 381 | die("error: one or more items unspecified") |
355 | sys.exit(1) | 382 | |
383 | ports = [pam_port, client_port, pam_tls_port, client_tls_port] | ||
384 | for port in ports: | ||
385 | if port == None: | ||
386 | ports[ports.index(port)] = 0 | ||
356 | 387 | ||
357 | return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers | 388 | return bind_addr, ports, database, cert, key, ciphers |
358 | 389 | ||
359 | 390 | ||
360 | def main(): | 391 | def main(): |
361 | args = parse_arguments() | 392 | args = parse_arguments() |
362 | confparser = read_config(args.config) | 393 | confparser = read_config(args.config) |
363 | 394 | ||
364 | bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) | 395 | bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser) |
396 | pam_port = int(ports[0]) | ||
397 | client_port = int(ports[1]) | ||
398 | pam_tls_port = int(ports[2]) | ||
399 | client_tls_port = int(ports[3]) | ||
400 | |||
401 | if pam_port == 0 and pam_tls_port == 0: | ||
402 | die("error: not listening for any PAM connections") | ||
403 | if client_port == 0 and client_tls_port == 0: | ||
404 | die("error: not listening for any client connections") | ||
405 | |||
365 | 406 | ||
366 | if not os.path.exists(db): | 407 | if not os.path.exists(db): |
367 | print("Creating DB") | 408 | print("Creating DB") |
368 | create_db(db) | 409 | create_db(db) |
369 | 410 | ||
370 | context = get_tls_context(cert,key,ciphers) | 411 | context = get_tls_context(cert,key,ciphers) |
371 | clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) | 412 | |
372 | pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) | 413 | if pam_port != 0: |
373 | clients.start() | 414 | threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start() |
374 | pam.start() | 415 | if client_port != 0: |
416 | threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start() | ||
417 | if pam_tls_port != 0: | ||
418 | threading.Thread(target=listen_pam_tls, | ||
419 | args=(db,bind_addr,pam_tls_port,context)).start() | ||
420 | if client_tls_port != 0: | ||
421 | threading.Thread(target=listen_client_tls, | ||
422 | args=(db,bind_addr,client_tls_port,context)).start() | ||
375 | 423 | ||
376 | 424 | ||
377 | if __name__ == '__main__': | 425 | if __name__ == '__main__': |