diff options
Diffstat (limited to 'server')
-rwxr-xr-x | server/mfad.py | 74 |
1 files changed, 61 insertions, 13 deletions
diff --git a/server/mfad.py b/server/mfad.py index a2a783b..cc5073b 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
@@ -52,6 +52,10 @@ def parse_arguments(): | |||
52 | parser.add_argument("--address",type=str,help="Bind Address") | 52 | parser.add_argument("--address",type=str,help="Bind Address") |
53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") | 53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") |
54 | parser.add_argument("--client-port",type=int,help="Port for client connections") | 54 | parser.add_argument("--client-port",type=int,help="Port for client connections") |
55 | parser.add_argument("--pam-tls-port",type=int, | ||
56 | help="Port to listen for encrypted PAM requests") | ||
57 | parser.add_argument("--client-tls-port",type=int, | ||
58 | help="Port for encrypted client connections") | ||
55 | parser.add_argument("--database",type=str,help="Path to database file") | 59 | parser.add_argument("--database",type=str,help="Path to database file") |
56 | parser.add_argument("--cert",type=str,help="TLS certificate file") | 60 | parser.add_argument("--cert",type=str,help="TLS certificate file") |
57 | parser.add_argument("--key",type=str,help="TLS private key file") | 61 | parser.add_argument("--key",type=str,help="TLS private key file") |
@@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers): | |||
266 | 270 | ||
267 | return context | 271 | return context |
268 | 272 | ||
269 | def listen_client(db, addr, port, tls_context): | 273 | def listen_client_tls(db, addr, port, tls_context): |
270 | with socket.create_server((addr, port)) as sock: | 274 | with socket.create_server((addr, port)) as sock: |
271 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 275 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
272 | while True: | 276 | while True: |
@@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context): | |||
278 | print("client: ssl handshake error") | 282 | print("client: ssl handshake error") |
279 | 283 | ||
280 | 284 | ||
281 | def listen_pam(db, addr, port, tls_context): | 285 | def listen_pam_tls(db, addr, port, tls_context): |
282 | with socket.create_server((addr,port)) as sock: | 286 | with socket.create_server((addr,port)) as sock: |
283 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 287 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
284 | while True: | 288 | while True: |
@@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context): | |||
290 | print("pam: ssl handshake error") | 294 | print("pam: ssl handshake error") |
291 | 295 | ||
292 | 296 | ||
297 | def listen_client(db, addr, port): | ||
298 | with socket.create_server((addr, port)) as sock: | ||
299 | while True: | ||
300 | conn, addr = sock.accept() | ||
301 | thread = threading.Thread(target=handle_client,args=(db,conn,addr)) | ||
302 | thread.start() | ||
303 | |||
304 | |||
305 | def listen_pam(db, addr, port): | ||
306 | with socket.create_server((addr, port)) as sock: | ||
307 | while True: | ||
308 | conn, addr = sock.accept() | ||
309 | thread = threading.Thread(target=handle_pam,args=(db,conn,addr)) | ||
310 | thread.start() | ||
311 | |||
312 | |||
293 | ################################################################################ | 313 | ################################################################################ |
294 | 314 | ||
295 | def create_db(db): | 315 | def create_db(db): |
@@ -312,12 +332,13 @@ def create_db(db): | |||
312 | 332 | ||
313 | def get_vars(args,confparser): | 333 | def get_vars(args,confparser): |
314 | if not os.path.exists(args.config): | 334 | if not os.path.exists(args.config): |
315 | print("Unable to open config file") | 335 | die("Unable to open config file") |
316 | sys.exit(1) | ||
317 | 336 | ||
318 | bind_addr = None | 337 | bind_addr = None |
319 | client_port = None | 338 | client_port = None |
339 | client_tls_port = None | ||
320 | pam_port = None | 340 | pam_port = None |
341 | pam_tls_port = None | ||
321 | database = None | 342 | database = None |
322 | cert = None | 343 | cert = None |
323 | key = None | 344 | key = None |
@@ -328,6 +349,8 @@ def get_vars(args,confparser): | |||
328 | bind_addr = confparser.get("mfad","address",fallback=None) | 349 | bind_addr = confparser.get("mfad","address",fallback=None) |
329 | client_port = confparser.get("mfad","client-port",fallback=None) | 350 | client_port = confparser.get("mfad","client-port",fallback=None) |
330 | pam_port = confparser.get("mfad","pam-port",fallback=None) | 351 | pam_port = confparser.get("mfad","pam-port",fallback=None) |
352 | client_tls_port = confparser.get("mfad","client-tls-port",fallback=None) | ||
353 | pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None) | ||
331 | database = confparser.get("mfad","database",fallback=None) | 354 | database = confparser.get("mfad","database",fallback=None) |
332 | cert = confparser.get("mfad","cert",fallback=None) | 355 | cert = confparser.get("mfad","cert",fallback=None) |
333 | key = confparser.get("mfad","key",fallback=None) | 356 | key = confparser.get("mfad","key",fallback=None) |
@@ -340,6 +363,10 @@ def get_vars(args,confparser): | |||
340 | client_port = args.client_port | 363 | client_port = args.client_port |
341 | if args.pam_port != None: | 364 | if args.pam_port != None: |
342 | pam_port = args.pam_port | 365 | pam_port = args.pam_port |
366 | if args.client_tls_port != None: | ||
367 | client_tls_port = args.client_tls_port | ||
368 | if args.pam_tls_port != None: | ||
369 | pam_tls_port = args.pam_tls_port | ||
343 | if args.database != None: | 370 | if args.database != None: |
344 | database = args.database | 371 | database = args.database |
345 | if args.cert != None: | 372 | if args.cert != None: |
@@ -350,28 +377,49 @@ def get_vars(args,confparser): | |||
350 | ciphers = args.ciphers | 377 | ciphers = args.ciphers |
351 | 378 | ||
352 | # Exit if any value is null | 379 | # Exit if any value is null |
353 | if None in [bind_addr,client_port,pam_port,database,cert,key]: | 380 | if None in [bind_addr,database,cert,key]: |
354 | print("error: one or more items unspecified") | 381 | die("error: one or more items unspecified") |
355 | sys.exit(1) | 382 | |
383 | ports = [pam_port, client_port, pam_tls_port, client_tls_port] | ||
384 | for port in ports: | ||
385 | if port == None: | ||
386 | ports[ports.index(port)] = 0 | ||
356 | 387 | ||
357 | return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers | 388 | return bind_addr, ports, database, cert, key, ciphers |
358 | 389 | ||
359 | 390 | ||
360 | def main(): | 391 | def main(): |
361 | args = parse_arguments() | 392 | args = parse_arguments() |
362 | confparser = read_config(args.config) | 393 | confparser = read_config(args.config) |
363 | 394 | ||
364 | bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) | 395 | bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser) |
396 | pam_port = int(ports[0]) | ||
397 | client_port = int(ports[1]) | ||
398 | pam_tls_port = int(ports[2]) | ||
399 | client_tls_port = int(ports[3]) | ||
400 | |||
401 | if pam_port == 0 and pam_tls_port == 0: | ||
402 | die("error: not listening for any PAM connections") | ||
403 | if client_port == 0 and client_tls_port == 0: | ||
404 | die("error: not listening for any client connections") | ||
405 | |||
365 | 406 | ||
366 | if not os.path.exists(db): | 407 | if not os.path.exists(db): |
367 | print("Creating DB") | 408 | print("Creating DB") |
368 | create_db(db) | 409 | create_db(db) |
369 | 410 | ||
370 | context = get_tls_context(cert,key,ciphers) | 411 | context = get_tls_context(cert,key,ciphers) |
371 | clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) | 412 | |
372 | pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) | 413 | if pam_port != 0: |
373 | clients.start() | 414 | threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start() |
374 | pam.start() | 415 | if client_port != 0: |
416 | threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start() | ||
417 | if pam_tls_port != 0: | ||
418 | threading.Thread(target=listen_pam_tls, | ||
419 | args=(db,bind_addr,pam_tls_port,context)).start() | ||
420 | if client_tls_port != 0: | ||
421 | threading.Thread(target=listen_client_tls, | ||
422 | args=(db,bind_addr,client_tls_port,context)).start() | ||
375 | 423 | ||
376 | 424 | ||
377 | if __name__ == '__main__': | 425 | if __name__ == '__main__': |