summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xclient/client.py44
-rwxr-xr-xpam/pam_mfa.py46
-rwxr-xr-xserver/mfad.py74
3 files changed, 137 insertions, 27 deletions
diff --git a/client/client.py b/client/client.py
index cc22d0b..0388073 100755
--- a/client/client.py
+++ b/client/client.py
@@ -26,6 +26,7 @@ def parse_arguments():
26 parser.add_argument("--config",type=str,help="Path to config file",\ 26 parser.add_argument("--config",type=str,help="Path to config file",\
27 default="/etc/mfa/mfa.conf") 27 default="/etc/mfa/mfa.conf")
28 parser.add_argument("--key",type=str,help="Client connection key") 28 parser.add_argument("--key",type=str,help="Client connection key")
29 parser.add_argument("--plain",action="store_true",help="Connect without TLS")
29 parser.add_argument("--insecure",action="store_true",\ 30 parser.add_argument("--insecure",action="store_true",\
30 help="Accept invalid TLS certificates") 31 help="Accept invalid TLS certificates")
31 return parser.parse_args() 32 return parser.parse_args()
@@ -37,7 +38,7 @@ def prompt_user(prompt):
37 return result 38 return result
38 39
39 40
40def init_connection(mfa_server, client_port, client_key, insecure=False): 41def init_connection_tls(mfa_server, client_port, client_key, insecure=False):
41 # Attempts to connect to MFA server with provided address,port, and key. 42 # Attempts to connect to MFA server with provided address,port, and key.
42 # Repeats attempt once a seconds until timeout is reached. 43 # Repeats attempt once a seconds until timeout is reached.
43 # Returns socket or None if unable to connect 44 # Returns socket or None if unable to connect
@@ -70,6 +71,27 @@ def init_connection(mfa_server, client_port, client_key, insecure=False):
70 return connection 71 return connection
71 72
72 73
74def init_connection(mfa_server, client_port, client_key):
75 connection = None
76 timeout = 0
77 timeout_length = 5
78 sleep_length = 1
79 while connection == None and timeout < timeout_length:
80 try:
81 connection = socket.create_connection((mfa_server,client_port))
82 connection.send(client_key.encode(FORMAT))
83 response = connection.recv(ACK_LENGTH).decode(FORMAT)
84 if response == ACK_MESSAGE:
85 print("connected to mfa server")
86 elif response == DISCONNECT_MESSAGE:
87 print("server terminated connection")
88 sys.exit(1)
89 except ConnectionError:
90 time.sleep(sleep_length)
91 timeout += sleep_length
92 return connection
93
94
73def read_config(config_file): 95def read_config(config_file):
74 parser = configparser.ConfigParser(inline_comment_prefixes="#") 96 parser = configparser.ConfigParser(inline_comment_prefixes="#")
75 parser.read(config_file) 97 parser.read(config_file)
@@ -84,6 +106,7 @@ def get_vars(args,confparser):
84 server = None 106 server = None
85 port = None 107 port = None
86 key = None 108 key = None
109 plain = None
87 insecure = None 110 insecure = None
88 111
89 # Set values from config file first 112 # Set values from config file first
@@ -91,7 +114,13 @@ def get_vars(args,confparser):
91 server = confparser.get("client","server",fallback=None) 114 server = confparser.get("client","server",fallback=None)
92 port = confparser.get("client","port",fallback=None) 115 port = confparser.get("client","port",fallback=None)
93 key = confparser.get("client","key",fallback=None) 116 key = confparser.get("client","key",fallback=None)
94 insecure = bool(confparser.get("client","insecure",fallback=False)) 117 plain = confparser.get("client","plain",fallback=False)
118 insecure = confparser.get("client","insecure",fallback=False)
119
120 if plain.lower() == "false":
121 plain = False
122 if insecure.lower() == "false":
123 insecure = False
95 124
96 # Let command line args overwrite any values 125 # Let command line args overwrite any values
97 if args.server != None: 126 if args.server != None:
@@ -100,6 +129,8 @@ def get_vars(args,confparser):
100 port = args.port 129 port = args.port
101 if args.key != None: 130 if args.key != None:
102 key = args.key 131 key = args.key
132 if args.plain:
133 plain = args.plain
103 if args.insecure: 134 if args.insecure:
104 insecure = args.insecure 135 insecure = args.insecure
105 136
@@ -108,7 +139,7 @@ def get_vars(args,confparser):
108 print("error: one or more items unspecified") 139 print("error: one or more items unspecified")
109 sys.exit(1) 140 sys.exit(1)
110 141
111 return server,port,key,insecure 142 return server,port,key,plain,insecure
112 143
113 144
114def main(): 145def main():
@@ -116,7 +147,7 @@ def main():
116 args = parse_arguments() 147 args = parse_arguments()
117 confparser = read_config(args.config) 148 confparser = read_config(args.config)
118 149
119 mfa_server,client_port,client_key,insecure = get_vars(args,confparser) 150 mfa_server,client_port,client_key,plain,insecure = get_vars(args,confparser)
120 151
121 # Exit if invalid key is provided 152 # Exit if invalid key is provided
122 if len(client_key) != KEY_LENGTH: 153 if len(client_key) != KEY_LENGTH:
@@ -124,7 +155,10 @@ def main():
124 sys.exit(1) 155 sys.exit(1)
125 156
126 # Open connection to server 157 # Open connection to server
127 conn = init_connection(mfa_server,client_port,client_key,insecure) 158 if plain:
159 conn = init_connection(mfa_server,client_port,client_key)
160 else:
161 conn = init_connection_tls(mfa_server,client_port,client_key,insecure)
128 if conn == None: 162 if conn == None:
129 print("timed out attempting to connect to server") 163 print("timed out attempting to connect to server")
130 sys.exit(1) 164 sys.exit(1)
diff --git a/pam/pam_mfa.py b/pam/pam_mfa.py
index 5a5a112..a5105a2 100755
--- a/pam/pam_mfa.py
+++ b/pam/pam_mfa.py
@@ -34,11 +34,12 @@ def parse_arguments():
34 default="/etc/mfa/mfa.conf") 34 default="/etc/mfa/mfa.conf")
35 parser.add_argument("--server",type=str,help="MFA server address") 35 parser.add_argument("--server",type=str,help="MFA server address")
36 parser.add_argument("--port",type=str,help="MFA server PAM connection port") 36 parser.add_argument("--port",type=str,help="MFA server PAM connection port")
37 parser.add_argument("--plain",action="store_true",help="Connect without TLS")
37 parser.add_argument("--insecure",action="store_true", 38 parser.add_argument("--insecure",action="store_true",
38 help="Accept invalid TLS certificates") 39 help="Accept invalid TLS certificates")
39 return parser.parse_args() 40 return parser.parse_args()
40 41
41def init_connection(mfa_server, pam_port, insecure): 42def init_connection_tls(mfa_server, pam_port, insecure):
42 # Attempts to connect to MFA server with provided address and port 43 # Attempts to connect to MFA server with provided address and port
43 # Repeats connection attempts once per second until timeout is reached 44 # Repeats connection attempts once per second until timeout is reached
44 # Returns the socket if connection was successful or None otherwise 45 # Returns the socket if connection was successful or None otherwise
@@ -52,7 +53,6 @@ def init_connection(mfa_server, pam_port, insecure):
52 context.verify_mode = 0 53 context.verify_mode = 0
53 while connection == None and timeout < timeout_length: 54 while connection == None and timeout < timeout_length:
54 try: 55 try:
55 #connection = socket.create_connection((mfa_server,client_port))
56 connection = context.wrap_socket(socket.socket(socket.AF_INET), 56 connection = context.wrap_socket(socket.socket(socket.AF_INET),
57 server_hostname=mfa_server) 57 server_hostname=mfa_server)
58 connection.connect((mfa_server,int(pam_port))) 58 connection.connect((mfa_server,int(pam_port)))
@@ -65,6 +65,24 @@ def init_connection(mfa_server, pam_port, insecure):
65 return None 65 return None
66 66
67 67
68def init_connection(mfa_server, pam_port):
69 # Attempts to connect to MFA server with provided address and port
70 # Repeats connection attempts once per second until timeout is reached
71 # Returns the socket if connection was successful or None otherwise
72 connection = None
73 timeout = 0
74 timeout_length = 5
75 sleep_length = 1
76 while connection == None and timeout < timeout_length:
77 try:
78 connection = socket.create_connection((mfa_server,pam_port))
79 return connection
80 except (ConnectionError,ConnectionRefusedError):
81 time.sleep(sleep_length)
82 timeout += sleep_length
83 return None
84
85
68def read_config(config_file): 86def read_config(config_file):
69 # Read config file for server and port info 87 # Read config file for server and port info
70 # Return tuple (server,port) 88 # Return tuple (server,port)
@@ -94,19 +112,28 @@ def get_vars(args,confparser):
94 112
95 server = None 113 server = None
96 port = None 114 port = None
115 plain = None
97 insecure = None 116 insecure = None
98 117
99 # Set values from config file first 118 # Set values from config file first
100 if confparser.has_section("pam"): 119 if confparser.has_section("pam"):
101 server = confparser.get("pam","server",fallback=None) 120 server = confparser.get("pam","server",fallback=None)
102 port = confparser.get("pam","port",fallback=None) 121 port = confparser.get("pam","port",fallback=None)
103 insecure = bool(confparser.get("pam","insecure",fallback=False)) 122 plain = confparser.get("client","plain",fallback=False)
123 insecure = confparser.get("client","insecure",fallback=False)
104 124
125 if plain.lower() == "false":
126 plain = False
127 if insecure.lower() == "false":
128 insecure = False
129
105 # Let command line args overwrite any values 130 # Let command line args overwrite any values
106 if args.server != None: 131 if args.server != None:
107 server = args.server 132 server = args.server
108 if args.port != None: 133 if args.port != None:
109 port = args.port 134 port = args.port
135 if args.plain:
136 plain = args.plain
110 if args.insecure: 137 if args.insecure:
111 insecure = args.insecure 138 insecure = args.insecure
112 139
@@ -115,7 +142,7 @@ def get_vars(args,confparser):
115 print("error: one or more items unspecified") 142 print("error: one or more items unspecified")
116 sys.exit(1) 143 sys.exit(1)
117 144
118 return server,port,insecure 145 return server,port,plain,insecure
119 146
120 147
121def main(): 148def main():
@@ -125,7 +152,7 @@ def main():
125 # Get arguments 152 # Get arguments
126 args = parse_arguments() 153 args = parse_arguments()
127 confparser = read_config(args.config) 154 confparser = read_config(args.config)
128 mfa_server,pam_port,insecure = get_vars(args,confparser) 155 mfa_server,pam_port,plain,insecure = get_vars(args,confparser)
129 user = args.user 156 user = args.user
130 service = args.service 157 service = args.service
131 158
@@ -144,12 +171,13 @@ def main():
144 hostname = args.host 171 hostname = args.host
145 data = user + "," + hostname + "," + service 172 data = user + "," + hostname + "," + service
146 173
147
148 # Initalize connection to MFA server. Quit if unable to connect. 174 # Initalize connection to MFA server. Quit if unable to connect.
149 connection = init_connection(mfa_server,pam_port,insecure) 175 if plain:
176 connection = init_connection(mfa_server, pam_port)
177 else:
178 connection = init_connection_tls(mfa_server,pam_port,insecure)
150 if connection == None: 179 if connection == None:
151 print(failed) 180 die("failed to connect")
152 sys.exit(1)
153 181
154 # Send authentication data to MFA server 182 # Send authentication data to MFA server
155 data_length = len(data) 183 data_length = len(data)
diff --git a/server/mfad.py b/server/mfad.py
index a2a783b..cc5073b 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -52,6 +52,10 @@ def parse_arguments():
52 parser.add_argument("--address",type=str,help="Bind Address") 52 parser.add_argument("--address",type=str,help="Bind Address")
53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") 53 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests")
54 parser.add_argument("--client-port",type=int,help="Port for client connections") 54 parser.add_argument("--client-port",type=int,help="Port for client connections")
55 parser.add_argument("--pam-tls-port",type=int,
56 help="Port to listen for encrypted PAM requests")
57 parser.add_argument("--client-tls-port",type=int,
58 help="Port for encrypted client connections")
55 parser.add_argument("--database",type=str,help="Path to database file") 59 parser.add_argument("--database",type=str,help="Path to database file")
56 parser.add_argument("--cert",type=str,help="TLS certificate file") 60 parser.add_argument("--cert",type=str,help="TLS certificate file")
57 parser.add_argument("--key",type=str,help="TLS private key file") 61 parser.add_argument("--key",type=str,help="TLS private key file")
@@ -266,7 +270,7 @@ def get_tls_context(cert, key, ciphers):
266 270
267 return context 271 return context
268 272
269def listen_client(db, addr, port, tls_context): 273def listen_client_tls(db, addr, port, tls_context):
270 with socket.create_server((addr, port)) as sock: 274 with socket.create_server((addr, port)) as sock:
271 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 275 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
272 while True: 276 while True:
@@ -278,7 +282,7 @@ def listen_client(db, addr, port, tls_context):
278 print("client: ssl handshake error") 282 print("client: ssl handshake error")
279 283
280 284
281def listen_pam(db, addr, port, tls_context): 285def listen_pam_tls(db, addr, port, tls_context):
282 with socket.create_server((addr,port)) as sock: 286 with socket.create_server((addr,port)) as sock:
283 with tls_context.wrap_socket(sock, server_side=True) as tls_socket: 287 with tls_context.wrap_socket(sock, server_side=True) as tls_socket:
284 while True: 288 while True:
@@ -290,6 +294,22 @@ def listen_pam(db, addr, port, tls_context):
290 print("pam: ssl handshake error") 294 print("pam: ssl handshake error")
291 295
292 296
297def listen_client(db, addr, port):
298 with socket.create_server((addr, port)) as sock:
299 while True:
300 conn, addr = sock.accept()
301 thread = threading.Thread(target=handle_client,args=(db,conn,addr))
302 thread.start()
303
304
305def listen_pam(db, addr, port):
306 with socket.create_server((addr, port)) as sock:
307 while True:
308 conn, addr = sock.accept()
309 thread = threading.Thread(target=handle_pam,args=(db,conn,addr))
310 thread.start()
311
312
293################################################################################ 313################################################################################
294 314
295def create_db(db): 315def create_db(db):
@@ -312,12 +332,13 @@ def create_db(db):
312 332
313def get_vars(args,confparser): 333def get_vars(args,confparser):
314 if not os.path.exists(args.config): 334 if not os.path.exists(args.config):
315 print("Unable to open config file") 335 die("Unable to open config file")
316 sys.exit(1)
317 336
318 bind_addr = None 337 bind_addr = None
319 client_port = None 338 client_port = None
339 client_tls_port = None
320 pam_port = None 340 pam_port = None
341 pam_tls_port = None
321 database = None 342 database = None
322 cert = None 343 cert = None
323 key = None 344 key = None
@@ -328,6 +349,8 @@ def get_vars(args,confparser):
328 bind_addr = confparser.get("mfad","address",fallback=None) 349 bind_addr = confparser.get("mfad","address",fallback=None)
329 client_port = confparser.get("mfad","client-port",fallback=None) 350 client_port = confparser.get("mfad","client-port",fallback=None)
330 pam_port = confparser.get("mfad","pam-port",fallback=None) 351 pam_port = confparser.get("mfad","pam-port",fallback=None)
352 client_tls_port = confparser.get("mfad","client-tls-port",fallback=None)
353 pam_tls_port = confparser.get("mfad","pam-tls-port",fallback=None)
331 database = confparser.get("mfad","database",fallback=None) 354 database = confparser.get("mfad","database",fallback=None)
332 cert = confparser.get("mfad","cert",fallback=None) 355 cert = confparser.get("mfad","cert",fallback=None)
333 key = confparser.get("mfad","key",fallback=None) 356 key = confparser.get("mfad","key",fallback=None)
@@ -340,6 +363,10 @@ def get_vars(args,confparser):
340 client_port = args.client_port 363 client_port = args.client_port
341 if args.pam_port != None: 364 if args.pam_port != None:
342 pam_port = args.pam_port 365 pam_port = args.pam_port
366 if args.client_tls_port != None:
367 client_tls_port = args.client_tls_port
368 if args.pam_tls_port != None:
369 pam_tls_port = args.pam_tls_port
343 if args.database != None: 370 if args.database != None:
344 database = args.database 371 database = args.database
345 if args.cert != None: 372 if args.cert != None:
@@ -350,28 +377,49 @@ def get_vars(args,confparser):
350 ciphers = args.ciphers 377 ciphers = args.ciphers
351 378
352 # Exit if any value is null 379 # Exit if any value is null
353 if None in [bind_addr,client_port,pam_port,database,cert,key]: 380 if None in [bind_addr,database,cert,key]:
354 print("error: one or more items unspecified") 381 die("error: one or more items unspecified")
355 sys.exit(1) 382
383 ports = [pam_port, client_port, pam_tls_port, client_tls_port]
384 for port in ports:
385 if port == None:
386 ports[ports.index(port)] = 0
356 387
357 return bind_addr, int(client_port), int(pam_port), database, cert, key, ciphers 388 return bind_addr, ports, database, cert, key, ciphers
358 389
359 390
360def main(): 391def main():
361 args = parse_arguments() 392 args = parse_arguments()
362 confparser = read_config(args.config) 393 confparser = read_config(args.config)
363 394
364 bind_addr,client_port,pam_port,db,cert,key,ciphers = get_vars(args,confparser) 395 bind_addr,ports,db,cert,key,ciphers = get_vars(args,confparser)
396 pam_port = int(ports[0])
397 client_port = int(ports[1])
398 pam_tls_port = int(ports[2])
399 client_tls_port = int(ports[3])
400
401 if pam_port == 0 and pam_tls_port == 0:
402 die("error: not listening for any PAM connections")
403 if client_port == 0 and client_tls_port == 0:
404 die("error: not listening for any client connections")
405
365 406
366 if not os.path.exists(db): 407 if not os.path.exists(db):
367 print("Creating DB") 408 print("Creating DB")
368 create_db(db) 409 create_db(db)
369 410
370 context = get_tls_context(cert,key,ciphers) 411 context = get_tls_context(cert,key,ciphers)
371 clients = threading.Thread(target=listen_client,args=(db,bind_addr,client_port,context)) 412
372 pam = threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port,context)) 413 if pam_port != 0:
373 clients.start() 414 threading.Thread(target=listen_pam,args=(db,bind_addr,pam_port)).start()
374 pam.start() 415 if client_port != 0:
416 threading.Thread(target=listen_client,args=(db,bind_addr,client_port)).start()
417 if pam_tls_port != 0:
418 threading.Thread(target=listen_pam_tls,
419 args=(db,bind_addr,pam_tls_port,context)).start()
420 if client_tls_port != 0:
421 threading.Thread(target=listen_client_tls,
422 args=(db,bind_addr,client_tls_port,context)).start()
375 423
376 424
377if __name__ == '__main__': 425if __name__ == '__main__':