summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Chudnick <sam@chudnick.com>2022-07-04 12:24:59 -0400
committerSam Chudnick <sam@chudnick.com>2022-07-04 12:24:59 -0400
commit755d7f5f94b720b028d085cf971c5935c130dec1 (patch)
treef015e8929563e5302d2ba8e2ee7215d1231debdd
parent11a4a5edb9f0e22fe8355291942ed03c9765ced5 (diff)
Implemented TLS encrypted connections
Implemented TLS encrypted connections. Added command line argument and configuration file option to accept invalid (self-signed) certificates. Fixed a couple of unrelated issues.
-rwxr-xr-xclient/client.py32
-rwxr-xr-xpam/pam_mfa.py30
-rwxr-xr-xserver/mfad.py38
3 files changed, 73 insertions, 27 deletions
diff --git a/client/client.py b/client/client.py
index 70d85a0..cc22d0b 100755
--- a/client/client.py
+++ b/client/client.py
@@ -1,6 +1,7 @@
1#!/usr/bin/env python3 1#!/usr/bin/env python3
2 2
3import socket 3import socket
4import ssl
4import time 5import time
5import argparse 6import argparse
6import sys 7import sys
@@ -25,6 +26,8 @@ def parse_arguments():
25 parser.add_argument("--config",type=str,help="Path to config file",\ 26 parser.add_argument("--config",type=str,help="Path to config file",\
26 default="/etc/mfa/mfa.conf") 27 default="/etc/mfa/mfa.conf")
27 parser.add_argument("--key",type=str,help="Client connection key") 28 parser.add_argument("--key",type=str,help="Client connection key")
29 parser.add_argument("--insecure",action="store_true",\
30 help="Accept invalid TLS certificates")
28 return parser.parse_args() 31 return parser.parse_args()
29 32
30def prompt_user(prompt): 33def prompt_user(prompt):
@@ -34,7 +37,7 @@ def prompt_user(prompt):
34 return result 37 return result
35 38
36 39
37def init_connection(mfa_server, client_port, client_key): 40def init_connection(mfa_server, client_port, client_key, insecure=False):
38 # Attempts to connect to MFA server with provided address,port, and key. 41 # Attempts to connect to MFA server with provided address,port, and key.
39 # Repeats attempt once a seconds until timeout is reached. 42 # Repeats attempt once a seconds until timeout is reached.
40 # Returns socket or None if unable to connect 43 # Returns socket or None if unable to connect
@@ -42,9 +45,16 @@ def init_connection(mfa_server, client_port, client_key):
42 timeout = 0 45 timeout = 0
43 timeout_length = 5 46 timeout_length = 5
44 sleep_length = 1 47 sleep_length = 1
48 context = ssl.create_default_context()
49 if insecure:
50 context.check_hostname = False
51 context.verify_mode = 0
45 while connection == None and timeout < timeout_length: 52 while connection == None and timeout < timeout_length:
46 try: 53 try:
47 connection = socket.create_connection((mfa_server,client_port)) 54 #connection = socket.create_connection((mfa_server,client_port))
55 connection = context.wrap_socket(socket.socket(socket.AF_INET),
56 server_hostname=mfa_server)
57 connection.connect((mfa_server,int(client_port)))
48 connection.send(client_key.encode(FORMAT)) 58 connection.send(client_key.encode(FORMAT))
49 response = connection.recv(ACK_LENGTH).decode(FORMAT) 59 response = connection.recv(ACK_LENGTH).decode(FORMAT)
50 if response == ACK_MESSAGE: 60 if response == ACK_MESSAGE:
@@ -55,6 +65,8 @@ def init_connection(mfa_server, client_port, client_key):
55 except ConnectionError: 65 except ConnectionError:
56 time.sleep(sleep_length) 66 time.sleep(sleep_length)
57 timeout += sleep_length 67 timeout += sleep_length
68 except ssl.SSLCertVerificationError:
69 die("error: server presented invalid certificate")
58 return connection 70 return connection
59 71
60 72
@@ -72,27 +84,31 @@ def get_vars(args,confparser):
72 server = None 84 server = None
73 port = None 85 port = None
74 key = None 86 key = None
87 insecure = None
75 88
76 # Set values from config file first 89 # Set values from config file first
77 if confparser.has_section("client"): 90 if confparser.has_section("client"):
78 server = confparser.get("client","server",fallback=None) 91 server = confparser.get("client","server",fallback=None)
79 port = confparser.get("client","port",fallback=None) 92 port = confparser.get("client","port",fallback=None)
80 key = confparser.get("client","key",fallback=None) 93 key = confparser.get("client","key",fallback=None)
94 insecure = bool(confparser.get("client","insecure",fallback=False))
81 95
82 # Let command line args overwrite any values 96 # Let command line args overwrite any values
83 if args.server: 97 if args.server != None:
84 server = args.server 98 server = args.server
85 if args.port: 99 if args.port != None:
86 port = args.port 100 port = args.port
87 if args.key: 101 if args.key != None:
88 key = args.key 102 key = args.key
103 if args.insecure:
104 insecure = args.insecure
89 105
90 # Exit if any value is null 106 # Exit if any value is null
91 if None in [server,port,key]: 107 if None in [server,port,key]:
92 print("error: one or more items unspecified") 108 print("error: one or more items unspecified")
93 sys.exit(1) 109 sys.exit(1)
94 110
95 return server,port,key 111 return server,port,key,insecure
96 112
97 113
98def main(): 114def main():
@@ -100,7 +116,7 @@ def main():
100 args = parse_arguments() 116 args = parse_arguments()
101 confparser = read_config(args.config) 117 confparser = read_config(args.config)
102 118
103 mfa_server,client_port,client_key = get_vars(args,confparser) 119 mfa_server,client_port,client_key,insecure = get_vars(args,confparser)
104 120
105 # Exit if invalid key is provided 121 # Exit if invalid key is provided
106 if len(client_key) != KEY_LENGTH: 122 if len(client_key) != KEY_LENGTH:
@@ -108,7 +124,7 @@ def main():
108 sys.exit(1) 124 sys.exit(1)
109 125
110 # Open connection to server 126 # Open connection to server
111 conn = init_connection(mfa_server,client_port,client_key) 127 conn = init_connection(mfa_server,client_port,client_key,insecure)
112 if conn == None: 128 if conn == None:
113 print("timed out attempting to connect to server") 129 print("timed out attempting to connect to server")
114 sys.exit(1) 130 sys.exit(1)
diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py
index 85d0a82..5a5a112 100755
--- a/pam/pam_mfa.py
+++ b/pam/pam_mfa.py
@@ -1,5 +1,6 @@
1#!/usr/bin/python3 1#!/usr/bin/python3
2import socket 2import socket
3import ssl
3import argparse 4import argparse
4import time 5import time
5import sys 6import sys
@@ -33,9 +34,11 @@ def parse_arguments():
33 default="/etc/mfa/mfa.conf") 34 default="/etc/mfa/mfa.conf")
34 parser.add_argument("--server",type=str,help="MFA server address") 35 parser.add_argument("--server",type=str,help="MFA server address")
35 parser.add_argument("--port",type=str,help="MFA server PAM connection port") 36 parser.add_argument("--port",type=str,help="MFA server PAM connection port")
37 parser.add_argument("--insecure",action="store_true",
38 help="Accept invalid TLS certificates")
36 return parser.parse_args() 39 return parser.parse_args()
37 40
38def init_connection(mfa_server, pam_port): 41def init_connection(mfa_server, pam_port, insecure):
39 # Attempts to connect to MFA server with provided address and port 42 # Attempts to connect to MFA server with provided address and port
40 # Repeats connection attempts once per second until timeout is reached 43 # Repeats connection attempts once per second until timeout is reached
41 # Returns the socket if connection was successful or None otherwise 44 # Returns the socket if connection was successful or None otherwise
@@ -43,13 +46,22 @@ def init_connection(mfa_server, pam_port):
43 timeout = 0 46 timeout = 0
44 timeout_length = 5 47 timeout_length = 5
45 sleep_length = 1 48 sleep_length = 1
49 context = ssl.create_default_context()
50 if insecure:
51 context.check_hostname = False
52 context.verify_mode = 0
46 while connection == None and timeout < timeout_length: 53 while connection == None and timeout < timeout_length:
47 try: 54 try:
48 connection = socket.create_connection((mfa_server,pam_port)) 55 #connection = socket.create_connection((mfa_server,client_port))
56 connection = context.wrap_socket(socket.socket(socket.AF_INET),
57 server_hostname=mfa_server)
58 connection.connect((mfa_server,int(pam_port)))
49 return connection 59 return connection
50 except (ConnectionError,ConnectionRefusedError): 60 except (ConnectionError,ConnectionRefusedError):
51 time.sleep(sleep_length) 61 time.sleep(sleep_length)
52 timeout += sleep_length 62 timeout += sleep_length
63 except ssl.SSLCertVerificationError:
64 die("error: server presented invalid certificate")
53 return None 65 return None
54 66
55 67
@@ -82,24 +94,28 @@ def get_vars(args,confparser):
82 94
83 server = None 95 server = None
84 port = None 96 port = None
97 insecure = None
85 98
86 # Set values from config file first 99 # Set values from config file first
87 if confparser.has_section("pam"): 100 if confparser.has_section("pam"):
88 server = confparser.get("pam","server",fallback=None) 101 server = confparser.get("pam","server",fallback=None)
89 port = confparser.get("pam","port",fallback=None) 102 port = confparser.get("pam","port",fallback=None)
103 insecure = bool(confparser.get("pam","insecure",fallback=False))
90 104
91 # Let command line args overwrite any values 105 # Let command line args overwrite any values
92 if args.server: 106 if args.server != None:
93 server = args.server 107 server = args.server
94 if args.port: 108 if args.port != None:
95 port = args.port 109 port = args.port
110 if args.insecure:
111 insecure = args.insecure
96 112
97 # Exit if any value is null 113 # Exit if any value is null
98 if None in [server,port]: 114 if None in [server,port]:
99 print("error: one or more items unspecified") 115 print("error: one or more items unspecified")
100 sys.exit(1) 116 sys.exit(1)
101 117
102 return server,port 118 return server,port,insecure
103 119
104 120
105def main(): 121def main():
@@ -109,7 +125,7 @@ def main():
109 # Get arguments 125 # Get arguments
110 args = parse_arguments() 126 args = parse_arguments()
111 confparser = read_config(args.config) 127 confparser = read_config(args.config)
112 mfa_server,pam_port = get_vars(args,confparser) 128 mfa_server,pam_port,insecure = get_vars(args,confparser)
113 user = args.user 129 user = args.user
114 service = args.service 130 service = args.service
115 131
@@ -130,7 +146,7 @@ def main():
130 146
131 147
132 # Initalize connection to MFA server. Quit if unable to connect. 148 # Initalize connection to MFA server. Quit if unable to connect.
133 connection = init_connection(mfa_server,pam_port) 149 connection = init_connection(mfa_server,pam_port,insecure)
134 if connection == None: 150 if connection == None:
135 print(failed) 151 print(failed)
136 sys.exit(1) 152 sys.exit(1)
diff --git a/server/mfad.py b/server/mfad.py
index 17d2585..18a048a 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -1,5 +1,6 @@
1#!/usr/bin/env python3 1#!/usr/bin/env python3
2import socket 2import socket
3import ssl
3import os 4import os
4import sys 5import sys
5import time 6import time
@@ -211,8 +212,9 @@ def parse_pam_data(data):
211def handle_pam(db, conn, addr): 212def handle_pam(db, conn, addr):
212 # Get request and data from PAM module 213 # Get request and data from PAM module
213 header = conn.recv(HEADER_LENGTH).decode(FORMAT) 214 header = conn.recv(HEADER_LENGTH).decode(FORMAT)
214 if header == "": 215 if len(header) != HEADER_LENGTH:
215 die("error: lost connection to pam module") 216 conn.close()
217 die("error: invalid data from PAM module")
216 data_length = int(header) 218 data_length = int(header)
217 pam_data = conn.recv(data_length).decode(FORMAT) 219 pam_data = conn.recv(data_length).decode(FORMAT)
218 print("Got pam_data: " + pam_data) 220 print("Got pam_data: " + pam_data)
@@ -238,19 +240,31 @@ def handle_pam(db, conn, addr):
238 240
239 241
240def listen_client(db, addr, port): 242def listen_client(db, addr, port):
241 with socket.create_server((addr, port)) as server: 243 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
242 while True: 244 context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem")
243 conn, addr = server.accept() 245 with socket.create_server((addr, port)) as sock:
244 thread = threading.Thread(target=handle_client,args=(db, conn,addr)) 246 with context.wrap_socket(sock, server_side=True) as tls_socket:
245 thread.start() 247 while True:
248 try:
249 conn, addr = tls_socket.accept()
250 thread = threading.Thread(target=handle_client,args=(db,conn,addr))
251 thread.start()
252 except ssl.SSLError:
253 print("client: ssl handshake error")
246 254
247 255
248def listen_pam(db, addr, port): 256def listen_pam(db, addr, port):
249 with socket.create_server((addr,port)) as pam_server: 257 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
250 while True: 258 context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem")
251 conn, addr = pam_server.accept() 259 with socket.create_server((addr,port)) as sock:
252 thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) 260 with context.wrap_socket(sock, server_side=True) as tls_socket:
253 thread.start() 261 while True:
262 try:
263 conn, addr = tls_socket.accept()
264 thread = threading.Thread(target=handle_pam,args=(db, conn,addr))
265 thread.start()
266 except ssl.SSLError:
267 print("pam: ssl handshake error")
254 268
255 269
256################################################################################ 270################################################################################