summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rwxr-xr-xserver/mfad.py73
1 files changed, 64 insertions, 9 deletions
diff --git a/server/mfad.py b/server/mfad.py
index d045e14..46fc0cc 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -7,6 +7,8 @@ import threading
7import pyotp 7import pyotp
8import sqlite3 8import sqlite3
9import re 9import re
10import configparser
11import argparse
10 12
11## Listens for authentication request from PAM module 13## Listens for authentication request from PAM module
12## Recevies connection from client 14## Recevies connection from client
@@ -16,7 +18,7 @@ import re
16## Return pass or fail response to PAM moudle 18## Return pass or fail response to PAM moudle
17 19
18 20
19DB_NAME = "mfa.db" 21DB_NAME = ""
20HEADER_LENGTH = 64 22HEADER_LENGTH = 64
21KEY_LENGTH = 64 23KEY_LENGTH = 64
22DISCONNECT_LENGTH = ACK_LENGTH = 3 24DISCONNECT_LENGTH = ACK_LENGTH = 3
@@ -41,6 +43,22 @@ CLIENT_SECRET_INDEX = 2
41# key and a tuple of (socket,(addr,port)) as the value 43# key and a tuple of (socket,(addr,port)) as the value
42client_connections = dict() 44client_connections = dict()
43 45
46def parse_arguments():
47 parser = argparse.ArgumentParser()
48 parser.add_argument("--address",type=str,help="Bind Address")
49 parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests")
50 parser.add_argument("--client-port",type=int,help="Port for client connections")
51 parser.add_argument("--database",type=str,help="Path to alternate database file")
52 parser.add_argument("--config",type=str,help="Alternate config file location",\
53 default="/etc/mfa/mfa.conf")
54 return parser.parse_args()
55
56
57def read_config(config):
58 parser = configparser.ConfigParser(inline_comment_prefixes="#")
59 parser.read(config)
60 return parser
61
44 62
45def eval_mfa(client_key, mfa_methods, client_response): 63def eval_mfa(client_key, mfa_methods, client_response):
46 print("response: " + client_response) 64 print("response: " + client_response)
@@ -228,14 +246,14 @@ def listen_pam(addr, port):
228 246
229################################################################################ 247################################################################################
230 248
231def create_db(): 249def create_db(db):
232 with sqlite3.connect(DB_NAME) as conn: 250 with sqlite3.connect(db) as conn:
233 c = conn.cursor() 251 c = conn.cursor()
234 c.execute("""CREATE TABLE applications ( 252 c.execute("""CREATE TABLE applications (
235 username text, 253 username text,
236 hostname text, 254 hostname text,
237 service text, 255 service text,
238 client_key text, 256 alias text,
239 mfa_methods text 257 mfa_methods text
240 )""") 258 )""")
241 c.execute("""CREATE TABLE clients ( 259 c.execute("""CREATE TABLE clients (
@@ -243,16 +261,53 @@ def create_db():
243 key text, 261 key text,
244 totp_secret text 262 totp_secret text
245 )""") 263 )""")
264 conn.commit()
265
266
267def get_vars(args,confparser):
268 if not os.path.exists(args.config):
269 print("Unable to open config file")
270 sys.exit(1)
271
272 bind_addr = None
273 client_port = None
274 pam_port = None
275 database = None
276
277 # Set values from config file first
278 if confparser.has_section("mfad"):
279 bind_addr = confparser.get("mfad","address",fallback=None)
280 client_port = confparser.get("mfad","client-port",fallback=None)
281 pam_port = confparser.get("mfad","pam-port",fallback=None)
282 database = confparser.get("mfad","database",fallback=None)
283
284 # Let command line args overwrite any values
285 if args.address:
286 bind_addr = args.address
287 if args.client_port:
288 client_port = args.client_port
289 if args.pam_port:
290 pam_port = args.pam_port
291 if args.database:
292 database = args.database
293
294 # Exit if any value is null
295 if None in [bind_addr,client_port,pam_port,database]:
296 print("error: one or more items unspecified")
297 sys.exit(1)
298
299 return bind_addr, int(client_port), int(pam_port), database
246 300
247 301
248def main(): 302def main():
249 global connection_list 303 args = parse_arguments()
250 bind_addr = "127.0.0.1" 304 confparser = read_config(args.config)
251 pam_port = 8000 305
252 client_port = 8001 306 bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser)
253 307
254 if not os.path.exists(DB_NAME): 308 if not os.path.exists(DB_NAME):
255 create_db() 309 print("Creating DB")
310 create_db(DB_NAME)
256 311
257 clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) 312 clients = threading.Thread(target=listen_client,args=(bind_addr,client_port))
258 pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) 313 pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port))