diff options
author | Sam Chudnick <sam@chudnick.com> | 2022-07-03 05:45:21 -0400 |
---|---|---|
committer | Sam Chudnick <sam@chudnick.com> | 2022-07-03 05:45:21 -0400 |
commit | ce3c9f1e849b871db2fa91b5aa030e8ea471a7ca (patch) | |
tree | 24cecd0a4f25d151066ab53b4875b76399019038 /server | |
parent | 41132171e065081736c82ee283c64309e30baa9b (diff) |
Fixed issue caused by non-static database location
Pass database location as argument where needed now that location is not
static.
Diffstat (limited to 'server')
-rwxr-xr-x | server/mfad.py | 59 |
1 files changed, 34 insertions, 25 deletions
diff --git a/server/mfad.py b/server/mfad.py index 46fc0cc..17d2585 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
@@ -18,7 +18,6 @@ import argparse | |||
18 | ## Return pass or fail response to PAM moudle | 18 | ## Return pass or fail response to PAM moudle |
19 | 19 | ||
20 | 20 | ||
21 | DB_NAME = "" | ||
22 | HEADER_LENGTH = 64 | 21 | HEADER_LENGTH = 64 |
23 | KEY_LENGTH = 64 | 22 | KEY_LENGTH = 64 |
24 | DISCONNECT_LENGTH = ACK_LENGTH = 3 | 23 | DISCONNECT_LENGTH = ACK_LENGTH = 3 |
@@ -43,6 +42,10 @@ CLIENT_SECRET_INDEX = 2 | |||
43 | # key and a tuple of (socket,(addr,port)) as the value | 42 | # key and a tuple of (socket,(addr,port)) as the value |
44 | client_connections = dict() | 43 | client_connections = dict() |
45 | 44 | ||
45 | def die(msg): | ||
46 | print(msg) | ||
47 | sys.exit(1) | ||
48 | |||
46 | def parse_arguments(): | 49 | def parse_arguments(): |
47 | parser = argparse.ArgumentParser() | 50 | parser = argparse.ArgumentParser() |
48 | parser.add_argument("--address",type=str,help="Bind Address") | 51 | parser.add_argument("--address",type=str,help="Bind Address") |
@@ -60,7 +63,7 @@ def read_config(config): | |||
60 | return parser | 63 | return parser |
61 | 64 | ||
62 | 65 | ||
63 | def eval_mfa(client_key, mfa_methods, client_response): | 66 | def eval_mfa(db, client_key, mfa_methods, client_response): |
64 | print("response: " + client_response) | 67 | print("response: " + client_response) |
65 | print("length: " + str(len(client_response))) | 68 | print("length: " + str(len(client_response))) |
66 | print("methods: " + str(mfa_methods)) | 69 | print("methods: " + str(mfa_methods)) |
@@ -74,13 +77,13 @@ def eval_mfa(client_key, mfa_methods, client_response): | |||
74 | totp_regex = re.compile(totp_format) | 77 | totp_regex = re.compile(totp_format) |
75 | matched = totp_regex.match(client_response) | 78 | matched = totp_regex.match(client_response) |
76 | if matched: | 79 | if matched: |
77 | return validate_totp(client_key, client_response) | 80 | return validate_totp(db, client_key, client_response) |
78 | return DENIED | 81 | return DENIED |
79 | 82 | ||
80 | 83 | ||
81 | def validate_totp(client_key, client_response): | 84 | def validate_totp(db, client_key, client_response): |
82 | secret = "" | 85 | secret = "" |
83 | with sqlite3.connect(DB_NAME) as conn: | 86 | with sqlite3.connect(db) as conn: |
84 | c = conn.cursor() | 87 | c = conn.cursor() |
85 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) | 88 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) |
86 | client = c.fetchone() | 89 | client = c.fetchone() |
@@ -103,7 +106,7 @@ def validate_totp(client_key, client_response): | |||
103 | 106 | ||
104 | # //TODO RSA public/private key pairs for proper authentication | 107 | # //TODO RSA public/private key pairs for proper authentication |
105 | 108 | ||
106 | def get_client_key(username,hostname,service): | 109 | def get_client_key(db, username,hostname,service): |
107 | # Correlates a PAM request to a registered client | 110 | # Correlates a PAM request to a registered client |
108 | # This is done by checking the PAM request against a preconfigured | 111 | # This is done by checking the PAM request against a preconfigured |
109 | # database mapping request info (username,hostname,etc...) to clients | 112 | # database mapping request info (username,hostname,etc...) to clients |
@@ -111,7 +114,7 @@ def get_client_key(username,hostname,service): | |||
111 | 114 | ||
112 | application = None | 115 | application = None |
113 | client = None | 116 | client = None |
114 | with sqlite3.connect(DB_NAME) as conn: | 117 | with sqlite3.connect(db) as conn: |
115 | c = conn.cursor() | 118 | c = conn.cursor() |
116 | c.execute("""SELECT * FROM applications WHERE username=? AND hostname=? | 119 | c.execute("""SELECT * FROM applications WHERE username=? AND hostname=? |
117 | AND service=?""",(username,hostname,service)) | 120 | AND service=?""",(username,hostname,service)) |
@@ -151,7 +154,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10): | |||
151 | conn.send(length_msg.encode(FORMAT)) | 154 | conn.send(length_msg.encode(FORMAT)) |
152 | conn.send(prompt_msg.encode(FORMAT)) | 155 | conn.send(prompt_msg.encode(FORMAT)) |
153 | # receive response | 156 | # receive response |
154 | response_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) | 157 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) |
158 | if header == "": | ||
159 | die("error: lost connection to client") | ||
160 | response_length = int(header) | ||
155 | response = conn.recv(response_length).decode(FORMAT) | 161 | response = conn.recv(response_length).decode(FORMAT) |
156 | return response | 162 | return response |
157 | except BrokenPipeError: | 163 | except BrokenPipeError: |
@@ -165,10 +171,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10): | |||
165 | return 0 | 171 | return 0 |
166 | 172 | ||
167 | 173 | ||
168 | def validate_client(client_key): | 174 | def validate_client(db, client_key): |
169 | # Validates a client | 175 | # Validates a client |
170 | client = None | 176 | client = None |
171 | with sqlite3.connect(DB_NAME) as conn: | 177 | with sqlite3.connect(db) as conn: |
172 | c = conn.cursor() | 178 | c = conn.cursor() |
173 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) | 179 | c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) |
174 | client = c.fetchall() | 180 | client = c.fetchall() |
@@ -184,11 +190,11 @@ def validate_client(client_key): | |||
184 | 190 | ||
185 | 191 | ||
186 | 192 | ||
187 | def handle_client(conn, addr): | 193 | def handle_client(db, conn, addr): |
188 | # Receive key from client | 194 | # Receive key from client |
189 | key = conn.recv(KEY_LENGTH).decode(FORMAT) | 195 | key = conn.recv(KEY_LENGTH).decode(FORMAT) |
190 | # Validate client | 196 | # Validate client |
191 | if not validate_client(key): | 197 | if not validate_client(db, key): |
192 | print("WARNING: client attempted to connect with invalid key") | 198 | print("WARNING: client attempted to connect with invalid key") |
193 | conn.send(DISCONNECT_MESSAGE.encode(FORMAT)) | 199 | conn.send(DISCONNECT_MESSAGE.encode(FORMAT)) |
194 | conn.close() | 200 | conn.close() |
@@ -202,15 +208,18 @@ def parse_pam_data(data): | |||
202 | # Parses pam data and returns (user,host,service) tuple | 208 | # Parses pam data and returns (user,host,service) tuple |
203 | return tuple(data.split(',')) | 209 | return tuple(data.split(',')) |
204 | 210 | ||
205 | def handle_pam(conn, addr): | 211 | def handle_pam(db, conn, addr): |
206 | # Get request and data from PAM module | 212 | # Get request and data from PAM module |
207 | data_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) | 213 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) |
214 | if header == "": | ||
215 | die("error: lost connection to pam module") | ||
216 | data_length = int(header) | ||
208 | pam_data = conn.recv(data_length).decode(FORMAT) | 217 | pam_data = conn.recv(data_length).decode(FORMAT) |
209 | print("Got pam_data: " + pam_data) | 218 | print("Got pam_data: " + pam_data) |
210 | user,host,service = parse_pam_data(pam_data) | 219 | user,host,service = parse_pam_data(pam_data) |
211 | 220 | ||
212 | # Correlate request to client | 221 | # Correlate request to client |
213 | client_key,mfa_methods = get_client_key(user,host,service) | 222 | client_key,mfa_methods = get_client_key(db, user,host,service) |
214 | mfa_methods = mfa_methods.split(' ') | 223 | mfa_methods = mfa_methods.split(' ') |
215 | if client_key == None: | 224 | if client_key == None: |
216 | print("No applications found for user="+user+" host="+host+" service="+service) | 225 | print("No applications found for user="+user+" host="+host+" service="+service) |
@@ -221,26 +230,26 @@ def handle_pam(conn, addr): | |||
221 | response = prompt_client(client_key,user,host,service,mfa_methods) | 230 | response = prompt_client(client_key,user,host,service,mfa_methods) |
222 | 231 | ||
223 | # Evaluate Response | 232 | # Evaluate Response |
224 | decision = eval_mfa(client_key, mfa_methods, response) | 233 | decision = eval_mfa(db, client_key, mfa_methods, response) |
225 | 234 | ||
226 | # Return response to PAM module | 235 | # Return response to PAM module |
227 | # Respone will either be 0 for authenticated and 1 for denied | 236 | # Respone will either be 0 for authenticated and 1 for denied |
228 | conn.send(str(decision).encode(FORMAT)) | 237 | conn.send(str(decision).encode(FORMAT)) |
229 | 238 | ||
230 | 239 | ||
231 | def listen_client(addr, port): | 240 | def listen_client(db, addr, port): |
232 | with socket.create_server((addr, port)) as server: | 241 | with socket.create_server((addr, port)) as server: |
233 | while True: | 242 | while True: |
234 | conn, addr = server.accept() | 243 | conn, addr = server.accept() |
235 | thread = threading.Thread(target=handle_client,args=(conn,addr)) | 244 | thread = threading.Thread(target=handle_client,args=(db, conn,addr)) |
236 | thread.start() | 245 | thread.start() |
237 | 246 | ||
238 | 247 | ||
239 | def listen_pam(addr, port): | 248 | def listen_pam(db, addr, port): |
240 | with socket.create_server((addr,port)) as pam_server: | 249 | with socket.create_server((addr,port)) as pam_server: |
241 | while True: | 250 | while True: |
242 | conn, addr = pam_server.accept() | 251 | conn, addr = pam_server.accept() |
243 | thread = threading.Thread(target=handle_pam,args=(conn,addr)) | 252 | thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) |
244 | thread.start() | 253 | thread.start() |
245 | 254 | ||
246 | 255 | ||
@@ -303,14 +312,14 @@ def main(): | |||
303 | args = parse_arguments() | 312 | args = parse_arguments() |
304 | confparser = read_config(args.config) | 313 | confparser = read_config(args.config) |
305 | 314 | ||
306 | bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser) | 315 | bind_addr, client_port, pam_port, db = get_vars(args,confparser) |
307 | 316 | ||
308 | if not os.path.exists(DB_NAME): | 317 | if not os.path.exists(db): |
309 | print("Creating DB") | 318 | print("Creating DB") |
310 | create_db(DB_NAME) | 319 | create_db(db) |
311 | 320 | ||
312 | clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) | 321 | clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port)) |
313 | pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) | 322 | pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port)) |
314 | clients.start() | 323 | clients.start() |
315 | pam.start() | 324 | pam.start() |
316 | 325 | ||