diff options
| author | Sam Chudnick <sam@chudnick.com> | 2022-07-04 12:32:16 -0400 |
|---|---|---|
| committer | Sam Chudnick <sam@chudnick.com> | 2022-07-04 12:32:16 -0400 |
| commit | dbfb415edfbe1bc8db3a1272c28189785e623860 (patch) | |
| tree | c4f519975cacd13123dcb15bec0c807ceb00f0f7 | |
| parent | 755d7f5f94b720b028d085cf971c5935c130dec1 (diff) | |
Added options for certificate and key files
Added command line arguments and config file options to specify TLS
certificate and TLS private key files.
| -rwxr-xr-x | server/mfad.py | 40 |
1 files changed, 26 insertions, 14 deletions
diff --git a/server/mfad.py b/server/mfad.py index 18a048a..b9557bd 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
| @@ -52,7 +52,9 @@ def parse_arguments(): | |||
| 52 | parser.add_argument("--address",type=str,help="Bind Address") | 52 | parser.add_argument("--address",type=str,help="Bind Address") |
| 53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") | 53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") |
| 54 | parser.add_argument("--client-port",type=int,help="Port for client connections") | 54 | parser.add_argument("--client-port",type=int,help="Port for client connections") |
| 55 | parser.add_argument("--database",type=str,help="Path to alternate database file") | 55 | parser.add_argument("--database",type=str,help="Path to database file") |
| 56 | parser.add_argument("--cert",type=str,help="TLS certificate file") | ||
| 57 | parser.add_argument("--key",type=str,help="TLS private key file") | ||
| 56 | parser.add_argument("--config",type=str,help="Alternate config file location",\ | 58 | parser.add_argument("--config",type=str,help="Alternate config file location",\ |
| 57 | default="/etc/mfa/mfa.conf") | 59 | default="/etc/mfa/mfa.conf") |
| 58 | return parser.parse_args() | 60 | return parser.parse_args() |
| @@ -239,9 +241,9 @@ def handle_pam(db, conn, addr): | |||
| 239 | conn.send(str(decision).encode(FORMAT)) | 241 | conn.send(str(decision).encode(FORMAT)) |
| 240 | 242 | ||
| 241 | 243 | ||
| 242 | def listen_client(db, addr, port): | 244 | def listen_client(db, addr, port, cert, key): |
| 243 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | 245 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
| 244 | context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") | 246 | context.load_cert_chain(certfile=cert, keyfile=key) |
| 245 | with socket.create_server((addr, port)) as sock: | 247 | with socket.create_server((addr, port)) as sock: |
| 246 | with context.wrap_socket(sock, server_side=True) as tls_socket: | 248 | with context.wrap_socket(sock, server_side=True) as tls_socket: |
| 247 | while True: | 249 | while True: |
| @@ -253,9 +255,9 @@ def listen_client(db, addr, port): | |||
| 253 | print("client: ssl handshake error") | 255 | print("client: ssl handshake error") |
| 254 | 256 | ||
| 255 | 257 | ||
| 256 | def listen_pam(db, addr, port): | 258 | def listen_pam(db, addr, port, cert, key): |
| 257 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | 259 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
| 258 | context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") | 260 | context.load_cert_chain(certfile=cert, keyfile=key) |
| 259 | with socket.create_server((addr,port)) as sock: | 261 | with socket.create_server((addr,port)) as sock: |
| 260 | with context.wrap_socket(sock, server_side=True) as tls_socket: | 262 | with context.wrap_socket(sock, server_side=True) as tls_socket: |
| 261 | while True: | 263 | while True: |
| @@ -296,6 +298,8 @@ def get_vars(args,confparser): | |||
| 296 | client_port = None | 298 | client_port = None |
| 297 | pam_port = None | 299 | pam_port = None |
| 298 | database = None | 300 | database = None |
| 301 | cert = None | ||
| 302 | key = None | ||
| 299 | 303 | ||
| 300 | # Set values from config file first | 304 | # Set values from config file first |
| 301 | if confparser.has_section("mfad"): | 305 | if confparser.has_section("mfad"): |
| @@ -303,37 +307,45 @@ def get_vars(args,confparser): | |||
| 303 | client_port = confparser.get("mfad","client-port",fallback=None) | 307 | client_port = confparser.get("mfad","client-port",fallback=None) |
| 304 | pam_port = confparser.get("mfad","pam-port",fallback=None) | 308 | pam_port = confparser.get("mfad","pam-port",fallback=None) |
| 305 | database = confparser.get("mfad","database",fallback=None) | 309 | database = confparser.get("mfad","database",fallback=None) |
| 310 | cert = confparser.get("mfad","cert",fallback=None) | ||
| 311 | key = confparser.get("mfad","key",fallback=None) | ||
| 306 | 312 | ||
| 307 | # Let command line args overwrite any values | 313 | # Let command line args overwrite any values |
| 308 | if args.address: | 314 | if args.address != None: |
| 309 | bind_addr = args.address | 315 | bind_addr = args.address |
| 310 | if args.client_port: | 316 | if args.client_port != None: |
| 311 | client_port = args.client_port | 317 | client_port = args.client_port |
| 312 | if args.pam_port: | 318 | if args.pam_port != None: |
| 313 | pam_port = args.pam_port | 319 | pam_port = args.pam_port |
| 314 | if args.database: | 320 | if args.database != None: |
| 315 | database = args.database | 321 | database = args.database |
| 322 | if args.cert != None: | ||
| 323 | cert = args.cert | ||
| 324 | if args.key != None: | ||
| 325 | key = args.key | ||
| 316 | 326 | ||
| 317 | # Exit if any value is null | 327 | # Exit if any value is null |
| 318 | if None in [bind_addr,client_port,pam_port,database]: | 328 | if None in [bind_addr,client_port,pam_port,database,cert,key]: |
| 319 | print("error: one or more items unspecified") | 329 | print("error: one or more items unspecified") |
| 320 | sys.exit(1) | 330 | sys.exit(1) |
| 321 | 331 | ||
| 322 | return bind_addr, int(client_port), int(pam_port), database | 332 | return bind_addr, int(client_port), int(pam_port), database, cert, key |
| 323 | 333 | ||
| 324 | 334 | ||
| 325 | def main(): | 335 | def main(): |
| 326 | args = parse_arguments() | 336 | args = parse_arguments() |
| 327 | confparser = read_config(args.config) | 337 | confparser = read_config(args.config) |
| 328 | 338 | ||
| 329 | bind_addr, client_port, pam_port, db = get_vars(args,confparser) | 339 | bind_addr, client_port, pam_port, db, cert, key = get_vars(args,confparser) |
| 330 | 340 | ||
| 331 | if not os.path.exists(db): | 341 | if not os.path.exists(db): |
| 332 | print("Creating DB") | 342 | print("Creating DB") |
| 333 | create_db(db) | 343 | create_db(db) |
| 334 | 344 | ||
| 335 | clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port)) | 345 | clients = threading.Thread(target=listen_client, |
| 336 | pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port)) | 346 | args=(db, bind_addr,client_port,cert,key)) |
| 347 | pam = threading.Thread(target=listen_pam, | ||
| 348 | args=(db, bind_addr,pam_port,cert,key)) | ||
| 337 | clients.start() | 349 | clients.start() |
| 338 | pam.start() | 350 | pam.start() |
| 339 | 351 | ||
