summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xserver/mfad.py50
1 files changed, 37 insertions, 13 deletions
diff --git a/server/mfad.py b/server/mfad.py
index b9557bd..a2a783b 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -55,6 +55,7 @@ def parse_arguments():
55 parser.add_argument("--database",type=str,help="Path to database file") 55 parser.add_argument("--database",type=str,help="Path to database file")
56 parser.add_argument("--cert",type=str,help="TLS certificate file") 56 parser.add_argument("--cert",type=str,help="TLS certificate file")
57 parser.add_argument("--key",type=str,help="TLS private key file") 57 parser.add_argument("--key",type=str,help="TLS private key file")
58 parser.add_argument("--ciphers",type=str,help="TLS ciphers to use")
58 parser.add_argument("--config",type=str,help="Alternate config file location",\ 59 parser.add_argument("--config",type=str,help="Alternate config file location",\
59 default="/etc/mfa/mfa.conf") 60 default="/etc/mfa/mfa.conf")
60 return parser.parse_args() 61 return parser.parse_args()
@@ -241,11 +242,33 @@ def handle_pam(db, conn, addr):
241 conn.send(str(decision).encode(FORMAT)) 242 conn.send(str(decision).encode(FORMAT))
242 243
243 244
244def listen_client(db, addr, port, cert, key): 245def get_tls_context(cert, key, ciphers):
246 if not os.path.exists(cert):
247 die("error: cannot open cert file")
248 if not os.path.exists(key):
249 die("error: cannot open key file")
250
251 # Create context, load cert and key
245 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 252 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
246 context.load_cert_chain(certfile=cert, keyfile=key) 253 context.load_cert_chain(certfile=cert, keyfile=key)
254 context.minimum_version = ssl.TLSVersion.TLSv1_2
255 context.maximum_version = ssl.TLSVersion.TLSv1_3
256
257 if ciphers == None:
258 # Mozilla intermediate compatibility
259 ciphers = "ECDHE+ECDSA+AESGCM:ECDHE+aRSA+AESGCM:ECDHE+ECDSA+CHACHA20:ECDHE+aRSA+CHACHA20:DHE+aRSA+AESGCM:!aNULL:!eNULL"
260
261 # Set ciphers
262 try:
263 context.set_ciphers(ciphers)
264 except ssl.SSLError:
265 die("error: invalid cipherlist")
266
267 return context
268
269def listen_client(db, addr, port, tls_context):
247 with socket.create_server((addr, port)) as sock: 270 with socket.create_server((addr, port)) as sock:
248 with context.wrap_socket(sock, server_side=True) as tls_socket: 271 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
249 while True: 272 while True:
250 try: 273 try:
251 conn, addr = tls_socket.accept() 274 conn, addr = tls_socket.accept()
@@ -255,11 +278,9 @@ def listen_client(db, addr, port, cert, key):
255 print("client: ssl handshake error") 278 print("client: ssl handshake error")
256 279
257 280
258def listen_pam(db, addr, port, cert, key): 281def listen_pam(db, addr, port, tls_context):
259 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
260 context.load_cert_chain(certfile=cert, keyfile=key)
261 with socket.create_server((addr,port)) as sock: 282 with socket.create_server((addr,port)) as sock:
262 with context.wrap_socket(sock, server_side=True) as tls_socket: 283 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
263 while True: 284 while True:
264 try: 285 try:
265 conn, addr = tls_socket.accept() 286 conn, addr = tls_socket.accept()
@@ -300,6 +321,7 @@ def get_vars(args,confparser):
300 database = None 321 database = None
301 cert = None 322 cert = None
302 key = None 323 key = None
324 ciphers = None
303 325
304 # Set values from config file first 326 # Set values from config file first
305 if confparser.has_section("mfad"): 327 if confparser.has_section("mfad"):
@@ -309,6 +331,7 @@ def get_vars(args,confparser):
309 database = confparser.get("mfad","database",fallback=None) 331 database = confparser.get("mfad","database",fallback=None)
310 cert = confparser.get("mfad","cert",fallback=None) 332 cert = confparser.get("mfad","cert",fallback=None)
311 key = confparser.get("mfad","key",fallback=None) 333 key = confparser.get("mfad","key",fallback=None)
334 ciphers = confparser.get("mfad","ciphers",fallback=None)
312 335
313 # Let command line args overwrite any values 336 # Let command line args overwrite any values
314 if args.address != None: 337 if args.address != None:
@@ -323,29 +346,30 @@ def get_vars(args,confparser):
323 cert = args.cert 346 cert = args.cert
324 if args.key != None: 347 if args.key != None:
325 key = args.key 348 key = args.key
349 if args.ciphers != None:
350 ciphers = args.ciphers
326 351
327 # Exit if any value is null 352 # Exit if any value is null
328 if None in [bind_addr,client_port,pam_port,database,cert,key]: 353 if None in [bind_addr,client_port,pam_port,database,cert,key]:
329 print("error: one or more items unspecified") 354 print("error: one or more items unspecified")
330 sys.exit(1) 355 sys.exit(1)
331 356
332 return bind_addr, int(client_port), int(pam_port), database, cert, key 357 return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers
333 358
334 359
335def main(): 360def main():
336 args = parse_arguments() 361 args = parse_arguments()
337 confparser = read_config(args.config) 362 confparser = read_config(args.config)
338 363
339 bind_addr, client_port, pam_port, db, cert, key = get_vars(args,confparser) 364 bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser)
340 365
341 if not os.path.exists(db): 366 if not os.path.exists(db):
342 print("Creating DB") 367 print("Creating DB")
343 create_db(db) 368 create_db(db)
344 369
345 clients = threading.Thread(target=listen_client, 370 context = get_tls_context(cert,key,ciphers)
346 args=(db, bind_addr,client_port,cert,key)) 371 clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context))
347 pam = threading.Thread(target=listen_pam, 372 pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context))
348 args=(db, bind_addr,pam_port,cert,key))
349 clients.start() 373 clients.start()
350 pam.start() 374 pam.start()
351 375