summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Chudnick <sam@chudnick.com>2022-07-04 20:03:27 -0400
committerSam Chudnick <sam@chudnick.com>2022-07-04 20:03:27 -0400
commit2e840e7c381f88425952c6fa9d68e0d433084a5a (patch)
tree31f7888ade33fbc112bd2b7509aac5c39bb2af82
parent46564f357c175c7a01a36422307f05b543a83190 (diff)
Support both TLS encrypted sessions and plaintext sessions
Added support for both TLS and plaintext connections. Server can accept both types of connection simultaneously or in different combinations (i.e encrypted client and plaintext PAM). Added options for specifying dedicated TLS ports on server. Added --plain options for client and PAM to force plaintext connections, default is to use encrypted connections. Configuring encrypted client and PAM connections and plaintext server connections allows for use of a reverse proxy setup with something like nginx. This will avoid having to expose the MFA server directly in setups that traverse the internet.
-rwxr-xr-xclient/client.py44
-rwxr-xr-xpam/pam_mfa.py46
-rwxr-xr-xserver/mfad.py74
3 files changed, 137 insertions, 27 deletions
diff --git a/client/client.py b/client/client.py
index cc22d0b..0388073 100755
--- a/client/client.py
+++ b/client/client.py
@@ -26,6 +26,7 @@ def parse_arguments():
26 parser.add_argument("--config",type=str,help="Path to config file",\ 26 parser.add_argument("--config",type=str,help="Path to config file",\
27 default="/etc/mfa/mfa.conf") 27 default="/etc/mfa/mfa.conf")
28 parser.add_argument("--key",type=str,help="Client connection key") 28 parser.add_argument("--key",type=str,help="Client connection key")
29 parser.add_argument("--plain",action="store_true",help="Connect without TLS")
29 parser.add_argument("--insecure",action="store_true",\ 30 parser.add_argument("--insecure",action="store_true",\
30 help="Accept invalid TLS certificates") 31 help="Accept invalid TLS certificates")
31 return parser.parse_args() 32 return parser.parse_args()
@@ -37,7 +38,7 @@ def prompt_user(prompt):
37 return result 38 return result
38 39
39 40
40def init_connection(mfa_server, client_port, client_key, insecure=False): 41def init_connection_tls(mfa_server, client_port, client_key, insecure=False):
41 # Attempts to connect to MFA server with provided address,port, and key. 42 # Attempts to connect to MFA server with provided address,port, and key.
42 # Repeats attempt once a seconds until timeout is reached. 43 # Repeats attempt once a seconds until timeout is reached.
43 # Returns socket or None if unable to connect 44 # Returns socket or None if unable to connect
@@ -70,6 +71,27 @@ def init_connection(mfa_server, client_port, client_key, insecure=False):
70 return connection 71 return connection
71 72
72 73
74def init_connection(mfa_server, client_port, client_key):
75 connection = None
76 timeout = 0
77 timeout_length = 5
78 sleep_length = 1
79 while connection == None and timeout < timeout_length:
80 try:
81 connection = socket.create_connection((mfa_server,client_port))
82 connection.send(client_key.encode(FORMAT))
83 response = connection.recv(ACK_LENGTH).decode(FORMAT)
84 if response == ACK_MESSAGE:
85 print("connected to mfa server")
86 elif response == DISCONNECT_MESSAGE:
87 print("server terminated connection")
88 sys.exit(1)
89 except ConnectionError:
90 time.sleep(sleep_length)
91 timeout += sleep_length
92 return connection
93
94
73def read_config(config_file): 95def read_config(config_file):
74 parser = configparser.ConfigParser(inline_comment_prefixes="#") 96 parser = configparser.ConfigParser(inline_comment_prefixes="#")
75 parser.read(config_file) 97 parser.read(config_file)
@@ -84,6 +106,7 @@ def get_vars(args,confparser):
84 server = None 106 server = None
85 port = None 107 port = None
86 key = None 108 key = None
109 plain = None
87 insecure = None 110 insecure = None
88 111
89 # Set values from config file first 112 # Set values from config file first
@@ -91,7 +114,13 @@ def get_vars(args,confparser):
91 server = confparser.get("client","server",fallback=None) 114 server = confparser.get("client","server",fallback=None)
92 port = confparser.get("client","port",fallback=None) 115 port = confparser.get("client","port",fallback=None)
93 key = confparser.get("client","key",fallback=None) 116 key = confparser.get("client","key",fallback=None)
94 insecure = bool(confparser.get("client","insecure",fallback=False)) 117 plain = confparser.get("client","plain",fallback=False)
118 insecure = confparser.get("client","insecure",fallback=False)
119
120 if plain.lower() == "false":
121 plain = False
122 if insecure.lower() == "false":
123 insecure = False
95 124
96 # Let command line args overwrite any values 125 # Let command line args overwrite any values
97 if args.server != None: 126 if args.server != None:
@@ -100,6 +129,8 @@ def get_vars(args,confparser):
100 port = args.port 129 port = args.port
101 if args.key != None: 130 if args.key != None:
102 key = args.key 131 key = args.key
132 if args.plain:
133 plain = args.plain
103 if args.insecure: 134 if args.insecure:
104 insecure = args.insecure 135 insecure = args.insecure
105 136
@@ -108,7 +139,7 @@ def get_vars(args,confparser):
108 print("error: one or more items unspecified") 139 print("error: one or more items unspecified")
109 sys.exit(1) 140 sys.exit(1)
110 141
111 return server,port,key,insecure 142 return server,port,key,plain,insecure
112 143
113 144
114def main(): 145def main():
@@ -116,7 +147,7 @@ def main():
116 args = parse_arguments() 147 args = parse_arguments()
117 confparser = read_config(args.config) 148 confparser = read_config(args.config)
118 149
119 mfa_server,client_port,client_key,insecure = get_vars(args,confparser) 150 mfa_server,client_port,client_key,plain,insecure = get_vars(args,confparser)
120 151
121 # Exit if invalid key is provided 152 # Exit if invalid key is provided
122 if len(client_key) != KEY_LENGTH: 153 if len(client_key) != KEY_LENGTH:
@@ -124,7 +155,10 @@ def main():
124 sys.exit(1) 155 sys.exit(1)
125 156
126 # Open connection to server 157 # Open connection to server
127 conn = init_connection(mfa_server,client_port,client_key,insecure) 158 if plain:
159 conn = init_connection(mfa_server,client_port,client_key)
160 else:
161 conn = init_connection_tls(mfa_server,client_port,client_key,insecure)
128 if conn == None: 162 if conn == None:
129 print("timed out attempting to connect to server") 163 print("timed out attempting to connect to server")
130 sys.exit(1) 164 sys.exit(1)
diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py
index 5a5a112..a5105a2 100755
--- a/pam/pam_mfa.py
+++ b/pam/pam_mfa.py
@@ -34,11 +34,12 @@ def parse_arguments():
34 default="/etc/mfa/mfa.conf") 34 default="/etc/mfa/mfa.conf")
35 parser.add_argument("--server",type=str,help="MFA server address") 35 parser.add_argument("--server",type=str,help="MFA server address")
36 parser.add_argument("--port",type=str,help="MFA server PAM connection port") 36 parser.add_argument("--port",type=str,help="MFA server PAM connection port")
37 parser.add_argument("--plain",action="store_true",help="Connect without TLS")
37 parser.add_argument("--insecure",action="store_true", 38 parser.add_argument("--insecure",action="store_true",
38 help="Accept invalid TLS certificates") 39 help="Accept invalid TLS certificates")
39 return parser.parse_args() 40 return parser.parse_args()
40 41
41def init_connection(mfa_server, pam_port, insecure): 42def init_connection_tls(mfa_server, pam_port, insecure):
42 # Attempts to connect to MFA server with provided address and port 43 # Attempts to connect to MFA server with provided address and port
43 # Repeats connection attempts once per second until timeout is reached 44 # Repeats connection attempts once per second until timeout is reached
44 # Returns the socket if connection was successful or None otherwise 45 # Returns the socket if connection was successful or None otherwise
@@ -52,7 +53,6 @@ def init_connection(mfa_server, pam_port, insecure):
52 context.verify_mode = 0 53 context.verify_mode = 0
53 while connection == None and timeout < timeout_length: 54 while connection == None and timeout < timeout_length:
54 try: 55 try:
55 #connection = socket.create_connection((mfa_server,client_port))
56 connection = context.wrap_socket(socket.socket(socket.AF_INET), 56 connection = context.wrap_socket(socket.socket(socket.AF_INET),
57 server_hostname=mfa_server) 57 server_hostname=mfa_server)
58 connection.connect((mfa_server,int(pam_port))) 58 connection.connect((mfa_server,int(pam_port)))
@@ -65,6 +65,24 @@ def init_connection(mfa_server, pam_port, insecure):
65 return None 65 return None
66 66
67 67
68def init_connection(mfa_server, pam_port):
69 # Attempts to connect to MFA server with provided address and port
70 # Repeats connection attempts once per second until timeout is reached
71 # Returns the socket if connection was successful or None otherwise
72 connection = None
73 timeout = 0
74 timeout_length = 5
75 sleep_length = 1
76 while connection == None and timeout < timeout_length:
77 try:
78 connection = socket.create_connection((mfa_server,pam_port))
79 return connection
80 except (ConnectionError,ConnectionRefusedError):
81 time.sleep(sleep_length)
82 timeout += sleep_length
83 return None
84
85
68def read_config(config_file): 86def read_config(config_file):
69 # Read config file for server and port info 87 # Read config file for server and port info
70 # Return tuple (server,port) 88 # Return tuple (server,port)
@@ -94,19 +112,28 @@ def get_vars(args,confparser):
94 112
95 server = None 113 server = None
96 port = None 114 port = None
115 plain = None
97 insecure = None 116 insecure = None
98 117
99 # Set values from config file first 118 # Set values from config file first
100 if confparser.has_section("pam"): 119 if confparser.has_section("pam"):
101 server = confparser.get("pam","server",fallback=None) 120 server = confparser.get("pam","server",fallback=None)
102 port = confparser.get("pam","port",fallback=None) 121 port = confparser.get("pam","port",fallback=None)
103 insecure = bool(confparser.get("pam","insecure",fallback=False)) 122 plain = confparser.get("client","plain",fallback=False)
123 insecure = confparser.get("client","insecure",fallback=False)
104 124
125 if plain.lower() == "false":
126 plain = False
127 if insecure.lower() == "false":
128 insecure = False
129
105 # Let command line args overwrite any values 130 # Let command line args overwrite any values
106 if args.server != None: 131 if args.server != None:
107 server = args.server 132 server = args.server
108 if args.port != None: 133 if args.port != None:
109 port = args.port 134 port = args.port
135 if args.plain:
136 plain = args.plain
110 if args.insecure: 137 if args.insecure:
111 insecure = args.insecure 138 insecure = args.insecure
112 139
@@ -115,7 +142,7 @@ def get_vars(args,confparser):
115 print("error: one or more items unspecified") 142 print("error: one or more items unspecified")
116 sys.exit(1) 143 sys.exit(1)
117 144
118 return server,port,insecure 145 return server,port,plain,insecure
119 146
120 147
121def main(): 148def main():
@@ -125,7 +152,7 @@ def main():
125 # Get arguments 152 # Get arguments
126 args = parse_arguments() 153 args = parse_arguments()
127 confparser = read_config(args.config) 154 confparser = read_config(args.config)
128 mfa_server,pam_port,insecure = get_vars(args,confparser) 155 mfa_server,pam_port,plain,insecure = get_vars(args,confparser)
129 user = args.user 156 user = args.user
130 service = args.service 157 service = args.service
131 158
@@ -144,12 +171,13 @@ def main():
144 hostname = args.host 171 hostname = args.host
145 data = user + "," + hostname + "," + service 172 data = user + "," + hostname + "," + service
146 173
147
148 # Initalize connection to MFA server. Quit if unable to connect. 174 # Initalize connection to MFA server. Quit if unable to connect.
149 connection = init_connection(mfa_server,pam_port,insecure) 175 if plain:
176 connection = init_connection(mfa_server, pam_port)
177 else:
178 connection = init_connection_tls(mfa_server,pam_port,insecure)
150 if connection == None: 179 if connection == None:
151 print(failed) 180 die("failed to connect")
152 sys.exit(1)
153 181
154 # Send authentication data to MFA server 182 # Send authentication data to MFA server
155 data_length = len(data) 183 data_length = len(data)
diff --git a/server/mfad.py b/server/mfad.py
index a2a783b..cc5073b 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -52,6 +52,10 @@ def parse_arguments():
52 parser.add_argument("--address",type=str,help="Bind Address") 52 parser.add_argument("--address",type=str,help="Bind Address")
53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") 53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests")
54 parser.add_argument("--client-port",type=int,help="Port for client connections") 54 parser.add_argument("--client-port",type=int,help="Port for client connections")
55 parser.add_argument("--pam-tls-port",type=int,
56 help="Port to listen for encrypted PAM requests")
57 parser.add_argument("--client-tls-port",type=int,
58 help="Port for encrypted client connections")
55 parser.add_argument("--database",type=str,help="Path to database file") 59 parser.add_argument("--database",type=str,help="Path to database file")
56 parser.add_argument("--cert",type=str,help="TLS certificate file") 60 parser.add_argument("--cert",type=str,help="TLS certificate file")
57 parser.add_argument("--key",type=str,help="TLS private key file") 61 parser.add_argument("--key",type=str,help="TLS private key file")
@@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers):
266 270
267 return context 271 return context
268 272
269def listen_client(db, addr, port, tls_context): 273def listen_client_tls(db, addr, port, tls_context):
270 with socket.create_server((addr, port)) as sock: 274 with socket.create_server((addr, port)) as sock:
271 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 275 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
272 while True: 276 while True:
@@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context):
278 print("client: ssl handshake error") 282 print("client: ssl handshake error")
279 283
280 284
281def listen_pam(db, addr, port, tls_context): 285def listen_pam_tls(db, addr, port, tls_context):
282 with socket.create_server((addr,port)) as sock: 286 with socket.create_server((addr,port)) as sock:
283 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 287 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
284 while True: 288 while True:
@@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context):
290 print("pam: ssl handshake error") 294 print("pam: ssl handshake error")
291 295
292 296
297def listen_client(db, addr, port):
298 with socket.create_server((addr, port)) as sock:
299 while True:
300 conn, addr = sock.accept()
301 thread = threading.Thread(target=handle_client,args=(db,conn,addr))
302 thread.start()
303
304
305def listen_pam(db, addr, port):
306 with socket.create_server((addr, port)) as sock:
307 while True:
308 conn, addr = sock.accept()
309 thread = threading.Thread(target=handle_pam,args=(db,conn,addr))
310 thread.start()
311
312
293################################################################################ 313################################################################################
294 314
295def create_db(db): 315def create_db(db):
@@ -312,12 +332,13 @@ def create_db(db):
312 332
313def get_vars(args,confparser): 333def get_vars(args,confparser):
314 if not os.path.exists(args.config): 334 if not os.path.exists(args.config):
315 print("Unable to open config file") 335 die("Unable to open config file")
316 sys.exit(1)
317 336
318 bind_addr = None 337 bind_addr = None
319 client_port = None 338 client_port = None
339 client_tls_port = None
320 pam_port = None 340 pam_port = None
341 pam_tls_port = None
321 database = None 342 database = None
322 cert = None 343 cert = None
323 key = None 344 key = None
@@ -328,6 +349,8 @@ def get_vars(args,confparser):
328 bind_addr = confparser.get("mfad","address",fallback=None) 349 bind_addr = confparser.get("mfad","address",fallback=None)
329 client_port = confparser.get("mfad","client-port",fallback=None) 350 client_port = confparser.get("mfad","client-port",fallback=None)
330 pam_port = confparser.get("mfad","pam-port",fallback=None) 351 pam_port = confparser.get("mfad","pam-port",fallback=None)
352 client_tls_port = confparser.get("mfad","client-tls-port",fallback=None)
353 pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None)
331 database = confparser.get("mfad","database",fallback=None) 354 database = confparser.get("mfad","database",fallback=None)
332 cert = confparser.get("mfad","cert",fallback=None) 355 cert = confparser.get("mfad","cert",fallback=None)
333 key = confparser.get("mfad","key",fallback=None) 356 key = confparser.get("mfad","key",fallback=None)
@@ -340,6 +363,10 @@ def get_vars(args,confparser):
340 client_port = args.client_port 363 client_port = args.client_port
341 if args.pam_port != None: 364 if args.pam_port != None:
342 pam_port = args.pam_port 365 pam_port = args.pam_port
366 if args.client_tls_port != None:
367 client_tls_port = args.client_tls_port
368 if args.pam_tls_port != None:
369 pam_tls_port = args.pam_tls_port
343 if args.database != None: 370 if args.database != None:
344 database = args.database 371 database = args.database
345 if args.cert != None: 372 if args.cert != None:
@@ -350,28 +377,49 @@ def get_vars(args,confparser):
350 ciphers = args.ciphers 377 ciphers = args.ciphers
351 378
352 # Exit if any value is null 379 # Exit if any value is null
353 if None in [bind_addr,client_port,pam_port,database,cert,key]: 380 if None in [bind_addr,database,cert,key]:
354 print("error: one or more items unspecified") 381 die("error: one or more items unspecified")
355 sys.exit(1) 382
383 ports = [pam_port, client_port, pam_tls_port, client_tls_port]
384 for port in ports:
385 if port == None:
386 ports[ports.index(port)] = 0
356 387
357 return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers 388 return bind_addr, ports, database, cert, key, ciphers
358 389
359 390
360def main(): 391def main():
361 args = parse_arguments() 392 args = parse_arguments()
362 confparser = read_config(args.config) 393 confparser = read_config(args.config)
363 394
364 bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) 395 bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser)
396 pam_port = int(ports[0])
397 client_port = int(ports[1])
398 pam_tls_port = int(ports[2])
399 client_tls_port = int(ports[3])
400
401 if pam_port == 0 and pam_tls_port == 0:
402 die("error: not listening for any PAM connections")
403 if client_port == 0 and client_tls_port == 0:
404 die("error: not listening for any client connections")
405
365 406
366 if not os.path.exists(db): 407 if not os.path.exists(db):
367 print("Creating DB") 408 print("Creating DB")
368 create_db(db) 409 create_db(db)
369 410
370 context = get_tls_context(cert,key,ciphers) 411 context = get_tls_context(cert,key,ciphers)
371 clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) 412
372 pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) 413 if pam_port != 0:
373 clients.start() 414 threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start()
374 pam.start() 415 if client_port != 0:
416 threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start()
417 if pam_tls_port != 0:
418 threading.Thread(target=listen_pam_tls,
419 args=(db,bind_addr,pam_tls_port,context)).start()
420 if client_tls_port != 0:
421 threading.Thread(target=listen_client_tls,
422 args=(db,bind_addr,client_tls_port,context)).start()
375 423
376 424
377if __name__ == '__main__': 425if __name__ == '__main__':