diff options
| -rwxr-xr-x | client/client.py | 44 | ||||
| -rwxr-xr-x | pam/pam_mfa.py | 46 | ||||
| -rwxr-xr-x | server/mfad.py | 74 |
3 files changed, 137 insertions, 27 deletions
diff --git a/client/client.py b/client/client.py index cc22d0b..0388073 100755 --- a/client/client.py +++ b/client/client.py | |||
| @@ -26,6 +26,7 @@ def parse_arguments(): | |||
| 26 | parser.add_argument("--config",type=str,help="Path to config file",\ | 26 | parser.add_argument("--config",type=str,help="Path to config file",\ |
| 27 | default="/etc/mfa/mfa.conf") | 27 | default="/etc/mfa/mfa.conf") |
| 28 | parser.add_argument("--key",type=str,help="Client connection key") | 28 | parser.add_argument("--key",type=str,help="Client connection key") |
| 29 | parser.add_argument("--plain",action="store_true",help="Connect without TLS") | ||
| 29 | parser.add_argument("--insecure",action="store_true",\ | 30 | parser.add_argument("--insecure",action="store_true",\ |
| 30 | help="Accept invalid TLS certificates") | 31 | help="Accept invalid TLS certificates") |
| 31 | return parser.parse_args() | 32 | return parser.parse_args() |
| @@ -37,7 +38,7 @@ def prompt_user(prompt): | |||
| 37 | return result | 38 | return result |
| 38 | 39 | ||
| 39 | 40 | ||
| 40 | def init_connection(mfa_server, client_port, client_key, insecure=False): | 41 | def init_connection_tls(mfa_server, client_port, client_key, insecure=False): |
| 41 | # Attempts to connect to MFA server with provided address,port, and key. | 42 | # Attempts to connect to MFA server with provided address,port, and key. |
| 42 | # Repeats attempt once a seconds until timeout is reached. | 43 | # Repeats attempt once a seconds until timeout is reached. |
| 43 | # Returns socket or None if unable to connect | 44 | # Returns socket or None if unable to connect |
| @@ -70,6 +71,27 @@ def init_connection(mfa_server, client_port, client_key, insecure=False): | |||
| 70 | return connection | 71 | return connection |
| 71 | 72 | ||
| 72 | 73 | ||
| 74 | def init_connection(mfa_server, client_port, client_key): | ||
| 75 | connection = None | ||
| 76 | timeout = 0 | ||
| 77 | timeout_length = 5 | ||
| 78 | sleep_length = 1 | ||
| 79 | while connection == None and timeout < timeout_length: | ||
| 80 | try: | ||
| 81 | connection = socket.create_connection((mfa_server,client_port)) | ||
| 82 | connection.send(client_key.encode(FORMAT)) | ||
| 83 | response = connection.recv(ACK_LENGTH).decode(FORMAT) | ||
| 84 | if response == ACK_MESSAGE: | ||
| 85 | print("connected to mfa server") | ||
| 86 | elif response == DISCONNECT_MESSAGE: | ||
| 87 | print("server terminated connection") | ||
| 88 | sys.exit(1) | ||
| 89 | except ConnectionError: | ||
| 90 | time.sleep(sleep_length) | ||
| 91 | timeout += sleep_length | ||
| 92 | return connection | ||
| 93 | |||
| 94 | |||
| 73 | def read_config(config_file): | 95 | def read_config(config_file): |
| 74 | parser = configparser.ConfigParser(inline_comment_prefixes="#") | 96 | parser = configparser.ConfigParser(inline_comment_prefixes="#") |
| 75 | parser.read(config_file) | 97 | parser.read(config_file) |
| @@ -84,6 +106,7 @@ def get_vars(args,confparser): | |||
| 84 | server = None | 106 | server = None |
| 85 | port = None | 107 | port = None |
| 86 | key = None | 108 | key = None |
| 109 | plain = None | ||
| 87 | insecure = None | 110 | insecure = None |
| 88 | 111 | ||
| 89 | # Set values from config file first | 112 | # Set values from config file first |
| @@ -91,7 +114,13 @@ def get_vars(args,confparser): | |||
| 91 | server = confparser.get("client","server",fallback=None) | 114 | server = confparser.get("client","server",fallback=None) |
| 92 | port = confparser.get("client","port",fallback=None) | 115 | port = confparser.get("client","port",fallback=None) |
| 93 | key = confparser.get("client","key",fallback=None) | 116 | key = confparser.get("client","key",fallback=None) |
| 94 | insecure = bool(confparser.get("client","insecure",fallback=False)) | 117 | plain = confparser.get("client","plain",fallback=False) |
| 118 | insecure = confparser.get("client","insecure",fallback=False) | ||
| 119 | |||
| 120 | if plain.lower() == "false": | ||
| 121 | plain = False | ||
| 122 | if insecure.lower() == "false": | ||
| 123 | insecure = False | ||
| 95 | 124 | ||
| 96 | # Let command line args overwrite any values | 125 | # Let command line args overwrite any values |
| 97 | if args.server != None: | 126 | if args.server != None: |
| @@ -100,6 +129,8 @@ def get_vars(args,confparser): | |||
| 100 | port = args.port | 129 | port = args.port |
| 101 | if args.key != None: | 130 | if args.key != None: |
| 102 | key = args.key | 131 | key = args.key |
| 132 | if args.plain: | ||
| 133 | plain = args.plain | ||
| 103 | if args.insecure: | 134 | if args.insecure: |
| 104 | insecure = args.insecure | 135 | insecure = args.insecure |
| 105 | 136 | ||
| @@ -108,7 +139,7 @@ def get_vars(args,confparser): | |||
| 108 | print("error: one or more items unspecified") | 139 | print("error: one or more items unspecified") |
| 109 | sys.exit(1) | 140 | sys.exit(1) |
| 110 | 141 | ||
| 111 | return server,port,key,insecure | 142 | return server,port,key,plain,insecure |
| 112 | 143 | ||
| 113 | 144 | ||
| 114 | def main(): | 145 | def main(): |
| @@ -116,7 +147,7 @@ def main(): | |||
| 116 | args = parse_arguments() | 147 | args = parse_arguments() |
| 117 | confparser = read_config(args.config) | 148 | confparser = read_config(args.config) |
| 118 | 149 | ||
| 119 | mfa_server,client_port,client_key,insecure = get_vars(args,confparser) | 150 | mfa_server,client_port,client_key,plain,insecure = get_vars(args,confparser) |
| 120 | 151 | ||
| 121 | # Exit if invalid key is provided | 152 | # Exit if invalid key is provided |
| 122 | if len(client_key) != KEY_LENGTH: | 153 | if len(client_key) != KEY_LENGTH: |
| @@ -124,7 +155,10 @@ def main(): | |||
| 124 | sys.exit(1) | 155 | sys.exit(1) |
| 125 | 156 | ||
| 126 | # Open connection to server | 157 | # Open connection to server |
| 127 | conn = init_connection(mfa_server,client_port,client_key,insecure) | 158 | if plain: |
| 159 | conn = init_connection(mfa_server,client_port,client_key) | ||
| 160 | else: | ||
| 161 | conn = init_connection_tls(mfa_server,client_port,client_key,insecure) | ||
| 128 | if conn == None: | 162 | if conn == None: |
| 129 | print("timed out attempting to connect to server") | 163 | print("timed out attempting to connect to server") |
| 130 | sys.exit(1) | 164 | sys.exit(1) |
diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py index 5a5a112..a5105a2 100755 --- a/pam/pam_mfa.py +++ b/pam/pam_mfa.py | |||
| @@ -34,11 +34,12 @@ def parse_arguments(): | |||
| 34 | default="/etc/mfa/mfa.conf") | 34 | default="/etc/mfa/mfa.conf") |
| 35 | parser.add_argument("--server",type=str,help="MFA server address") | 35 | parser.add_argument("--server",type=str,help="MFA server address") |
| 36 | parser.add_argument("--port",type=str,help="MFA server PAM connection port") | 36 | parser.add_argument("--port",type=str,help="MFA server PAM connection port") |
| 37 | parser.add_argument("--plain",action="store_true",help="Connect without TLS") | ||
| 37 | parser.add_argument("--insecure",action="store_true", | 38 | parser.add_argument("--insecure",action="store_true", |
| 38 | help="Accept invalid TLS certificates") | 39 | help="Accept invalid TLS certificates") |
| 39 | return parser.parse_args() | 40 | return parser.parse_args() |
| 40 | 41 | ||
| 41 | def init_connection(mfa_server, pam_port, insecure): | 42 | def init_connection_tls(mfa_server, pam_port, insecure): |
| 42 | # Attempts to connect to MFA server with provided address and port | 43 | # Attempts to connect to MFA server with provided address and port |
| 43 | # Repeats connection attempts once per second until timeout is reached | 44 | # Repeats connection attempts once per second until timeout is reached |
| 44 | # Returns the socket if connection was successful or None otherwise | 45 | # Returns the socket if connection was successful or None otherwise |
| @@ -52,7 +53,6 @@ def init_connection(mfa_server, pam_port, insecure): | |||
| 52 | context.verify_mode = 0 | 53 | context.verify_mode = 0 |
| 53 | while connection == None and timeout < timeout_length: | 54 | while connection == None and timeout < timeout_length: |
| 54 | try: | 55 | try: |
| 55 | #connection = socket.create_connection((mfa_server,client_port)) | ||
| 56 | connection = context.wrap_socket(socket.socket(socket.AF_INET), | 56 | connection = context.wrap_socket(socket.socket(socket.AF_INET), |
| 57 | server_hostname=mfa_server) | 57 | server_hostname=mfa_server) |
| 58 | connection.connect((mfa_server,int(pam_port))) | 58 | connection.connect((mfa_server,int(pam_port))) |
| @@ -65,6 +65,24 @@ def init_connection(mfa_server, pam_port, insecure): | |||
| 65 | return None | 65 | return None |
| 66 | 66 | ||
| 67 | 67 | ||
| 68 | def init_connection(mfa_server, pam_port): | ||
| 69 | # Attempts to connect to MFA server with provided address and port | ||
| 70 | # Repeats connection attempts once per second until timeout is reached | ||
| 71 | # Returns the socket if connection was successful or None otherwise | ||
| 72 | connection = None | ||
| 73 | timeout = 0 | ||
| 74 | timeout_length = 5 | ||
| 75 | sleep_length = 1 | ||
| 76 | while connection == None and timeout < timeout_length: | ||
| 77 | try: | ||
| 78 | connection = socket.create_connection((mfa_server,pam_port)) | ||
| 79 | return connection | ||
| 80 | except (ConnectionError,ConnectionRefusedError): | ||
| 81 | time.sleep(sleep_length) | ||
| 82 | timeout += sleep_length | ||
| 83 | return None | ||
| 84 | |||
| 85 | |||
| 68 | def read_config(config_file): | 86 | def read_config(config_file): |
| 69 | # Read config file for server and port info | 87 | # Read config file for server and port info |
| 70 | # Return tuple (server,port) | 88 | # Return tuple (server,port) |
| @@ -94,19 +112,28 @@ def get_vars(args,confparser): | |||
| 94 | 112 | ||
| 95 | server = None | 113 | server = None |
| 96 | port = None | 114 | port = None |
| 115 | plain = None | ||
| 97 | insecure = None | 116 | insecure = None |
| 98 | 117 | ||
| 99 | # Set values from config file first | 118 | # Set values from config file first |
| 100 | if confparser.has_section("pam"): | 119 | if confparser.has_section("pam"): |
| 101 | server = confparser.get("pam","server",fallback=None) | 120 | server = confparser.get("pam","server",fallback=None) |
| 102 | port = confparser.get("pam","port",fallback=None) | 121 | port = confparser.get("pam","port",fallback=None) |
| 103 | insecure = bool(confparser.get("pam","insecure",fallback=False)) | 122 | plain = confparser.get("client","plain",fallback=False) |
| 123 | insecure = confparser.get("client","insecure",fallback=False) | ||
| 104 | 124 | ||
| 125 | if plain.lower() == "false": | ||
| 126 | plain = False | ||
| 127 | if insecure.lower() == "false": | ||
| 128 | insecure = False | ||
| 129 | |||
| 105 | # Let command line args overwrite any values | 130 | # Let command line args overwrite any values |
| 106 | if args.server != None: | 131 | if args.server != None: |
| 107 | server = args.server | 132 | server = args.server |
| 108 | if args.port != None: | 133 | if args.port != None: |
| 109 | port = args.port | 134 | port = args.port |
| 135 | if args.plain: | ||
| 136 | plain = args.plain | ||
| 110 | if args.insecure: | 137 | if args.insecure: |
| 111 | insecure = args.insecure | 138 | insecure = args.insecure |
| 112 | 139 | ||
| @@ -115,7 +142,7 @@ def get_vars(args,confparser): | |||
| 115 | print("error: one or more items unspecified") | 142 | print("error: one or more items unspecified") |
| 116 | sys.exit(1) | 143 | sys.exit(1) |
| 117 | 144 | ||
| 118 | return server,port,insecure | 145 | return server,port,plain,insecure |
| 119 | 146 | ||
| 120 | 147 | ||
| 121 | def main(): | 148 | def main(): |
| @@ -125,7 +152,7 @@ def main(): | |||
| 125 | # Get arguments | 152 | # Get arguments |
| 126 | args = parse_arguments() | 153 | args = parse_arguments() |
| 127 | confparser = read_config(args.config) | 154 | confparser = read_config(args.config) |
| 128 | mfa_server,pam_port,insecure = get_vars(args,confparser) | 155 | mfa_server,pam_port,plain,insecure = get_vars(args,confparser) |
| 129 | user = args.user | 156 | user = args.user |
| 130 | service = args.service | 157 | service = args.service |
| 131 | 158 | ||
| @@ -144,12 +171,13 @@ def main(): | |||
| 144 | hostname = args.host | 171 | hostname = args.host |
| 145 | data = user + "," + hostname + "," + service | 172 | data = user + "," + hostname + "," + service |
| 146 | 173 | ||
| 147 | |||
| 148 | # Initalize connection to MFA server. Quit if unable to connect. | 174 | # Initalize connection to MFA server. Quit if unable to connect. |
| 149 | connection = init_connection(mfa_server,pam_port,insecure) | 175 | if plain: |
| 176 | connection = init_connection(mfa_server, pam_port) | ||
| 177 | else: | ||
| 178 | connection = init_connection_tls(mfa_server,pam_port,insecure) | ||
| 150 | if connection == None: | 179 | if connection == None: |
| 151 | print(failed) | 180 | die("failed to connect") |
| 152 | sys.exit(1) | ||
| 153 | 181 | ||
| 154 | # Send authentication data to MFA server | 182 | # Send authentication data to MFA server |
| 155 | data_length = len(data) | 183 | data_length = len(data) |
diff --git a/server/mfad.py b/server/mfad.py index a2a783b..cc5073b 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
| @@ -52,6 +52,10 @@ def parse_arguments(): | |||
| 52 | parser.add_argument("--address",type=str,help="Bind Address") | 52 | parser.add_argument("--address",type=str,help="Bind Address") |
| 53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") | 53 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") |
| 54 | parser.add_argument("--client-port",type=int,help="Port for client connections") | 54 | parser.add_argument("--client-port",type=int,help="Port for client connections") |
| 55 | parser.add_argument("--pam-tls-port",type=int, | ||
| 56 | help="Port to listen for encrypted PAM requests") | ||
| 57 | parser.add_argument("--client-tls-port",type=int, | ||
| 58 | help="Port for encrypted client connections") | ||
| 55 | parser.add_argument("--database",type=str,help="Path to database file") | 59 | parser.add_argument("--database",type=str,help="Path to database file") |
| 56 | parser.add_argument("--cert",type=str,help="TLS certificate file") | 60 | parser.add_argument("--cert",type=str,help="TLS certificate file") |
| 57 | parser.add_argument("--key",type=str,help="TLS private key file") | 61 | parser.add_argument("--key",type=str,help="TLS private key file") |
| @@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers): | |||
| 266 | 270 | ||
| 267 | return context | 271 | return context |
| 268 | 272 | ||
| 269 | def listen_client(db, addr, port, tls_context): | 273 | def listen_client_tls(db, addr, port, tls_context): |
| 270 | with socket.create_server((addr, port)) as sock: | 274 | with socket.create_server((addr, port)) as sock: |
| 271 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 275 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
| 272 | while True: | 276 | while True: |
| @@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context): | |||
| 278 | print("client: ssl handshake error") | 282 | print("client: ssl handshake error") |
| 279 | 283 | ||
| 280 | 284 | ||
| 281 | def listen_pam(db, addr, port, tls_context): | 285 | def listen_pam_tls(db, addr, port, tls_context): |
| 282 | with socket.create_server((addr,port)) as sock: | 286 | with socket.create_server((addr,port)) as sock: |
| 283 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: | 287 | with tls_context.wrap_socket(sock, server_side=True) as tls_socket: |
| 284 | while True: | 288 | while True: |
| @@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context): | |||
| 290 | print("pam: ssl handshake error") | 294 | print("pam: ssl handshake error") |
| 291 | 295 | ||
| 292 | 296 | ||
| 297 | def listen_client(db, addr, port): | ||
| 298 | with socket.create_server((addr, port)) as sock: | ||
| 299 | while True: | ||
| 300 | conn, addr = sock.accept() | ||
| 301 | thread = threading.Thread(target=handle_client,args=(db,conn,addr)) | ||
| 302 | thread.start() | ||
| 303 | |||
| 304 | |||
| 305 | def listen_pam(db, addr, port): | ||
| 306 | with socket.create_server((addr, port)) as sock: | ||
| 307 | while True: | ||
| 308 | conn, addr = sock.accept() | ||
| 309 | thread = threading.Thread(target=handle_pam,args=(db,conn,addr)) | ||
| 310 | thread.start() | ||
| 311 | |||
| 312 | |||
| 293 | ################################################################################ | 313 | ################################################################################ |
| 294 | 314 | ||
| 295 | def create_db(db): | 315 | def create_db(db): |
| @@ -312,12 +332,13 @@ def create_db(db): | |||
| 312 | 332 | ||
| 313 | def get_vars(args,confparser): | 333 | def get_vars(args,confparser): |
| 314 | if not os.path.exists(args.config): | 334 | if not os.path.exists(args.config): |
| 315 | print("Unable to open config file") | 335 | die("Unable to open config file") |
| 316 | sys.exit(1) | ||
| 317 | 336 | ||
| 318 | bind_addr = None | 337 | bind_addr = None |
| 319 | client_port = None | 338 | client_port = None |
| 339 | client_tls_port = None | ||
| 320 | pam_port = None | 340 | pam_port = None |
| 341 | pam_tls_port = None | ||
| 321 | database = None | 342 | database = None |
| 322 | cert = None | 343 | cert = None |
| 323 | key = None | 344 | key = None |
| @@ -328,6 +349,8 @@ def get_vars(args,confparser): | |||
| 328 | bind_addr = confparser.get("mfad","address",fallback=None) | 349 | bind_addr = confparser.get("mfad","address",fallback=None) |
| 329 | client_port = confparser.get("mfad","client-port",fallback=None) | 350 | client_port = confparser.get("mfad","client-port",fallback=None) |
| 330 | pam_port = confparser.get("mfad","pam-port",fallback=None) | 351 | pam_port = confparser.get("mfad","pam-port",fallback=None) |
| 352 | client_tls_port = confparser.get("mfad","client-tls-port",fallback=None) | ||
| 353 | pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None) | ||
| 331 | database = confparser.get("mfad","database",fallback=None) | 354 | database = confparser.get("mfad","database",fallback=None) |
| 332 | cert = confparser.get("mfad","cert",fallback=None) | 355 | cert = confparser.get("mfad","cert",fallback=None) |
| 333 | key = confparser.get("mfad","key",fallback=None) | 356 | key = confparser.get("mfad","key",fallback=None) |
| @@ -340,6 +363,10 @@ def get_vars(args,confparser): | |||
| 340 | client_port = args.client_port | 363 | client_port = args.client_port |
| 341 | if args.pam_port != None: | 364 | if args.pam_port != None: |
| 342 | pam_port = args.pam_port | 365 | pam_port = args.pam_port |
| 366 | if args.client_tls_port != None: | ||
| 367 | client_tls_port = args.client_tls_port | ||
| 368 | if args.pam_tls_port != None: | ||
| 369 | pam_tls_port = args.pam_tls_port | ||
| 343 | if args.database != None: | 370 | if args.database != None: |
| 344 | database = args.database | 371 | database = args.database |
| 345 | if args.cert != None: | 372 | if args.cert != None: |
| @@ -350,28 +377,49 @@ def get_vars(args,confparser): | |||
| 350 | ciphers = args.ciphers | 377 | ciphers = args.ciphers |
| 351 | 378 | ||
| 352 | # Exit if any value is null | 379 | # Exit if any value is null |
| 353 | if None in [bind_addr,client_port,pam_port,database,cert,key]: | 380 | if None in [bind_addr,database,cert,key]: |
| 354 | print("error: one or more items unspecified") | 381 | die("error: one or more items unspecified") |
| 355 | sys.exit(1) | 382 | |
| 383 | ports = [pam_port, client_port, pam_tls_port, client_tls_port] | ||
| 384 | for port in ports: | ||
| 385 | if port == None: | ||
| 386 | ports[ports.index(port)] = 0 | ||
| 356 | 387 | ||
| 357 | return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers | 388 | return bind_addr, ports, database, cert, key, ciphers |
| 358 | 389 | ||
| 359 | 390 | ||
| 360 | def main(): | 391 | def main(): |
| 361 | args = parse_arguments() | 392 | args = parse_arguments() |
| 362 | confparser = read_config(args.config) | 393 | confparser = read_config(args.config) |
| 363 | 394 | ||
| 364 | bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) | 395 | bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser) |
| 396 | pam_port = int(ports[0]) | ||
| 397 | client_port = int(ports[1]) | ||
| 398 | pam_tls_port = int(ports[2]) | ||
| 399 | client_tls_port = int(ports[3]) | ||
| 400 | |||
| 401 | if pam_port == 0 and pam_tls_port == 0: | ||
| 402 | die("error: not listening for any PAM connections") | ||
| 403 | if client_port == 0 and client_tls_port == 0: | ||
| 404 | die("error: not listening for any client connections") | ||
| 405 | |||
| 365 | 406 | ||
| 366 | if not os.path.exists(db): | 407 | if not os.path.exists(db): |
| 367 | print("Creating DB") | 408 | print("Creating DB") |
| 368 | create_db(db) | 409 | create_db(db) |
| 369 | 410 | ||
| 370 | context = get_tls_context(cert,key,ciphers) | 411 | context = get_tls_context(cert,key,ciphers) |
| 371 | clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) | 412 | |
| 372 | pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) | 413 | if pam_port != 0: |
| 373 | clients.start() | 414 | threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start() |
| 374 | pam.start() | 415 | if client_port != 0: |
| 416 | threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start() | ||
| 417 | if pam_tls_port != 0: | ||
| 418 | threading.Thread(target=listen_pam_tls, | ||
| 419 | args=(db,bind_addr,pam_tls_port,context)).start() | ||
| 420 | if client_tls_port != 0: | ||
| 421 | threading.Thread(target=listen_client_tls, | ||
| 422 | args=(db,bind_addr,client_tls_port,context)).start() | ||
| 375 | 423 | ||
| 376 | 424 | ||
| 377 | if __name__ == '__main__': | 425 | if __name__ == '__main__': |
