From 755d7f5f94b720b028d085cf971c5935c130dec1 Mon Sep 17 00:00:00 2001 From: Sam Chudnick Date: Mon, 4 Jul 2022 12:24:59 -0400 Subject: Implemented TLS encrypted connections Implemented TLS encrypted connections. Added command line argument and configuration file option to accept invalid (self-signed) certificates. Fixed a couple of unrelated issues. --- client/client.py | 32 ++++++++++++++++++++++++-------- pam/pam_mfa.py | 30 +++++++++++++++++++++++------- server/mfad.py | 38 ++++++++++++++++++++++++++------------ 3 files changed, 73 insertions(+), 27 deletions(-) diff --git a/client/client.py b/client/client.py index 70d85a0..cc22d0b 100755 --- a/client/client.py +++ b/client/client.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import socket +import ssl import time import argparse import sys @@ -25,6 +26,8 @@ def parse_arguments(): parser.add_argument("--config",type=str,help="Path to config file",\ default="/etc/mfa/mfa.conf") parser.add_argument("--key",type=str,help="Client connection key") + parser.add_argument("--insecure",action="store_true",\ + help="Accept invalid TLS certificates") return parser.parse_args() def prompt_user(prompt): @@ -34,7 +37,7 @@ def prompt_user(prompt): return result -def init_connection(mfa_server, client_port, client_key): +def init_connection(mfa_server, client_port, client_key, insecure=False): # Attempts to connect to MFA server with provided address,port, and key. # Repeats attempt once a seconds until timeout is reached. # Returns socket or None if unable to connect @@ -42,9 +45,16 @@ def init_connection(mfa_server, client_port, client_key): timeout = 0 timeout_length = 5 sleep_length = 1 + context = ssl.create_default_context() + if insecure: + context.check_hostname = False + context.verify_mode = 0 while connection == None and timeout < timeout_length: try: - connection = socket.create_connection((mfa_server,client_port)) + #connection = socket.create_connection((mfa_server,client_port)) + connection = context.wrap_socket(socket.socket(socket.AF_INET), + server_hostname=mfa_server) + connection.connect((mfa_server,int(client_port))) connection.send(client_key.encode(FORMAT)) response = connection.recv(ACK_LENGTH).decode(FORMAT) if response == ACK_MESSAGE: @@ -55,6 +65,8 @@ def init_connection(mfa_server, client_port, client_key): except ConnectionError: time.sleep(sleep_length) timeout += sleep_length + except ssl.SSLCertVerificationError: + die("error: server presented invalid certificate") return connection @@ -72,27 +84,31 @@ def get_vars(args,confparser): server = None port = None key = None + insecure = None # Set values from config file first if confparser.has_section("client"): server = confparser.get("client","server",fallback=None) port = confparser.get("client","port",fallback=None) key = confparser.get("client","key",fallback=None) + insecure = bool(confparser.get("client","insecure",fallback=False)) # Let command line args overwrite any values - if args.server: + if args.server != None: server = args.server - if args.port: + if args.port != None: port = args.port - if args.key: + if args.key != None: key = args.key + if args.insecure: + insecure = args.insecure # Exit if any value is null if None in [server,port,key]: print("error: one or more items unspecified") sys.exit(1) - return server,port,key + return server,port,key,insecure def main(): @@ -100,7 +116,7 @@ def main(): args = parse_arguments() confparser = read_config(args.config) - mfa_server,client_port,client_key = get_vars(args,confparser) + mfa_server,client_port,client_key,insecure = get_vars(args,confparser) # Exit if invalid key is provided if len(client_key) != KEY_LENGTH: @@ -108,7 +124,7 @@ def main(): sys.exit(1) # Open connection to server - conn = init_connection(mfa_server,client_port,client_key) + conn = init_connection(mfa_server,client_port,client_key,insecure) if conn == None: print("timed out attempting to connect to server") sys.exit(1) diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py index 85d0a82..5a5a112 100755 --- a/pam/pam_mfa.py +++ b/pam/pam_mfa.py @@ -1,5 +1,6 @@ #!/usr/bin/python3 import socket +import ssl import argparse import time import sys @@ -33,9 +34,11 @@ def parse_arguments(): default="/etc/mfa/mfa.conf") parser.add_argument("--server",type=str,help="MFA server address") parser.add_argument("--port",type=str,help="MFA server PAM connection port") + parser.add_argument("--insecure",action="store_true", + help="Accept invalid TLS certificates") return parser.parse_args() -def init_connection(mfa_server, pam_port): +def init_connection(mfa_server, pam_port, insecure): # Attempts to connect to MFA server with provided address and port # Repeats connection attempts once per second until timeout is reached # Returns the socket if connection was successful or None otherwise @@ -43,13 +46,22 @@ def init_connection(mfa_server, pam_port): timeout = 0 timeout_length = 5 sleep_length = 1 + context = ssl.create_default_context() + if insecure: + context.check_hostname = False + context.verify_mode = 0 while connection == None and timeout < timeout_length: try: - connection = socket.create_connection((mfa_server,pam_port)) + #connection = socket.create_connection((mfa_server,client_port)) + connection = context.wrap_socket(socket.socket(socket.AF_INET), + server_hostname=mfa_server) + connection.connect((mfa_server,int(pam_port))) return connection except (ConnectionError,ConnectionRefusedError): time.sleep(sleep_length) timeout += sleep_length + except ssl.SSLCertVerificationError: + die("error: server presented invalid certificate") return None @@ -82,24 +94,28 @@ def get_vars(args,confparser): server = None port = None + insecure = None # Set values from config file first if confparser.has_section("pam"): server = confparser.get("pam","server",fallback=None) port = confparser.get("pam","port",fallback=None) + insecure = bool(confparser.get("pam","insecure",fallback=False)) # Let command line args overwrite any values - if args.server: + if args.server != None: server = args.server - if args.port: + if args.port != None: port = args.port + if args.insecure: + insecure = args.insecure # Exit if any value is null if None in [server,port]: print("error: one or more items unspecified") sys.exit(1) - return server,port + return server,port,insecure def main(): @@ -109,7 +125,7 @@ def main(): # Get arguments args = parse_arguments() confparser = read_config(args.config) - mfa_server,pam_port = get_vars(args,confparser) + mfa_server,pam_port,insecure = get_vars(args,confparser) user = args.user service = args.service @@ -130,7 +146,7 @@ def main(): # Initalize connection to MFA server. Quit if unable to connect. - connection = init_connection(mfa_server,pam_port) + connection = init_connection(mfa_server,pam_port,insecure) if connection == None: print(failed) sys.exit(1) diff --git a/server/mfad.py b/server/mfad.py index 17d2585..18a048a 100755 --- a/server/mfad.py +++ b/server/mfad.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import socket +import ssl import os import sys import time @@ -211,8 +212,9 @@ def parse_pam_data(data): def handle_pam(db, conn, addr): # Get request and data from PAM module header = conn.recv(HEADER_LENGTH).decode(FORMAT) - if header == "": - die("error: lost connection to pam module") + if len(header) != HEADER_LENGTH: + conn.close() + die("error: invalid data from PAM module") data_length = int(header) pam_data = conn.recv(data_length).decode(FORMAT) print("Got pam_data: " + pam_data) @@ -238,19 +240,31 @@ def handle_pam(db, conn, addr): def listen_client(db, addr, port): - with socket.create_server((addr, port)) as server: - while True: - conn, addr = server.accept() - thread = threading.Thread(target=handle_client,args=(db, conn,addr)) - thread.start() + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") + with socket.create_server((addr, port)) as sock: + with context.wrap_socket(sock, server_side=True) as tls_socket: + while True: + try: + conn, addr = tls_socket.accept() + thread = threading.Thread(target=handle_client,args=(db,conn,addr)) + thread.start() + except ssl.SSLError: + print("client: ssl handshake error") def listen_pam(db, addr, port): - with socket.create_server((addr,port)) as pam_server: - while True: - conn, addr = pam_server.accept() - thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) - thread.start() + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") + with socket.create_server((addr,port)) as sock: + with context.wrap_socket(sock, server_side=True) as tls_socket: + while True: + try: + conn, addr = tls_socket.accept() + thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) + thread.start() + except ssl.SSLError: + print("pam: ssl handshake error") ################################################################################ -- cgit v1.2.3