diff options
-rwxr-xr-x | server/mfad.py | 50 |
1 files changed, 37 insertions, 13 deletions
diff --git a/server/mfad.py b/server/mfad.py index b9557bd..a2a783b 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
@@ -55,6 +55,7 @@ def parse_arguments(): | |||
55 | parser.add_argument("--database",type=str,help="Path to database file") | 55 | parser.add_argument("--database",type=str,help="Path to database file") |
56 | parser.add_argument("--cert",type=str,help="TLS certificate file") | 56 | parser.add_argument("--cert",type=str,help="TLS certificate file") |
57 | parser.add_argument("--key",type=str,help="TLS private key file") | 57 | parser.add_argument("--key",type=str,help="TLS private key file") |
58 | parser.add_argument("--ciphers",type=str,help="TLS ciphers to use") | ||
58 | parser.add_argument("--config",type=str,help="Alternate config file location",\ | 59 | parser.add_argument("--config",type=str,help="Alternate config file location",\ |
59 | default="/etc/mfa/mfa.conf") | 60 | default="/etc/mfa/mfa.conf") |
60 | return parser.parse_args() | 61 | return parser.parse_args() |
@@ -241,11 +242,33 @@ def handle_pam(db, conn, addr): | |||
241 | conn.send(str(decision).encode(FORMAT)) | 242 | conn.send(str(decision).encode(FORMAT)) |
242 | 243 | ||
243 | 244 | ||
244 | def listen_client(db, addr, port, cert, key): | 245 | def get_tls_context(cert, key, ciphers): |
246 | if not os.path.exists(cert): | ||
247 | die("error: cannot open cert file") | ||
248 | if not os.path.exists(key): | ||
249 | die("error: cannot open key file") | ||
250 | |||
251 | # Create context, load cert and key | ||
245 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | 252 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
246 | context.load_cert_chain(certfile=cert, keyfile=key) | 253 | context.load_cert_chain(certfile=cert, keyfile=key) |
254 | context.minimum_version = ssl.TLSVersion.TLSv1_2 | ||
255 | context.maximum_version = ssl.TLSVersion.TLSv1_3 | ||
256 | |||
257 | if ciphers == None: | ||
258 | # Mozilla intermediate compatibility | ||
259 | ciphers = "ECDHE+ECDSA+AESGCM:ECDHE+aRSA+AESGCM:ECDHE+ECDSA+CHACHA20:ECDHE+aRSA+CHACHA20:DHE+aRSA+AESGCM:!aNULL:!eNULL" | ||
260 | |||
261 | # Set ciphers | ||
262 | try: | ||
263 | context.set_ciphers(ciphers) | ||
264 | except ssl.SSLError: | ||
265 | die("error: invalid cipherlist") | ||
266 | |||
267 | return context | ||
268 | |||
269 | def listen_client(db, addr, port, tls_context): | ||
247 | with socket.create_server((addr, port)) as sock: | 270 | with socket.create_server((addr, port)) as sock: |
248 | with context.wrap_socket(sock, server_side=True) as tls_socket: | 271 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
249 | while True: | 272 | while True: |
250 | try: | 273 | try: |
251 | conn, addr = tls_socket.accept() | 274 | conn, addr = tls_socket.accept() |
@@ -255,11 +278,9 @@ def listen_client(db, addr, port, cert, key): | |||
255 | print("client: ssl handshake error") | 278 | print("client: ssl handshake error") |
256 | 279 | ||
257 | 280 | ||
258 | def listen_pam(db, addr, port, cert, key): | 281 | def listen_pam(db, addr, port, tls_context): |
259 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | ||
260 | context.load_cert_chain(certfile=cert, keyfile=key) | ||
261 | with socket.create_server((addr,port)) as sock: | 282 | with socket.create_server((addr,port)) as sock: |
262 | with context.wrap_socket(sock, server_side=True) as tls_socket: | 283 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
263 | while True: | 284 | while True: |
264 | try: | 285 | try: |
265 | conn, addr = tls_socket.accept() | 286 | conn, addr = tls_socket.accept() |
@@ -300,6 +321,7 @@ def get_vars(args,confparser): | |||
300 | database = None | 321 | database = None |
301 | cert = None | 322 | cert = None |
302 | key = None | 323 | key = None |
324 | ciphers = None | ||
303 | 325 | ||
304 | # Set values from config file first | 326 | # Set values from config file first |
305 | if confparser.has_section("mfad"): | 327 | if confparser.has_section("mfad"): |
@@ -309,6 +331,7 @@ def get_vars(args,confparser): | |||
309 | database = confparser.get("mfad","database",fallback=None) | 331 | database = confparser.get("mfad","database",fallback=None) |
310 | cert = confparser.get("mfad","cert",fallback=None) | 332 | cert = confparser.get("mfad","cert",fallback=None) |
311 | key = confparser.get("mfad","key",fallback=None) | 333 | key = confparser.get("mfad","key",fallback=None) |
334 | ciphers = confparser.get("mfad","ciphers",fallback=None) | ||
312 | 335 | ||
313 | # Let command line args overwrite any values | 336 | # Let command line args overwrite any values |
314 | if args.address != None: | 337 | if args.address != None: |
@@ -323,29 +346,30 @@ def get_vars(args,confparser): | |||
323 | cert = args.cert | 346 | cert = args.cert |
324 | if args.key != None: | 347 | if args.key != None: |
325 | key = args.key | 348 | key = args.key |
349 | if args.ciphers != None: | ||
350 | ciphers = args.ciphers | ||
326 | 351 | ||
327 | # Exit if any value is null | 352 | # Exit if any value is null |
328 | if None in [bind_addr,client_port,pam_port,database,cert,key]: | 353 | if None in [bind_addr,client_port,pam_port,database,cert,key]: |
329 | print("error: one or more items unspecified") | 354 | print("error: one or more items unspecified") |
330 | sys.exit(1) | 355 | sys.exit(1) |
331 | 356 | ||
332 | return bind_addr, int(client_port), int(pam_port), database, cert, key | 357 | return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers |
333 | 358 | ||
334 | 359 | ||
335 | def main(): | 360 | def main(): |
336 | args = parse_arguments() | 361 | args = parse_arguments() |
337 | confparser = read_config(args.config) | 362 | confparser = read_config(args.config) |
338 | 363 | ||
339 | bind_addr, client_port, pam_port, db, cert, key = get_vars(args,confparser) | 364 | bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) |
340 | 365 | ||
341 | if not os.path.exists(db): | 366 | if not os.path.exists(db): |
342 | print("Creating DB") | 367 | print("Creating DB") |
343 | create_db(db) | 368 | create_db(db) |
344 | 369 | ||
345 | clients = threading.Thread(target=listen_client, | 370 | context = get_tls_context(cert,key,ciphers) |
346 | args=(db, bind_addr,client_port,cert,key)) | 371 | clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) |
347 | pam = threading.Thread(target=listen_pam, | 372 | pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) |
348 | args=(db, bind_addr,pam_port,cert,key)) | ||
349 | clients.start() | 373 | clients.start() |
350 | pam.start() | 374 | pam.start() |
351 | 375 | ||