diff options
| author | Sam Chudnick <sam@chudnick.com> | 2022-07-03 05:45:21 -0400 |
|---|---|---|
| committer | Sam Chudnick <sam@chudnick.com> | 2022-07-03 05:45:21 -0400 |
| commit | ce3c9f1e849b871db2fa91b5aa030e8ea471a7ca (patch) | |
| tree | 24cecd0a4f25d151066ab53b4875b76399019038 /server | |
| parent | 41132171e065081736c82ee283c64309e30baa9b (diff) | |
Fixed issue caused by non-static database location
Pass database location as argument where needed now that location is not
static.
Diffstat (limited to 'server')
| -rwxr-xr-x | server/mfad.py | 59 |
1 files changed, 34 insertions, 25 deletions
diff --git a/server/mfad.py b/server/mfad.py index 46fc0cc..17d2585 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
| @@ -18,7 +18,6 @@ import argparse | |||
| 18 | ## Return pass or fail response to PAM moudle | 18 | ## Return pass or fail response to PAM moudle |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | DB_NAME = "" | ||
| 22 | HEADER_LENGTH = 64 | 21 | HEADER_LENGTH = 64 |
| 23 | KEY_LENGTH = 64 | 22 | KEY_LENGTH = 64 |
| 24 | DISCONNECT_LENGTH = ACK_LENGTH = 3 | 23 | DISCONNECT_LENGTH = ACK_LENGTH = 3 |
| @@ -43,6 +42,10 @@ CLIENT_SECRET_INDEX = 2 | |||
| 43 | # key and a tuple of (socket,(addr,port)) as the value | 42 | # key and a tuple of (socket,(addr,port)) as the value |
| 44 | client_connections = dict() | 43 | client_connections = dict() |
| 45 | 44 | ||
| 45 | def die(msg): | ||
| 46 | print(msg) | ||
| 47 | sys.exit(1) | ||
| 48 | |||
| 46 | def parse_arguments(): | 49 | def parse_arguments(): |
| 47 | parser = argparse.ArgumentParser() | 50 | parser = argparse.ArgumentParser() |
| 48 | parser.add_argument("--address",type=str,help="Bind Address") | 51 | parser.add_argument("--address",type=str,help="Bind Address") |
| @@ -60,7 +63,7 @@ def read_config(config): | |||
| 60 | return parser | 63 | return parser |
| 61 | 64 | ||
| 62 | 65 | ||
| 63 | def eval_mfa(client_key, mfa_methods, client_response): | 66 | def eval_mfa(db, client_key, mfa_methods, client_response): |
| 64 | print("response: " + client_response) | 67 | print("response: " + client_response) |
| 65 | print("length: " + str(len(client_response))) | 68 | print("length: " + str(len(client_response))) |
| 66 | print("methods: " + str(mfa_methods)) | 69 | print("methods: " + str(mfa_methods)) |
| @@ -74,13 +77,13 @@ def eval_mfa(client_key, mfa_methods, client_response): | |||
| 74 | totp_regex = re.compile(totp_format) | 77 | totp_regex = re.compile(totp_format) |
| 75 | matched = totp_regex.match(client_response) | 78 | matched = totp_regex.match(client_response) |
| 76 | if matched: | 79 | if matched: |
| 77 | return validate_totp(client_key, client_response) | 80 | return validate_totp(db, client_key, client_response) |
| 78 | return DENIED | 81 | return DENIED |
| 79 | 82 | ||
| 80 | 83 | ||
| 81 | def validate_totp(client_key, client_response): | 84 | def validate_totp(db, client_key, client_response): |
| 82 | secret = "" | 85 | secret = "" |
| 83 | with sqlite3.connect(DB_NAME) as conn: | 86 | with sqlite3.connect(db) as conn: |
| 84 | c = conn.cursor() | 87 | c = conn.cursor() |
| 85 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) | 88 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) |
| 86 | client = c.fetchone() | 89 | client = c.fetchone() |
| @@ -103,7 +106,7 @@ def validate_totp(client_key, client_response): | |||
| 103 | 106 | ||
| 104 | # //TODO RSA public/private key pairs for proper authentication | 107 | # //TODO RSA public/private key pairs for proper authentication |
| 105 | 108 | ||
| 106 | def get_client_key(username,hostname,service): | 109 | def get_client_key(db, username,hostname,service): |
| 107 | # Correlates a PAM request to a registered client | 110 | # Correlates a PAM request to a registered client |
| 108 | # This is done by checking the PAM request against a preconfigured | 111 | # This is done by checking the PAM request against a preconfigured |
| 109 | # database mapping request info (username,hostname,etc...) to clients | 112 | # database mapping request info (username,hostname,etc...) to clients |
| @@ -111,7 +114,7 @@ def get_client_key(username,hostname,service): | |||
| 111 | 114 | ||
| 112 | application = None | 115 | application = None |
| 113 | client = None | 116 | client = None |
| 114 | with sqlite3.connect(DB_NAME) as conn: | 117 | with sqlite3.connect(db) as conn: |
| 115 | c = conn.cursor() | 118 | c = conn.cursor() |
| 116 | c.execute("""SELECT * FROM applications WHERE username=? AND hostname=? | 119 | c.execute("""SELECT * FROM applications WHERE username=? AND hostname=? |
| 117 | AND service=?""",(username,hostname,service)) | 120 | AND service=?""",(username,hostname,service)) |
| @@ -151,7 +154,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10): | |||
| 151 | conn.send(length_msg.encode(FORMAT)) | 154 | conn.send(length_msg.encode(FORMAT)) |
| 152 | conn.send(prompt_msg.encode(FORMAT)) | 155 | conn.send(prompt_msg.encode(FORMAT)) |
| 153 | # receive response | 156 | # receive response |
| 154 | response_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) | 157 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) |
| 158 | if header == "": | ||
| 159 | die("error: lost connection to client") | ||
| 160 | response_length = int(header) | ||
| 155 | response = conn.recv(response_length).decode(FORMAT) | 161 | response = conn.recv(response_length).decode(FORMAT) |
| 156 | return response | 162 | return response |
| 157 | except BrokenPipeError: | 163 | except BrokenPipeError: |
| @@ -165,10 +171,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10): | |||
| 165 | return 0 | 171 | return 0 |
| 166 | 172 | ||
| 167 | 173 | ||
| 168 | def validate_client(client_key): | 174 | def validate_client(db, client_key): |
| 169 | # Validates a client | 175 | # Validates a client |
| 170 | client = None | 176 | client = None |
| 171 | with sqlite3.connect(DB_NAME) as conn: | 177 | with sqlite3.connect(db) as conn: |
| 172 | c = conn.cursor() | 178 | c = conn.cursor() |
| 173 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) | 179 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) |
| 174 | client = c.fetchall() | 180 | client = c.fetchall() |
| @@ -184,11 +190,11 @@ def validate_client(client_key): | |||
| 184 | 190 | ||
| 185 | 191 | ||
| 186 | 192 | ||
| 187 | def handle_client(conn, addr): | 193 | def handle_client(db, conn, addr): |
| 188 | # Receive key from client | 194 | # Receive key from client |
| 189 | key = conn.recv(KEY_LENGTH).decode(FORMAT) | 195 | key = conn.recv(KEY_LENGTH).decode(FORMAT) |
| 190 | # Validate client | 196 | # Validate client |
| 191 | if not validate_client(key): | 197 | if not validate_client(db, key): |
| 192 | print("WARNING: client attempted to connect with invalid key") | 198 | print("WARNING: client attempted to connect with invalid key") |
| 193 | conn.send(DISCONNECT_MESSAGE.encode(FORMAT)) | 199 | conn.send(DISCONNECT_MESSAGE.encode(FORMAT)) |
| 194 | conn.close() | 200 | conn.close() |
| @@ -202,15 +208,18 @@ def parse_pam_data(data): | |||
| 202 | # Parses pam data and returns (user,host,service) tuple | 208 | # Parses pam data and returns (user,host,service) tuple |
| 203 | return tuple(data.split(',')) | 209 | return tuple(data.split(',')) |
| 204 | 210 | ||
| 205 | def handle_pam(conn, addr): | 211 | def handle_pam(db, conn, addr): |
| 206 | # Get request and data from PAM module | 212 | # Get request and data from PAM module |
| 207 | data_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) | 213 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) |
| 214 | if header == "": | ||
| 215 | die("error: lost connection to pam module") | ||
| 216 | data_length = int(header) | ||
| 208 | pam_data = conn.recv(data_length).decode(FORMAT) | 217 | pam_data = conn.recv(data_length).decode(FORMAT) |
| 209 | print("Got pam_data: " + pam_data) | 218 | print("Got pam_data: " + pam_data) |
| 210 | user,host,service = parse_pam_data(pam_data) | 219 | user,host,service = parse_pam_data(pam_data) |
| 211 | 220 | ||
| 212 | # Correlate request to client | 221 | # Correlate request to client |
| 213 | client_key,mfa_methods = get_client_key(user,host,service) | 222 | client_key,mfa_methods = get_client_key(db, user,host,service) |
| 214 | mfa_methods = mfa_methods.split(' ') | 223 | mfa_methods = mfa_methods.split(' ') |
| 215 | if client_key == None: | 224 | if client_key == None: |
| 216 | print("No applications found for user="+user+" host="+host+" service="+service) | 225 | print("No applications found for user="+user+" host="+host+" service="+service) |
| @@ -221,26 +230,26 @@ def handle_pam(conn, addr): | |||
| 221 | response = prompt_client(client_key,user,host,service,mfa_methods) | 230 | response = prompt_client(client_key,user,host,service,mfa_methods) |
| 222 | 231 | ||
| 223 | # Evaluate Response | 232 | # Evaluate Response |
| 224 | decision = eval_mfa(client_key, mfa_methods, response) | 233 | decision = eval_mfa(db, client_key, mfa_methods, response) |
| 225 | 234 | ||
| 226 | # Return response to PAM module | 235 | # Return response to PAM module |
| 227 | # Respone will either be 0 for authenticated and 1 for denied | 236 | # Respone will either be 0 for authenticated and 1 for denied |
| 228 | conn.send(str(decision).encode(FORMAT)) | 237 | conn.send(str(decision).encode(FORMAT)) |
| 229 | 238 | ||
| 230 | 239 | ||
| 231 | def listen_client(addr, port): | 240 | def listen_client(db, addr, port): |
| 232 | with socket.create_server((addr, port)) as server: | 241 | with socket.create_server((addr, port)) as server: |
| 233 | while True: | 242 | while True: |
| 234 | conn, addr = server.accept() | 243 | conn, addr = server.accept() |
| 235 | thread = threading.Thread(target=handle_client,args=(conn,addr)) | 244 | thread = threading.Thread(target=handle_client,args=(db, conn,addr)) |
| 236 | thread.start() | 245 | thread.start() |
| 237 | 246 | ||
| 238 | 247 | ||
| 239 | def listen_pam(addr, port): | 248 | def listen_pam(db, addr, port): |
| 240 | with socket.create_server((addr,port)) as pam_server: | 249 | with socket.create_server((addr,port)) as pam_server: |
| 241 | while True: | 250 | while True: |
| 242 | conn, addr = pam_server.accept() | 251 | conn, addr = pam_server.accept() |
| 243 | thread = threading.Thread(target=handle_pam,args=(conn,addr)) | 252 | thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) |
| 244 | thread.start() | 253 | thread.start() |
| 245 | 254 | ||
| 246 | 255 | ||
| @@ -303,14 +312,14 @@ def main(): | |||
| 303 | args = parse_arguments() | 312 | args = parse_arguments() |
| 304 | confparser = read_config(args.config) | 313 | confparser = read_config(args.config) |
| 305 | 314 | ||
| 306 | bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser) | 315 | bind_addr, client_port, pam_port, db = get_vars(args,confparser) |
| 307 | 316 | ||
| 308 | if not os.path.exists(DB_NAME): | 317 | if not os.path.exists(db): |
| 309 | print("Creating DB") | 318 | print("Creating DB") |
| 310 | create_db(DB_NAME) | 319 | create_db(db) |
| 311 | 320 | ||
| 312 | clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) | 321 | clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port)) |
| 313 | pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) | 322 | pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port)) |
| 314 | clients.start() | 323 | clients.start() |
| 315 | pam.start() | 324 | pam.start() |
| 316 | 325 | ||
