diff options
Diffstat (limited to 'server')
-rwxr-xr-x | server/mfad.py | 73 |
1 files changed, 64 insertions, 9 deletions
diff --git a/server/mfad.py b/server/mfad.py index d045e14..46fc0cc 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
@@ -7,6 +7,8 @@ import threading | |||
7 | import pyotp | 7 | import pyotp |
8 | import sqlite3 | 8 | import sqlite3 |
9 | import re | 9 | import re |
10 | import configparser | ||
11 | import argparse | ||
10 | 12 | ||
11 | ## Listens for authentication request from PAM module | 13 | ## Listens for authentication request from PAM module |
12 | ## Recevies connection from client | 14 | ## Recevies connection from client |
@@ -16,7 +18,7 @@ import re | |||
16 | ## Return pass or fail response to PAM moudle | 18 | ## Return pass or fail response to PAM moudle |
17 | 19 | ||
18 | 20 | ||
19 | DB_NAME = "mfa.db" | 21 | DB_NAME = "" |
20 | HEADER_LENGTH = 64 | 22 | HEADER_LENGTH = 64 |
21 | KEY_LENGTH = 64 | 23 | KEY_LENGTH = 64 |
22 | DISCONNECT_LENGTH = ACK_LENGTH = 3 | 24 | DISCONNECT_LENGTH = ACK_LENGTH = 3 |
@@ -41,6 +43,22 @@ CLIENT_SECRET_INDEX = 2 | |||
41 | # key and a tuple of (socket,(addr,port)) as the value | 43 | # key and a tuple of (socket,(addr,port)) as the value |
42 | client_connections = dict() | 44 | client_connections = dict() |
43 | 45 | ||
46 | def parse_arguments(): | ||
47 | parser = argparse.ArgumentParser() | ||
48 | parser.add_argument("--address",type=str,help="Bind Address") | ||
49 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") | ||
50 | parser.add_argument("--client-port",type=int,help="Port for client connections") | ||
51 | parser.add_argument("--database",type=str,help="Path to alternate database file") | ||
52 | parser.add_argument("--config",type=str,help="Alternate config file location",\ | ||
53 | default="/etc/mfa/mfa.conf") | ||
54 | return parser.parse_args() | ||
55 | |||
56 | |||
57 | def read_config(config): | ||
58 | parser = configparser.ConfigParser(inline_comment_prefixes="#") | ||
59 | parser.read(config) | ||
60 | return parser | ||
61 | |||
44 | 62 | ||
45 | def eval_mfa(client_key, mfa_methods, client_response): | 63 | def eval_mfa(client_key, mfa_methods, client_response): |
46 | print("response: " + client_response) | 64 | print("response: " + client_response) |
@@ -228,14 +246,14 @@ def listen_pam(addr, port): | |||
228 | 246 | ||
229 | ################################################################################ | 247 | ################################################################################ |
230 | 248 | ||
231 | def create_db(): | 249 | def create_db(db): |
232 | with sqlite3.connect(DB_NAME) as conn: | 250 | with sqlite3.connect(db) as conn: |
233 | c = conn.cursor() | 251 | c = conn.cursor() |
234 | c.execute("""CREATE TABLE applications ( | 252 | c.execute("""CREATE TABLE applications ( |
235 | username text, | 253 | username text, |
236 | hostname text, | 254 | hostname text, |
237 | service text, | 255 | service text, |
238 | client_key text, | 256 | alias text, |
239 | mfa_methods text | 257 | mfa_methods text |
240 | )""") | 258 | )""") |
241 | c.execute("""CREATE TABLE clients ( | 259 | c.execute("""CREATE TABLE clients ( |
@@ -243,16 +261,53 @@ def create_db(): | |||
243 | key text, | 261 | key text, |
244 | totp_secret text | 262 | totp_secret text |
245 | )""") | 263 | )""") |
264 | conn.commit() | ||
265 | |||
266 | |||
267 | def get_vars(args,confparser): | ||
268 | if not os.path.exists(args.config): | ||
269 | print("Unable to open config file") | ||
270 | sys.exit(1) | ||
271 | |||
272 | bind_addr = None | ||
273 | client_port = None | ||
274 | pam_port = None | ||
275 | database = None | ||
276 | |||
277 | # Set values from config file first | ||
278 | if confparser.has_section("mfad"): | ||
279 | bind_addr = confparser.get("mfad","address",fallback=None) | ||
280 | client_port = confparser.get("mfad","client-port",fallback=None) | ||
281 | pam_port = confparser.get("mfad","pam-port",fallback=None) | ||
282 | database = confparser.get("mfad","database",fallback=None) | ||
283 | |||
284 | # Let command line args overwrite any values | ||
285 | if args.address: | ||
286 | bind_addr = args.address | ||
287 | if args.client_port: | ||
288 | client_port = args.client_port | ||
289 | if args.pam_port: | ||
290 | pam_port = args.pam_port | ||
291 | if args.database: | ||
292 | database = args.database | ||
293 | |||
294 | # Exit if any value is null | ||
295 | if None in [bind_addr,client_port,pam_port,database]: | ||
296 | print("error: one or more items unspecified") | ||
297 | sys.exit(1) | ||
298 | |||
299 | return bind_addr, int(client_port), int(pam_port), database | ||
246 | 300 | ||
247 | 301 | ||
248 | def main(): | 302 | def main(): |
249 | global connection_list | 303 | args = parse_arguments() |
250 | bind_addr = "127.0.0.1" | 304 | confparser = read_config(args.config) |
251 | pam_port = 8000 | 305 | |
252 | client_port = 8001 | 306 | bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser) |
253 | 307 | ||
254 | if not os.path.exists(DB_NAME): | 308 | if not os.path.exists(DB_NAME): |
255 | create_db() | 309 | print("Creating DB") |
310 | create_db(DB_NAME) | ||
256 | 311 | ||
257 | clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) | 312 | clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) |
258 | pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) | 313 | pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) |