diff options
author | Sam Chudnick <sam@chudnick.com> | 2022-07-04 20:03:27 -0400 |
---|---|---|
committer | Sam Chudnick <sam@chudnick.com> | 2022-07-04 20:03:27 -0400 |
commit | 2e840e7c381f88425952c6fa9d68e0d433084a5a (patch) | |
tree | 31f7888ade33fbc112bd2b7509aac5c39bb2af82 | |
parent | 46564f357c175c7a01a36422307f05b543a83190 (diff) |
Support both TLS encrypted sessions and plaintext sessions
Added support for both TLS and plaintext connections. Server can accept
both types of connection simultaneously or in different combinations
(i.e encrypted client and plaintext PAM). Added options for specifying
dedicated TLS ports on server. Added --plain options for client and PAM
to force plaintext connections, default is to use encrypted connections.
Configuring encrypted client and PAM connections and plaintext server
connections allows for use of a reverse proxy setup with something like
nginx. This will avoid having to expose the MFA server directly in setups
that traverse the internet.
-rwxr-xr-x | client/client.py | 44 | ||||
-rwxr-xr-x | pam/pam_mfa.py | 46 | ||||
-rwxr-xr-x | server/mfad.py | 74 |
3 files changed, 137 insertions, 27 deletions
diff --git a/client/client.py b/client/client.py index cc22d0b..0388073 100755 --- a/client/client.py +++ b/client/client.py | |||
@@ -26,6 +26,7 @@ def parse_arguments(): | |||
26 | parser.add_argument("--config",type=str,help="Path to config file",\ | 26 | parser.add_argument("--config",type=str,help="Path to config file",\ |
27 | default="/etc/mfa/mfa.conf") | 27 | default="/etc/mfa/mfa.conf") |
28 | parser.add_argument("--key",type=str,help="Client connection key") | 28 | parser.add_argument("--key",type=str,help="Client connection key") |
29 | parser.add_argument("--plain",action="store_true",help="Connect without TLS") | ||
29 | parser.add_argument("--insecure",action="store_true",\ | 30 | parser.add_argument("--insecure",action="store_true",\ |
30 | help="Accept invalid TLS certificates") | 31 | help="Accept invalid TLS certificates") |
31 | return parser.parse_args() | 32 | return parser.parse_args() |
@@ -37,7 +38,7 @@ def prompt_user(prompt): | |||
37 | return result | 38 | return result |
38 | 39 | ||
39 | 40 | ||
40 | def init_connection(mfa_server, client_port, client_key, insecure=False): | 41 | def init_connection_tls(mfa_server, client_port, client_key, insecure=False): |
41 | # Attempts to connect to MFA server with provided address,port, and key. | 42 | # Attempts to connect to MFA server with provided address,port, and key. |
42 | # Repeats attempt once a seconds until timeout is reached. | 43 | # Repeats attempt once a seconds until timeout is reached. |
43 | # Returns socket or None if unable to connect | 44 | # Returns socket or None if unable to connect |
@@ -70,6 +71,27 @@ def init_connection(mfa_server, client_port, client_key, insecure=False): | |||
70 | return connection | 71 | return connection |
71 | 72 | ||
72 | 73 | ||
74 | def init_connection(mfa_server, client_port, client_key): | ||
75 | connection = None | ||
76 | timeout = 0 | ||
77 | timeout_length = 5 | ||
78 | sleep_length = 1 | ||
79 | while connection == None and timeout < timeout_length: | ||
80 | try: | ||
81 | connection = socket.create_connection((mfa_server,client_port)) | ||
82 | connection.send(client_key.encode(FORMAT)) | ||
83 | response = connection.recv(ACK_LENGTH).decode(FORMAT) | ||
84 | if response == ACK_MESSAGE: | ||
85 | print("connected to mfa server") | ||
86 | elif response == DISCONNECT_MESSAGE: | ||
87 | print("server terminated connection") | ||
88 | sys.exit(1) | ||
89 | except ConnectionError: | ||
90 | time.sleep(sleep_length) | ||
91 | timeout += sleep_length | ||
92 | return connection | ||
93 | |||
94 | |||
73 | def read_config(config_file): | 95 | def read_config(config_file): |
74 | parser = configparser.ConfigParser(inline_comment_prefixes="#") | 96 | parser = configparser.ConfigParser(inline_comment_prefixes="#") |
75 | parser.read(config_file) | 97 | parser.read(config_file) |
@@ -84,6 +106,7 @@ def get_vars(args,confparser): | |||
84 | server = None | 106 | server = None |
85 | port = None | 107 | port = None |
86 | key = None | 108 | key = None |
109 | plain = None | ||
87 | insecure = None | 110 | insecure = None |
88 | 111 | ||
89 | # Set values from config file first | 112 | # Set values from config file first |
@@ -91,7 +114,13 @@ def get_vars(args,confparser): | |||
91 | server = confparser.get("client","server",fallback=None) | 114 | server = confparser.get("client","server",fallback=None) |
92 | port = confparser.get("client","port",fallback=None) | 115 | port = confparser.get("client","port",fallback=None) |
93 | key = confparser.get("client","key",fallback=None) | 116 | key = confparser.get("client","key",fallback=None) |
94 | insecure = bool(confparser.get("client","insecure",fallback=False)) | 117 | plain = confparser.get("client","plain",fallback=False) |
118 | insecure = confparser.get("client","insecure",fallback=False) | ||
119 | |||
120 | if plain.lower() == "false": | ||
121 | plain = False | ||
122 | if insecure.lower() == "false": | ||
123 | insecure = False | ||
95 | 124 | ||
96 | # Let command line args overwrite any values | 125 | # Let command line args overwrite any values |
97 | if args.server != None: | 126 | if args.server != None: |
@@ -100,6 +129,8 @@ def get_vars(args,confparser): | |||
100 | port = args.port | 129 | port = args.port |
101 | if args.key != None: | 130 | if args.key != None: |
102 | key = args.key | 131 | key = args.key |
132 | if args.plain: | ||
133 | plain = args.plain | ||
103 | if args.insecure: | 134 | if args.insecure: |
104 | insecure = args.insecure | 135 | insecure = args.insecure |
105 | 136 | ||
@@ -108,7 +139,7 @@ def get_vars(args,confparser): | |||
108 | print("error: one or more items unspecified") | 139 | print("error: one or more items unspecified") |
109 | sys.exit(1) | 140 | sys.exit(1) |
110 | 141 | ||
111 | return server,port,key,insecure | 142 | return server,port,key,plain,insecure |
112 | 143 | ||
113 | 144 | ||
114 | def main(): | 145 | def main(): |
@@ -116,7 +147,7 @@ def main(): | |||
116 | args = parse_arguments() | 147 | args = parse_arguments() |
117 | confparser = read_config(args.config) | 148 | confparser = read_config(args.config) |
118 | 149 | ||
119 | mfa_server,client_port,client_key,insecure = get_vars(args,confparser) | 150 | mfa_server,client_port,client_key,plain,insecure = get_vars(args,confparser) |
120 | 151 | ||
121 | # Exit if invalid key is provided | 152 | # Exit if invalid key is provided |
122 | if len(client_key) != KEY_LENGTH: | 153 | if len(client_key) != KEY_LENGTH: |
@@ -124,7 +155,10 @@ def main(): | |||
124 | sys.exit(1) | 155 | sys.exit(1) |
125 | 156 | ||
126 | # Open connection to server | 157 | # Open connection to server |
127 | conn = init_connection(mfa_server,client_port,client_key,insecure) | 158 | if plain: |
159 | conn = init_connection(mfa_server,client_port,client_key) | ||
160 | else: | ||
161 | conn = init_connection_tls(mfa_server,client_port,client_key,insecure) | ||
128 | if conn == None: | 162 | if conn == None: |
129 | print("timed out attempting to connect to server") | 163 | print("timed out attempting to connect to server") |
130 | sys.exit(1) | 164 | sys.exit(1) |
diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py index 5a5a112..a5105a2 100755 --- a/pam/pam_mfa.py +++ b/pam/pam_mfa.py | |||
@@ -34,11 +34,12 @@ def parse_arguments(): | |||
34 | default="/etc/mfa/mfa.conf") | 34 | default="/etc/mfa/mfa.conf") |
35 | parser.add_argument("--server",type=str,help="MFA server address") | 35 | parser.add_argument("--server",type=str,help="MFA server address") |
36 | parser.add_argument("--port",type=str,help="MFA server PAM connection port") | 36 | parser.add_argument("--port",type=str,help="MFA server PAM connection port") |
37 | parser.add_argument("--plain",action="store_true",help="Connect without TLS") | ||
37 | parser.add_argument("--insecure",action="store_true", | 38 | parser.add_argument("--insecure",action="store_true", |
38 | help="Accept invalid TLS certificates") | 39 | help="Accept invalid TLS certificates") |
39 | return parser.parse_args() | 40 | return parser.parse_args() |
40 | 41 | ||
41 | def init_connection(mfa_server, pam_port, insecure): | 42 | def init_connection_tls(mfa_server, pam_port, insecure): |
42 | # Attempts to connect to MFA server with provided address and port | 43 | # Attempts to connect to MFA server with provided address and port |
43 | # Repeats connection attempts once per second until timeout is reached | 44 | # Repeats connection attempts once per second until timeout is reached |
44 | # Returns the socket if connection was successful or None otherwise | 45 | # Returns the socket if connection was successful or None otherwise |
@@ -52,7 +53,6 @@ def init_connection(mfa_server, pam_port, insecure): | |||
52 | context.verify_mode = 0 | 53 | context.verify_mode = 0 |
53 | while connection == None and timeout < timeout_length: | 54 | while connection == None and timeout < timeout_length: |
54 | try: | 55 | try: |
55 | #connection = socket.create_connection((mfa_server,client_port)) | ||
56 | connection = context.wrap_socket(socket.socket(socket.AF_INET), | 56 | connection = context.wrap_socket(socket.socket(socket.AF_INET), |
57 | server_hostname=mfa_server) | 57 | server_hostname=mfa_server) |
58 | connection.connect((mfa_server,int(pam_port))) | 58 | connection.connect((mfa_server,int(pam_port))) |
@@ -65,6 +65,24 @@ def init_connection(mfa_server, pam_port, insecure): | |||
65 | return None | 65 | return None |
66 | 66 | ||
67 | 67 | ||
68 | def init_connection(mfa_server, pam_port): | ||
69 | # Attempts to connect to MFA server with provided address and port | ||
70 | # Repeats connection attempts once per second until timeout is reached | ||
71 | # Returns the socket if connection was successful or None otherwise | ||
72 | connection = None | ||
73 | timeout = 0 | ||
74 | timeout_length = 5 | ||
75 | sleep_length = 1 | ||
76 | while connection == None and timeout < timeout_length: | ||
77 | try: | ||
78 | connection = socket.create_connection((mfa_server,pam_port)) | ||
79 | return connection | ||
80 | except (ConnectionError,ConnectionRefusedError): | ||
81 | time.sleep(sleep_length) | ||
82 | timeout += sleep_length | ||
83 | return None | ||
84 | |||
85 | |||
68 | def read_config(config_file): | 86 | def read_config(config_file): |
69 | # Read config file for server and port info | 87 | # Read config file for server and port info |
70 | # Return tuple (server,port) | 88 | # Return tuple (server,port) |
@@ -94,19 +112,28 @@ def get_vars(args,confparser): | |||
94 | 112 | ||
95 | server = None | 113 | server = None |
96 | port = None | 114 | port = None |
115 | plain = None | ||
97 | insecure = None | 116 | insecure = None |
98 | 117 | ||
99 | # Set values from config file first | 118 | # Set values from config file first |
100 | if confparser.has_section("pam"): | 119 | if confparser.has_section("pam"): |
101 | server = confparser.get("pam","server",fallback=None) | 120 | server = confparser.get("pam","server",fallback=None) |
102 | port = confparser.get("pam","port",fallback=None) | 121 | port = confparser.get("pam","port",fallback=None) |
103 | insecure = bool(confparser.get("pam","insecure",fallback=False)) | 122 | plain = confparser.get("client","plain",fallback=False) |
123 | insecure = confparser.get("client","insecure",fallback=False) | ||
104 | 124 | ||
125 | if plain.lower() == "false": | ||
126 | plain = False | ||
127 | if insecure.lower() == "false": | ||
128 | insecure = False | ||
129 | |||
105 | # Let command line args overwrite any values | 130 | # Let command line args overwrite any values |
106 | if args.server != None: | 131 | if args.server != None: |
107 | server = args.server | 132 | server = args.server |
108 | if args.port != None: | 133 | if args.port != None: |
109 | port = args.port | 134 | port = args.port |
135 | if args.plain: | ||
136 | plain = args.plain | ||
110 | if args.insecure: | 137 | if args.insecure: |
111 | insecure = args.insecure | 138 | insecure = args.insecure |
112 | 139 | ||
@@ -115,7 +142,7 @@ def get_vars(args,confparser): | |||
115 | print("error: one or more items unspecified") | 142 | print("error: one or more items unspecified") |
116 | sys.exit(1) | 143 | sys.exit(1) |
117 | 144 | ||
118 | return server,port,insecure | 145 | return server,port,plain,insecure |
119 | 146 | ||
120 | 147 | ||
121 | def main(): | 148 | def main(): |
@@ -125,7 +152,7 @@ def main(): | |||
125 | # Get arguments | 152 | # Get arguments |
126 | args = parse_arguments() | 153 | args = parse_arguments() |
127 | confparser = read_config(args.config) | 154 | confparser = read_config(args.config) |
128 | mfa_server,pam_port,insecure = get_vars(args,confparser) | 155 | mfa_server,pam_port,plain,insecure = get_vars(args,confparser) |
129 | user = args.user | 156 | user = args.user |
130 | service = args.service | 157 | service = args.service |
131 | 158 | ||
@@ -144,12 +171,13 @@ def main(): | |||
144 | hostname = args.host | 171 | hostname = args.host |
145 | data = user + "," + hostname + "," + service | 172 | data = user + "," + hostname + "," + service |
146 | 173 | ||
147 | |||
148 | # Initalize connection to MFA server. Quit if unable to connect. | 174 | # Initalize connection to MFA server. Quit if unable to connect. |
149 | connection = init_connection(mfa_server,pam_port,insecure) | 175 | if plain: |
176 | connection = init_connection(mfa_server, pam_port) | ||
177 | else: | ||
178 | connection = init_connection_tls(mfa_server,pam_port,insecure) | ||
150 | if connection == None: | 179 | if connection == None: |
151 | print(failed) | 180 | die("failed to connect") |
152 | sys.exit(1) | ||
153 | 181 | ||
154 | # Send authentication data to MFA server | 182 | # Send authentication data to MFA server |
155 | data_length = len(data) | 183 | data_length = len(data) |
diff --git a/server/mfad.py b/server/mfad.py index a2a783b..cc5073b 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
@@ -52,6 +52,10 @@ def parse_arguments(): | |||
52 | parser.add_argument("--address",type=str,help="Bind Address") | 52 | parser.add_argument("--address",type=str,help="Bind Address") |
53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") | 53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") |
54 | parser.add_argument("--client-port",type=int,help="Port for client connections") | 54 | parser.add_argument("--client-port",type=int,help="Port for client connections") |
55 | parser.add_argument("--pam-tls-port",type=int, | ||
56 | help="Port to listen for encrypted PAM requests") | ||
57 | parser.add_argument("--client-tls-port",type=int, | ||
58 | help="Port for encrypted client connections") | ||
55 | parser.add_argument("--database",type=str,help="Path to database file") | 59 | parser.add_argument("--database",type=str,help="Path to database file") |
56 | parser.add_argument("--cert",type=str,help="TLS certificate file") | 60 | parser.add_argument("--cert",type=str,help="TLS certificate file") |
57 | parser.add_argument("--key",type=str,help="TLS private key file") | 61 | parser.add_argument("--key",type=str,help="TLS private key file") |
@@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers): | |||
266 | 270 | ||
267 | return context | 271 | return context |
268 | 272 | ||
269 | def listen_client(db, addr, port, tls_context): | 273 | def listen_client_tls(db, addr, port, tls_context): |
270 | with socket.create_server((addr, port)) as sock: | 274 | with socket.create_server((addr, port)) as sock: |
271 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 275 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
272 | while True: | 276 | while True: |
@@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context): | |||
278 | print("client: ssl handshake error") | 282 | print("client: ssl handshake error") |
279 | 283 | ||
280 | 284 | ||
281 | def listen_pam(db, addr, port, tls_context): | 285 | def listen_pam_tls(db, addr, port, tls_context): |
282 | with socket.create_server((addr,port)) as sock: | 286 | with socket.create_server((addr,port)) as sock: |
283 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 287 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
284 | while True: | 288 | while True: |
@@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context): | |||
290 | print("pam: ssl handshake error") | 294 | print("pam: ssl handshake error") |
291 | 295 | ||
292 | 296 | ||
297 | def listen_client(db, addr, port): | ||
298 | with socket.create_server((addr, port)) as sock: | ||
299 | while True: | ||
300 | conn, addr = sock.accept() | ||
301 | thread = threading.Thread(target=handle_client,args=(db,conn,addr)) | ||
302 | thread.start() | ||
303 | |||
304 | |||
305 | def listen_pam(db, addr, port): | ||
306 | with socket.create_server((addr, port)) as sock: | ||
307 | while True: | ||
308 | conn, addr = sock.accept() | ||
309 | thread = threading.Thread(target=handle_pam,args=(db,conn,addr)) | ||
310 | thread.start() | ||
311 | |||
312 | |||
293 | ################################################################################ | 313 | ################################################################################ |
294 | 314 | ||
295 | def create_db(db): | 315 | def create_db(db): |
@@ -312,12 +332,13 @@ def create_db(db): | |||
312 | 332 | ||
313 | def get_vars(args,confparser): | 333 | def get_vars(args,confparser): |
314 | if not os.path.exists(args.config): | 334 | if not os.path.exists(args.config): |
315 | print("Unable to open config file") | 335 | die("Unable to open config file") |
316 | sys.exit(1) | ||
317 | 336 | ||
318 | bind_addr = None | 337 | bind_addr = None |
319 | client_port = None | 338 | client_port = None |
339 | client_tls_port = None | ||
320 | pam_port = None | 340 | pam_port = None |
341 | pam_tls_port = None | ||
321 | database = None | 342 | database = None |
322 | cert = None | 343 | cert = None |
323 | key = None | 344 | key = None |
@@ -328,6 +349,8 @@ def get_vars(args,confparser): | |||
328 | bind_addr = confparser.get("mfad","address",fallback=None) | 349 | bind_addr = confparser.get("mfad","address",fallback=None) |
329 | client_port = confparser.get("mfad","client-port",fallback=None) | 350 | client_port = confparser.get("mfad","client-port",fallback=None) |
330 | pam_port = confparser.get("mfad","pam-port",fallback=None) | 351 | pam_port = confparser.get("mfad","pam-port",fallback=None) |
352 | client_tls_port = confparser.get("mfad","client-tls-port",fallback=None) | ||
353 | pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None) | ||
331 | database = confparser.get("mfad","database",fallback=None) | 354 | database = confparser.get("mfad","database",fallback=None) |
332 | cert = confparser.get("mfad","cert",fallback=None) | 355 | cert = confparser.get("mfad","cert",fallback=None) |
333 | key = confparser.get("mfad","key",fallback=None) | 356 | key = confparser.get("mfad","key",fallback=None) |
@@ -340,6 +363,10 @@ def get_vars(args,confparser): | |||
340 | client_port = args.client_port | 363 | client_port = args.client_port |
341 | if args.pam_port != None: | 364 | if args.pam_port != None: |
342 | pam_port = args.pam_port | 365 | pam_port = args.pam_port |
366 | if args.client_tls_port != None: | ||
367 | client_tls_port = args.client_tls_port | ||
368 | if args.pam_tls_port != None: | ||
369 | pam_tls_port = args.pam_tls_port | ||
343 | if args.database != None: | 370 | if args.database != None: |
344 | database = args.database | 371 | database = args.database |
345 | if args.cert != None: | 372 | if args.cert != None: |
@@ -350,28 +377,49 @@ def get_vars(args,confparser): | |||
350 | ciphers = args.ciphers | 377 | ciphers = args.ciphers |
351 | 378 | ||
352 | # Exit if any value is null | 379 | # Exit if any value is null |
353 | if None in [bind_addr,client_port,pam_port,database,cert,key]: | 380 | if None in [bind_addr,database,cert,key]: |
354 | print("error: one or more items unspecified") | 381 | die("error: one or more items unspecified") |
355 | sys.exit(1) | 382 | |
383 | ports = [pam_port, client_port, pam_tls_port, client_tls_port] | ||
384 | for port in ports: | ||
385 | if port == None: | ||
386 | ports[ports.index(port)] = 0 | ||
356 | 387 | ||
357 | return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers | 388 | return bind_addr, ports, database, cert, key, ciphers |
358 | 389 | ||
359 | 390 | ||
360 | def main(): | 391 | def main(): |
361 | args = parse_arguments() | 392 | args = parse_arguments() |
362 | confparser = read_config(args.config) | 393 | confparser = read_config(args.config) |
363 | 394 | ||
364 | bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) | 395 | bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser) |
396 | pam_port = int(ports[0]) | ||
397 | client_port = int(ports[1]) | ||
398 | pam_tls_port = int(ports[2]) | ||
399 | client_tls_port = int(ports[3]) | ||
400 | |||
401 | if pam_port == 0 and pam_tls_port == 0: | ||
402 | die("error: not listening for any PAM connections") | ||
403 | if client_port == 0 and client_tls_port == 0: | ||
404 | die("error: not listening for any client connections") | ||
405 | |||
365 | 406 | ||
366 | if not os.path.exists(db): | 407 | if not os.path.exists(db): |
367 | print("Creating DB") | 408 | print("Creating DB") |
368 | create_db(db) | 409 | create_db(db) |
369 | 410 | ||
370 | context = get_tls_context(cert,key,ciphers) | 411 | context = get_tls_context(cert,key,ciphers) |
371 | clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) | 412 | |
372 | pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) | 413 | if pam_port != 0: |
373 | clients.start() | 414 | threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start() |
374 | pam.start() | 415 | if client_port != 0: |
416 | threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start() | ||
417 | if pam_tls_port != 0: | ||
418 | threading.Thread(target=listen_pam_tls, | ||
419 | args=(db,bind_addr,pam_tls_port,context)).start() | ||
420 | if client_tls_port != 0: | ||
421 | threading.Thread(target=listen_client_tls, | ||
422 | args=(db,bind_addr,client_tls_port,context)).start() | ||
375 | 423 | ||
376 | 424 | ||
377 | if __name__ == '__main__': | 425 | if __name__ == '__main__': |