diff options
Diffstat (limited to 'server')
| -rwxr-xr-x | server/mfad.py | 73 |
1 files changed, 64 insertions, 9 deletions
diff --git a/server/mfad.py b/server/mfad.py index d045e14..46fc0cc 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
| @@ -7,6 +7,8 @@ import threading | |||
| 7 | import pyotp | 7 | import pyotp |
| 8 | import sqlite3 | 8 | import sqlite3 |
| 9 | import re | 9 | import re |
| 10 | import configparser | ||
| 11 | import argparse | ||
| 10 | 12 | ||
| 11 | ## Listens for authentication request from PAM module | 13 | ## Listens for authentication request from PAM module |
| 12 | ## Recevies connection from client | 14 | ## Recevies connection from client |
| @@ -16,7 +18,7 @@ import re | |||
| 16 | ## Return pass or fail response to PAM moudle | 18 | ## Return pass or fail response to PAM moudle |
| 17 | 19 | ||
| 18 | 20 | ||
| 19 | DB_NAME = "mfa.db" | 21 | DB_NAME = "" |
| 20 | HEADER_LENGTH = 64 | 22 | HEADER_LENGTH = 64 |
| 21 | KEY_LENGTH = 64 | 23 | KEY_LENGTH = 64 |
| 22 | DISCONNECT_LENGTH = ACK_LENGTH = 3 | 24 | DISCONNECT_LENGTH = ACK_LENGTH = 3 |
| @@ -41,6 +43,22 @@ CLIENT_SECRET_INDEX = 2 | |||
| 41 | # key and a tuple of (socket,(addr,port)) as the value | 43 | # key and a tuple of (socket,(addr,port)) as the value |
| 42 | client_connections = dict() | 44 | client_connections = dict() |
| 43 | 45 | ||
| 46 | def parse_arguments(): | ||
| 47 | parser = argparse.ArgumentParser() | ||
| 48 | parser.add_argument("--address",type=str,help="Bind Address") | ||
| 49 | parser.add_argument("--pam-port",type=int,help="Port to listen for PAM requests") | ||
| 50 | parser.add_argument("--client-port",type=int,help="Port for client connections") | ||
| 51 | parser.add_argument("--database",type=str,help="Path to alternate database file") | ||
| 52 | parser.add_argument("--config",type=str,help="Alternate config file location",\ | ||
| 53 | default="/etc/mfa/mfa.conf") | ||
| 54 | return parser.parse_args() | ||
| 55 | |||
| 56 | |||
| 57 | def read_config(config): | ||
| 58 | parser = configparser.ConfigParser(inline_comment_prefixes="#") | ||
| 59 | parser.read(config) | ||
| 60 | return parser | ||
| 61 | |||
| 44 | 62 | ||
| 45 | def eval_mfa(client_key, mfa_methods, client_response): | 63 | def eval_mfa(client_key, mfa_methods, client_response): |
| 46 | print("response: " + client_response) | 64 | print("response: " + client_response) |
| @@ -228,14 +246,14 @@ def listen_pam(addr, port): | |||
| 228 | 246 | ||
| 229 | ################################################################################ | 247 | ################################################################################ |
| 230 | 248 | ||
| 231 | def create_db(): | 249 | def create_db(db): |
| 232 | with sqlite3.connect(DB_NAME) as conn: | 250 | with sqlite3.connect(db) as conn: |
| 233 | c = conn.cursor() | 251 | c = conn.cursor() |
| 234 | c.execute("""CREATE TABLE applications ( | 252 | c.execute("""CREATE TABLE applications ( |
| 235 | username text, | 253 | username text, |
| 236 | hostname text, | 254 | hostname text, |
| 237 | service text, | 255 | service text, |
| 238 | client_key text, | 256 | alias text, |
| 239 | mfa_methods text | 257 | mfa_methods text |
| 240 | )""") | 258 | )""") |
| 241 | c.execute("""CREATE TABLE clients ( | 259 | c.execute("""CREATE TABLE clients ( |
| @@ -243,16 +261,53 @@ def create_db(): | |||
| 243 | key text, | 261 | key text, |
| 244 | totp_secret text | 262 | totp_secret text |
| 245 | )""") | 263 | )""") |
| 264 | conn.commit() | ||
| 265 | |||
| 266 | |||
| 267 | def get_vars(args,confparser): | ||
| 268 | if not os.path.exists(args.config): | ||
| 269 | print("Unable to open config file") | ||
| 270 | sys.exit(1) | ||
| 271 | |||
| 272 | bind_addr = None | ||
| 273 | client_port = None | ||
| 274 | pam_port = None | ||
| 275 | database = None | ||
| 276 | |||
| 277 | # Set values from config file first | ||
| 278 | if confparser.has_section("mfad"): | ||
| 279 | bind_addr = confparser.get("mfad","address",fallback=None) | ||
| 280 | client_port = confparser.get("mfad","client-port",fallback=None) | ||
| 281 | pam_port = confparser.get("mfad","pam-port",fallback=None) | ||
| 282 | database = confparser.get("mfad","database",fallback=None) | ||
| 283 | |||
| 284 | # Let command line args overwrite any values | ||
| 285 | if args.address: | ||
| 286 | bind_addr = args.address | ||
| 287 | if args.client_port: | ||
| 288 | client_port = args.client_port | ||
| 289 | if args.pam_port: | ||
| 290 | pam_port = args.pam_port | ||
| 291 | if args.database: | ||
| 292 | database = args.database | ||
| 293 | |||
| 294 | # Exit if any value is null | ||
| 295 | if None in [bind_addr,client_port,pam_port,database]: | ||
| 296 | print("error: one or more items unspecified") | ||
| 297 | sys.exit(1) | ||
| 298 | |||
| 299 | return bind_addr, int(client_port), int(pam_port), database | ||
| 246 | 300 | ||
| 247 | 301 | ||
| 248 | def main(): | 302 | def main(): |
| 249 | global connection_list | 303 | args = parse_arguments() |
| 250 | bind_addr = "127.0.0.1" | 304 | confparser = read_config(args.config) |
| 251 | pam_port = 8000 | 305 | |
| 252 | client_port = 8001 | 306 | bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser) |
| 253 | 307 | ||
| 254 | if not os.path.exists(DB_NAME): | 308 | if not os.path.exists(DB_NAME): |
| 255 | create_db() | 309 | print("Creating DB") |
| 310 | create_db(DB_NAME) | ||
| 256 | 311 | ||
| 257 | clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) | 312 | clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) |
| 258 | pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) | 313 | pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) |
