summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorSam Chudnick <sam@chudnick.com>2022-07-04 12:32:16 -0400
committerSam Chudnick <sam@chudnick.com>2022-07-04 12:32:16 -0400
commitdbfb415edfbe1bc8db3a1272c28189785e623860 (patch)
treec4f519975cacd13123dcb15bec0c807ceb00f0f7 /server
parent755d7f5f94b720b028d085cf971c5935c130dec1 (diff)
Added options for certificate and key files
Added command line arguments and config file options to specify TLS certificate and TLS private key files.
Diffstat (limited to 'server')
-rwxr-xr-xserver/mfad.py40
1 files changed, 26 insertions, 14 deletions
diff --git a/server/mfad.py b/server/mfad.py
index 18a048a..b9557bd 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -52,7 +52,9 @@ def parse_arguments():
52 parser.add_argument("--address",type=str,help="Bind Address") 52 parser.add_argument("--address",type=str,help="Bind Address")
53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") 53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests")
54 parser.add_argument("--client-port",type=int,help="Port for client connections") 54 parser.add_argument("--client-port",type=int,help="Port for client connections")
55 parser.add_argument("--database",type=str,help="Path to alternate database file") 55 parser.add_argument("--database",type=str,help="Path to database file")
56 parser.add_argument("--cert",type=str,help="TLS certificate file")
57 parser.add_argument("--key",type=str,help="TLS private key file")
56 parser.add_argument("--config",type=str,help="Alternate config file location",\ 58 parser.add_argument("--config",type=str,help="Alternate config file location",\
57 default="/etc/mfa/mfa.conf") 59 default="/etc/mfa/mfa.conf")
58 return parser.parse_args() 60 return parser.parse_args()
@@ -239,9 +241,9 @@ def handle_pam(db, conn, addr):
239 conn.send(str(decision).encode(FORMAT)) 241 conn.send(str(decision).encode(FORMAT))
240 242
241 243
242def listen_client(db, addr, port): 244def listen_client(db, addr, port, cert, key):
243 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 245 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
244 context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") 246 context.load_cert_chain(certfile=cert, keyfile=key)
245 with socket.create_server((addr, port)) as sock: 247 with socket.create_server((addr, port)) as sock:
246 with context.wrap_socket(sock, server_side=True) as tls_socket: 248 with context.wrap_socket(sock, server_side=True) as tls_socket:
247 while True: 249 while True:
@@ -253,9 +255,9 @@ def listen_client(db, addr, port):
253 print("client: ssl handshake error") 255 print("client: ssl handshake error")
254 256
255 257
256def listen_pam(db, addr, port): 258def listen_pam(db, addr, port, cert, key):
257 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 259 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
258 context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") 260 context.load_cert_chain(certfile=cert, keyfile=key)
259 with socket.create_server((addr,port)) as sock: 261 with socket.create_server((addr,port)) as sock:
260 with context.wrap_socket(sock, server_side=True) as tls_socket: 262 with context.wrap_socket(sock, server_side=True) as tls_socket:
261 while True: 263 while True:
@@ -296,6 +298,8 @@ def get_vars(args,confparser):
296 client_port = None 298 client_port = None
297 pam_port = None 299 pam_port = None
298 database = None 300 database = None
301 cert = None
302 key = None
299 303
300 # Set values from config file first 304 # Set values from config file first
301 if confparser.has_section("mfad"): 305 if confparser.has_section("mfad"):
@@ -303,37 +307,45 @@ def get_vars(args,confparser):
303 client_port = confparser.get("mfad","client-port",fallback=None) 307 client_port = confparser.get("mfad","client-port",fallback=None)
304 pam_port = confparser.get("mfad","pam-port",fallback=None) 308 pam_port = confparser.get("mfad","pam-port",fallback=None)
305 database = confparser.get("mfad","database",fallback=None) 309 database = confparser.get("mfad","database",fallback=None)
310 cert = confparser.get("mfad","cert",fallback=None)
311 key = confparser.get("mfad","key",fallback=None)
306 312
307 # Let command line args overwrite any values 313 # Let command line args overwrite any values
308 if args.address: 314 if args.address != None:
309 bind_addr = args.address 315 bind_addr = args.address
310 if args.client_port: 316 if args.client_port != None:
311 client_port = args.client_port 317 client_port = args.client_port
312 if args.pam_port: 318 if args.pam_port != None:
313 pam_port = args.pam_port 319 pam_port = args.pam_port
314 if args.database: 320 if args.database != None:
315 database = args.database 321 database = args.database
322 if args.cert != None:
323 cert = args.cert
324 if args.key != None:
325 key = args.key
316 326
317 # Exit if any value is null 327 # Exit if any value is null
318 if None in [bind_addr,client_port,pam_port,database]: 328 if None in [bind_addr,client_port,pam_port,database,cert,key]:
319 print("error: one or more items unspecified") 329 print("error: one or more items unspecified")
320 sys.exit(1) 330 sys.exit(1)
321 331
322 return bind_addr, int(client_port), int(pam_port), database 332 return bind_addr, int(client_port), int(pam_port), database, cert, key
323 333
324 334
325def main(): 335def main():
326 args = parse_arguments() 336 args = parse_arguments()
327 confparser = read_config(args.config) 337 confparser = read_config(args.config)
328 338
329 bind_addr, client_port, pam_port, db = get_vars(args,confparser) 339 bind_addr, client_port, pam_port, db, cert, key = get_vars(args,confparser)
330 340
331 if not os.path.exists(db): 341 if not os.path.exists(db):
332 print("Creating DB") 342 print("Creating DB")
333 create_db(db) 343 create_db(db)
334 344
335 clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port)) 345 clients = threading.Thread(target=listen_client,
336 pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port)) 346 args=(db, bind_addr,client_port,cert,key))
347 pam = threading.Thread(target=listen_pam,
348 args=(db, bind_addr,pam_port,cert,key))
337 clients.start() 349 clients.start()
338 pam.start() 350 pam.start()
339 351