diff options
| author | Sam Chudnick <sam@chudnick.com> | 2022-07-04 12:24:59 -0400 | 
|---|---|---|
| committer | Sam Chudnick <sam@chudnick.com> | 2022-07-04 12:24:59 -0400 | 
| commit | 755d7f5f94b720b028d085cf971c5935c130dec1 (patch) | |
| tree | f015e8929563e5302d2ba8e2ee7215d1231debdd | |
| parent | 11a4a5edb9f0e22fe8355291942ed03c9765ced5 (diff) | |
Implemented TLS encrypted connections
Implemented TLS encrypted connections. Added command line argument and
configuration file option to accept invalid (self-signed) certificates.
Fixed a couple of unrelated issues.
| -rwxr-xr-x | client/client.py | 32 | ||||
| -rwxr-xr-x | pam/pam_mfa.py | 30 | ||||
| -rwxr-xr-x | server/mfad.py | 38 | 
3 files changed, 73 insertions, 27 deletions
| diff --git a/client/client.py b/client/client.py index 70d85a0..cc22d0b 100755 --- a/client/client.py +++ b/client/client.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 | 
| 2 | 2 | ||
| 3 | import socket | 3 | import socket | 
| 4 | import ssl | ||
| 4 | import time | 5 | import time | 
| 5 | import argparse | 6 | import argparse | 
| 6 | import sys | 7 | import sys | 
| @@ -25,6 +26,8 @@ def parse_arguments(): | |||
| 25 | parser.add_argument("--config",type=str,help="Path to config file",\ | 26 | parser.add_argument("--config",type=str,help="Path to config file",\ | 
| 26 | default="/etc/mfa/mfa.conf") | 27 | default="/etc/mfa/mfa.conf") | 
| 27 | parser.add_argument("--key",type=str,help="Client connection key") | 28 | parser.add_argument("--key",type=str,help="Client connection key") | 
| 29 | parser.add_argument("--insecure",action="store_true",\ | ||
| 30 | help="Accept invalid TLS certificates") | ||
| 28 | return parser.parse_args() | 31 | return parser.parse_args() | 
| 29 | 32 | ||
| 30 | def prompt_user(prompt): | 33 | def prompt_user(prompt): | 
| @@ -34,7 +37,7 @@ def prompt_user(prompt): | |||
| 34 | return result | 37 | return result | 
| 35 | 38 | ||
| 36 | 39 | ||
| 37 | def init_connection(mfa_server, client_port, client_key): | 40 | def init_connection(mfa_server, client_port, client_key, insecure=False): | 
| 38 | # Attempts to connect to MFA server with provided address,port, and key. | 41 | # Attempts to connect to MFA server with provided address,port, and key. | 
| 39 | # Repeats attempt once a seconds until timeout is reached. | 42 | # Repeats attempt once a seconds until timeout is reached. | 
| 40 | # Returns socket or None if unable to connect | 43 | # Returns socket or None if unable to connect | 
| @@ -42,9 +45,16 @@ def init_connection(mfa_server, client_port, client_key): | |||
| 42 | timeout = 0 | 45 | timeout = 0 | 
| 43 | timeout_length = 5 | 46 | timeout_length = 5 | 
| 44 | sleep_length = 1 | 47 | sleep_length = 1 | 
| 48 | context = ssl.create_default_context() | ||
| 49 | if insecure: | ||
| 50 | context.check_hostname = False | ||
| 51 | context.verify_mode = 0 | ||
| 45 | while connection == None and timeout < timeout_length: | 52 | while connection == None and timeout < timeout_length: | 
| 46 | try: | 53 | try: | 
| 47 | connection = socket.create_connection((mfa_server,client_port)) | 54 | #connection = socket.create_connection((mfa_server,client_port)) | 
| 55 | connection = context.wrap_socket(socket.socket(socket.AF_INET), | ||
| 56 | server_hostname=mfa_server) | ||
| 57 | connection.connect((mfa_server,int(client_port))) | ||
| 48 | connection.send(client_key.encode(FORMAT)) | 58 | connection.send(client_key.encode(FORMAT)) | 
| 49 | response = connection.recv(ACK_LENGTH).decode(FORMAT) | 59 | response = connection.recv(ACK_LENGTH).decode(FORMAT) | 
| 50 | if response == ACK_MESSAGE: | 60 | if response == ACK_MESSAGE: | 
| @@ -55,6 +65,8 @@ def init_connection(mfa_server, client_port, client_key): | |||
| 55 | except ConnectionError: | 65 | except ConnectionError: | 
| 56 | time.sleep(sleep_length) | 66 | time.sleep(sleep_length) | 
| 57 | timeout += sleep_length | 67 | timeout += sleep_length | 
| 68 | except ssl.SSLCertVerificationError: | ||
| 69 | die("error: server presented invalid certificate") | ||
| 58 | return connection | 70 | return connection | 
| 59 | 71 | ||
| 60 | 72 | ||
| @@ -72,27 +84,31 @@ def get_vars(args,confparser): | |||
| 72 | server = None | 84 | server = None | 
| 73 | port = None | 85 | port = None | 
| 74 | key = None | 86 | key = None | 
| 87 | insecure = None | ||
| 75 | 88 | ||
| 76 | # Set values from config file first | 89 | # Set values from config file first | 
| 77 | if confparser.has_section("client"): | 90 | if confparser.has_section("client"): | 
| 78 | server = confparser.get("client","server",fallback=None) | 91 | server = confparser.get("client","server",fallback=None) | 
| 79 | port = confparser.get("client","port",fallback=None) | 92 | port = confparser.get("client","port",fallback=None) | 
| 80 | key = confparser.get("client","key",fallback=None) | 93 | key = confparser.get("client","key",fallback=None) | 
| 94 | insecure = bool(confparser.get("client","insecure",fallback=False)) | ||
| 81 | 95 | ||
| 82 | # Let command line args overwrite any values | 96 | # Let command line args overwrite any values | 
| 83 | if args.server: | 97 | if args.server != None: | 
| 84 | server = args.server | 98 | server = args.server | 
| 85 | if args.port: | 99 | if args.port != None: | 
| 86 | port = args.port | 100 | port = args.port | 
| 87 | if args.key: | 101 | if args.key != None: | 
| 88 | key = args.key | 102 | key = args.key | 
| 103 | if args.insecure: | ||
| 104 | insecure = args.insecure | ||
| 89 | 105 | ||
| 90 | # Exit if any value is null | 106 | # Exit if any value is null | 
| 91 | if None in [server,port,key]: | 107 | if None in [server,port,key]: | 
| 92 | print("error: one or more items unspecified") | 108 | print("error: one or more items unspecified") | 
| 93 | sys.exit(1) | 109 | sys.exit(1) | 
| 94 | 110 | ||
| 95 | return server,port,key | 111 | return server,port,key,insecure | 
| 96 | 112 | ||
| 97 | 113 | ||
| 98 | def main(): | 114 | def main(): | 
| @@ -100,7 +116,7 @@ def main(): | |||
| 100 | args = parse_arguments() | 116 | args = parse_arguments() | 
| 101 | confparser = read_config(args.config) | 117 | confparser = read_config(args.config) | 
| 102 | 118 | ||
| 103 | mfa_server,client_port,client_key = get_vars(args,confparser) | 119 | mfa_server,client_port,client_key,insecure = get_vars(args,confparser) | 
| 104 | 120 | ||
| 105 | # Exit if invalid key is provided | 121 | # Exit if invalid key is provided | 
| 106 | if len(client_key) != KEY_LENGTH: | 122 | if len(client_key) != KEY_LENGTH: | 
| @@ -108,7 +124,7 @@ def main(): | |||
| 108 | sys.exit(1) | 124 | sys.exit(1) | 
| 109 | 125 | ||
| 110 | # Open connection to server | 126 | # Open connection to server | 
| 111 | conn = init_connection(mfa_server,client_port,client_key) | 127 | conn = init_connection(mfa_server,client_port,client_key,insecure) | 
| 112 | if conn == None: | 128 | if conn == None: | 
| 113 | print("timed out attempting to connect to server") | 129 | print("timed out attempting to connect to server") | 
| 114 | sys.exit(1) | 130 | sys.exit(1) | 
| diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py index 85d0a82..5a5a112 100755 --- a/pam/pam_mfa.py +++ b/pam/pam_mfa.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | #!/usr/bin/python3 | 1 | #!/usr/bin/python3 | 
| 2 | import socket | 2 | import socket | 
| 3 | import ssl | ||
| 3 | import argparse | 4 | import argparse | 
| 4 | import time | 5 | import time | 
| 5 | import sys | 6 | import sys | 
| @@ -33,9 +34,11 @@ def parse_arguments(): | |||
| 33 | default="/etc/mfa/mfa.conf") | 34 | default="/etc/mfa/mfa.conf") | 
| 34 | parser.add_argument("--server",type=str,help="MFA server address") | 35 | parser.add_argument("--server",type=str,help="MFA server address") | 
| 35 | parser.add_argument("--port",type=str,help="MFA server PAM connection port") | 36 | parser.add_argument("--port",type=str,help="MFA server PAM connection port") | 
| 37 | parser.add_argument("--insecure",action="store_true", | ||
| 38 | help="Accept invalid TLS certificates") | ||
| 36 | return parser.parse_args() | 39 | return parser.parse_args() | 
| 37 | 40 | ||
| 38 | def init_connection(mfa_server, pam_port): | 41 | def init_connection(mfa_server, pam_port, insecure): | 
| 39 | # Attempts to connect to MFA server with provided address and port | 42 | # Attempts to connect to MFA server with provided address and port | 
| 40 | # Repeats connection attempts once per second until timeout is reached | 43 | # Repeats connection attempts once per second until timeout is reached | 
| 41 | # Returns the socket if connection was successful or None otherwise | 44 | # Returns the socket if connection was successful or None otherwise | 
| @@ -43,13 +46,22 @@ def init_connection(mfa_server, pam_port): | |||
| 43 | timeout = 0 | 46 | timeout = 0 | 
| 44 | timeout_length = 5 | 47 | timeout_length = 5 | 
| 45 | sleep_length = 1 | 48 | sleep_length = 1 | 
| 49 | context = ssl.create_default_context() | ||
| 50 | if insecure: | ||
| 51 | context.check_hostname = False | ||
| 52 | context.verify_mode = 0 | ||
| 46 | while connection == None and timeout < timeout_length: | 53 | while connection == None and timeout < timeout_length: | 
| 47 | try: | 54 | try: | 
| 48 | connection = socket.create_connection((mfa_server,pam_port)) | 55 | #connection = socket.create_connection((mfa_server,client_port)) | 
| 56 | connection = context.wrap_socket(socket.socket(socket.AF_INET), | ||
| 57 | server_hostname=mfa_server) | ||
| 58 | connection.connect((mfa_server,int(pam_port))) | ||
| 49 | return connection | 59 | return connection | 
| 50 | except (ConnectionError,ConnectionRefusedError): | 60 | except (ConnectionError,ConnectionRefusedError): | 
| 51 | time.sleep(sleep_length) | 61 | time.sleep(sleep_length) | 
| 52 | timeout += sleep_length | 62 | timeout += sleep_length | 
| 63 | except ssl.SSLCertVerificationError: | ||
| 64 | die("error: server presented invalid certificate") | ||
| 53 | return None | 65 | return None | 
| 54 | 66 | ||
| 55 | 67 | ||
| @@ -82,24 +94,28 @@ def get_vars(args,confparser): | |||
| 82 | 94 | ||
| 83 | server = None | 95 | server = None | 
| 84 | port = None | 96 | port = None | 
| 97 | insecure = None | ||
| 85 | 98 | ||
| 86 | # Set values from config file first | 99 | # Set values from config file first | 
| 87 | if confparser.has_section("pam"): | 100 | if confparser.has_section("pam"): | 
| 88 | server = confparser.get("pam","server",fallback=None) | 101 | server = confparser.get("pam","server",fallback=None) | 
| 89 | port = confparser.get("pam","port",fallback=None) | 102 | port = confparser.get("pam","port",fallback=None) | 
| 103 | insecure = bool(confparser.get("pam","insecure",fallback=False)) | ||
| 90 | 104 | ||
| 91 | # Let command line args overwrite any values | 105 | # Let command line args overwrite any values | 
| 92 | if args.server: | 106 | if args.server != None: | 
| 93 | server = args.server | 107 | server = args.server | 
| 94 | if args.port: | 108 | if args.port != None: | 
| 95 | port = args.port | 109 | port = args.port | 
| 110 | if args.insecure: | ||
| 111 | insecure = args.insecure | ||
| 96 | 112 | ||
| 97 | # Exit if any value is null | 113 | # Exit if any value is null | 
| 98 | if None in [server,port]: | 114 | if None in [server,port]: | 
| 99 | print("error: one or more items unspecified") | 115 | print("error: one or more items unspecified") | 
| 100 | sys.exit(1) | 116 | sys.exit(1) | 
| 101 | 117 | ||
| 102 | return server,port | 118 | return server,port,insecure | 
| 103 | 119 | ||
| 104 | 120 | ||
| 105 | def main(): | 121 | def main(): | 
| @@ -109,7 +125,7 @@ def main(): | |||
| 109 | # Get arguments | 125 | # Get arguments | 
| 110 | args = parse_arguments() | 126 | args = parse_arguments() | 
| 111 | confparser = read_config(args.config) | 127 | confparser = read_config(args.config) | 
| 112 | mfa_server,pam_port = get_vars(args,confparser) | 128 | mfa_server,pam_port,insecure = get_vars(args,confparser) | 
| 113 | user = args.user | 129 | user = args.user | 
| 114 | service = args.service | 130 | service = args.service | 
| 115 | 131 | ||
| @@ -130,7 +146,7 @@ def main(): | |||
| 130 | 146 | ||
| 131 | 147 | ||
| 132 | # Initalize connection to MFA server. Quit if unable to connect. | 148 | # Initalize connection to MFA server. Quit if unable to connect. | 
| 133 | connection = init_connection(mfa_server,pam_port) | 149 | connection = init_connection(mfa_server,pam_port,insecure) | 
| 134 | if connection == None: | 150 | if connection == None: | 
| 135 | print(failed) | 151 | print(failed) | 
| 136 | sys.exit(1) | 152 | sys.exit(1) | 
| diff --git a/server/mfad.py b/server/mfad.py index 17d2585..18a048a 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 | 
| 2 | import socket | 2 | import socket | 
| 3 | import ssl | ||
| 3 | import os | 4 | import os | 
| 4 | import sys | 5 | import sys | 
| 5 | import time | 6 | import time | 
| @@ -211,8 +212,9 @@ def parse_pam_data(data): | |||
| 211 | def handle_pam(db, conn, addr): | 212 | def handle_pam(db, conn, addr): | 
| 212 | # Get request and data from PAM module | 213 | # Get request and data from PAM module | 
| 213 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) | 214 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) | 
| 214 | if header == "": | 215 | if len(header) != HEADER_LENGTH: | 
| 215 | die("error: lost connection to pam module") | 216 | conn.close() | 
| 217 | die("error: invalid data from PAM module") | ||
| 216 | data_length = int(header) | 218 | data_length = int(header) | 
| 217 | pam_data = conn.recv(data_length).decode(FORMAT) | 219 | pam_data = conn.recv(data_length).decode(FORMAT) | 
| 218 | print("Got pam_data: " + pam_data) | 220 | print("Got pam_data: " + pam_data) | 
| @@ -238,19 +240,31 @@ def handle_pam(db, conn, addr): | |||
| 238 | 240 | ||
| 239 | 241 | ||
| 240 | def listen_client(db, addr, port): | 242 | def listen_client(db, addr, port): | 
| 241 | with socket.create_server((addr, port)) as server: | 243 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | 
| 242 | while True: | 244 | context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") | 
| 243 | conn, addr = server.accept() | 245 | with socket.create_server((addr, port)) as sock: | 
| 244 | thread = threading.Thread(target=handle_client,args=(db, conn,addr)) | 246 | with context.wrap_socket(sock, server_side=True) as tls_socket: | 
| 245 | thread.start() | 247 | while True: | 
| 248 | try: | ||
| 249 | conn, addr = tls_socket.accept() | ||
| 250 | thread = threading.Thread(target=handle_client,args=(db,conn,addr)) | ||
| 251 | thread.start() | ||
| 252 | except ssl.SSLError: | ||
| 253 | print("client: ssl handshake error") | ||
| 246 | 254 | ||
| 247 | 255 | ||
| 248 | def listen_pam(db, addr, port): | 256 | def listen_pam(db, addr, port): | 
| 249 | with socket.create_server((addr,port)) as pam_server: | 257 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | 
| 250 | while True: | 258 | context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") | 
| 251 | conn, addr = pam_server.accept() | 259 | with socket.create_server((addr,port)) as sock: | 
| 252 | thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) | 260 | with context.wrap_socket(sock, server_side=True) as tls_socket: | 
| 253 | thread.start() | 261 | while True: | 
| 262 | try: | ||
| 263 | conn, addr = tls_socket.accept() | ||
| 264 | thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) | ||
| 265 | thread.start() | ||
| 266 | except ssl.SSLError: | ||
| 267 | print("pam: ssl handshake error") | ||
| 254 | 268 | ||
| 255 | 269 | ||
| 256 | ################################################################################ | 270 | ################################################################################ | 
