summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rwxr-xr-xserver/mfad.py74
1 files changed, 61 insertions, 13 deletions
diff --git a/server/mfad.py b/server/mfad.py
index a2a783b..cc5073b 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -52,6 +52,10 @@ def parse_arguments():
52 parser.add_argument("--address",type=str,help="Bind Address") 52 parser.add_argument("--address",type=str,help="Bind Address")
53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") 53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests")
54 parser.add_argument("--client-port",type=int,help="Port for client connections") 54 parser.add_argument("--client-port",type=int,help="Port for client connections")
55 parser.add_argument("--pam-tls-port",type=int,
56 help="Port to listen for encrypted PAM requests")
57 parser.add_argument("--client-tls-port",type=int,
58 help="Port for encrypted client connections")
55 parser.add_argument("--database",type=str,help="Path to database file") 59 parser.add_argument("--database",type=str,help="Path to database file")
56 parser.add_argument("--cert",type=str,help="TLS certificate file") 60 parser.add_argument("--cert",type=str,help="TLS certificate file")
57 parser.add_argument("--key",type=str,help="TLS private key file") 61 parser.add_argument("--key",type=str,help="TLS private key file")
@@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers):
266 270
267 return context 271 return context
268 272
269def listen_client(db, addr, port, tls_context): 273def listen_client_tls(db, addr, port, tls_context):
270 with socket.create_server((addr, port)) as sock: 274 with socket.create_server((addr, port)) as sock:
271 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 275 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
272 while True: 276 while True:
@@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context):
278 print("client: ssl handshake error") 282 print("client: ssl handshake error")
279 283
280 284
281def listen_pam(db, addr, port, tls_context): 285def listen_pam_tls(db, addr, port, tls_context):
282 with socket.create_server((addr,port)) as sock: 286 with socket.create_server((addr,port)) as sock:
283 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 287 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
284 while True: 288 while True:
@@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context):
290 print("pam: ssl handshake error") 294 print("pam: ssl handshake error")
291 295
292 296
297def listen_client(db, addr, port):
298 with socket.create_server((addr, port)) as sock:
299 while True:
300 conn, addr = sock.accept()
301 thread = threading.Thread(target=handle_client,args=(db,conn,addr))
302 thread.start()
303
304
305def listen_pam(db, addr, port):
306 with socket.create_server((addr, port)) as sock:
307 while True:
308 conn, addr = sock.accept()
309 thread = threading.Thread(target=handle_pam,args=(db,conn,addr))
310 thread.start()
311
312
293################################################################################ 313################################################################################
294 314
295def create_db(db): 315def create_db(db):
@@ -312,12 +332,13 @@ def create_db(db):
312 332
313def get_vars(args,confparser): 333def get_vars(args,confparser):
314 if not os.path.exists(args.config): 334 if not os.path.exists(args.config):
315 print("Unable to open config file") 335 die("Unable to open config file")
316 sys.exit(1)
317 336
318 bind_addr = None 337 bind_addr = None
319 client_port = None 338 client_port = None
339 client_tls_port = None
320 pam_port = None 340 pam_port = None
341 pam_tls_port = None
321 database = None 342 database = None
322 cert = None 343 cert = None
323 key = None 344 key = None
@@ -328,6 +349,8 @@ def get_vars(args,confparser):
328 bind_addr = confparser.get("mfad","address",fallback=None) 349 bind_addr = confparser.get("mfad","address",fallback=None)
329 client_port = confparser.get("mfad","client-port",fallback=None) 350 client_port = confparser.get("mfad","client-port",fallback=None)
330 pam_port = confparser.get("mfad","pam-port",fallback=None) 351 pam_port = confparser.get("mfad","pam-port",fallback=None)
352 client_tls_port = confparser.get("mfad","client-tls-port",fallback=None)
353 pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None)
331 database = confparser.get("mfad","database",fallback=None) 354 database = confparser.get("mfad","database",fallback=None)
332 cert = confparser.get("mfad","cert",fallback=None) 355 cert = confparser.get("mfad","cert",fallback=None)
333 key = confparser.get("mfad","key",fallback=None) 356 key = confparser.get("mfad","key",fallback=None)
@@ -340,6 +363,10 @@ def get_vars(args,confparser):
340 client_port = args.client_port 363 client_port = args.client_port
341 if args.pam_port != None: 364 if args.pam_port != None:
342 pam_port = args.pam_port 365 pam_port = args.pam_port
366 if args.client_tls_port != None:
367 client_tls_port = args.client_tls_port
368 if args.pam_tls_port != None:
369 pam_tls_port = args.pam_tls_port
343 if args.database != None: 370 if args.database != None:
344 database = args.database 371 database = args.database
345 if args.cert != None: 372 if args.cert != None:
@@ -350,28 +377,49 @@ def get_vars(args,confparser):
350 ciphers = args.ciphers 377 ciphers = args.ciphers
351 378
352 # Exit if any value is null 379 # Exit if any value is null
353 if None in [bind_addr,client_port,pam_port,database,cert,key]: 380 if None in [bind_addr,database,cert,key]:
354 print("error: one or more items unspecified") 381 die("error: one or more items unspecified")
355 sys.exit(1) 382
383 ports = [pam_port, client_port, pam_tls_port, client_tls_port]
384 for port in ports:
385 if port == None:
386 ports[ports.index(port)] = 0
356 387
357 return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers 388 return bind_addr, ports, database, cert, key, ciphers
358 389
359 390
360def main(): 391def main():
361 args = parse_arguments() 392 args = parse_arguments()
362 confparser = read_config(args.config) 393 confparser = read_config(args.config)
363 394
364 bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) 395 bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser)
396 pam_port = int(ports[0])
397 client_port = int(ports[1])
398 pam_tls_port = int(ports[2])
399 client_tls_port = int(ports[3])
400
401 if pam_port == 0 and pam_tls_port == 0:
402 die("error: not listening for any PAM connections")
403 if client_port == 0 and client_tls_port == 0:
404 die("error: not listening for any client connections")
405
365 406
366 if not os.path.exists(db): 407 if not os.path.exists(db):
367 print("Creating DB") 408 print("Creating DB")
368 create_db(db) 409 create_db(db)
369 410
370 context = get_tls_context(cert,key,ciphers) 411 context = get_tls_context(cert,key,ciphers)
371 clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) 412
372 pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) 413 if pam_port != 0:
373 clients.start() 414 threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start()
374 pam.start() 415 if client_port != 0:
416 threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start()
417 if pam_tls_port != 0:
418 threading.Thread(target=listen_pam_tls,
419 args=(db,bind_addr,pam_tls_port,context)).start()
420 if client_tls_port != 0:
421 threading.Thread(target=listen_client_tls,
422 args=(db,bind_addr,client_tls_port,context)).start()
375 423
376 424
377if __name__ == '__main__': 425if __name__ == '__main__':