summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xserver/mfad.py59
1 files changed, 34 insertions, 25 deletions
diff --git a/server/mfad.py b/server/mfad.py
index 46fc0cc..17d2585 100755
--- a/server/mfad.py
+++ b/server/mfad.py
@@ -18,7 +18,6 @@ import argparse
18## Return pass or fail response to PAM moudle 18## Return pass or fail response to PAM moudle
19 19
20 20
21DB_NAME = ""
22HEADER_LENGTH = 64 21HEADER_LENGTH = 64
23KEY_LENGTH = 64 22KEY_LENGTH = 64
24DISCONNECT_LENGTH = ACK_LENGTH = 3 23DISCONNECT_LENGTH = ACK_LENGTH = 3
@@ -43,6 +42,10 @@ CLIENT_SECRET_INDEX = 2
43# key and a tuple of (socket,(addr,port)) as the value 42# key and a tuple of (socket,(addr,port)) as the value
44client_connections = dict() 43client_connections = dict()
45 44
45def die(msg):
46 print(msg)
47 sys.exit(1)
48
46def parse_arguments(): 49def parse_arguments():
47 parser = argparse.ArgumentParser() 50 parser = argparse.ArgumentParser()
48 parser.add_argument("--address",type=str,help="Bind Address") 51 parser.add_argument("--address",type=str,help="Bind Address")
@@ -60,7 +63,7 @@ def read_config(config):
60 return parser 63 return parser
61 64
62 65
63def eval_mfa(client_key, mfa_methods, client_response): 66def eval_mfa(db, client_key, mfa_methods, client_response):
64 print("response: " + client_response) 67 print("response: " + client_response)
65 print("length: " + str(len(client_response))) 68 print("length: " + str(len(client_response)))
66 print("methods: " + str(mfa_methods)) 69 print("methods: " + str(mfa_methods))
@@ -74,13 +77,13 @@ def eval_mfa(client_key, mfa_methods, client_response):
74 totp_regex = re.compile(totp_format) 77 totp_regex = re.compile(totp_format)
75 matched = totp_regex.match(client_response) 78 matched = totp_regex.match(client_response)
76 if matched: 79 if matched:
77 return validate_totp(client_key, client_response) 80 return validate_totp(db, client_key, client_response)
78 return DENIED 81 return DENIED
79 82
80 83
81def validate_totp(client_key, client_response): 84def validate_totp(db, client_key, client_response):
82 secret = "" 85 secret = ""
83 with sqlite3.connect(DB_NAME) as conn: 86 with sqlite3.connect(db) as conn:
84 c = conn.cursor() 87 c = conn.cursor()
85 c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) 88 c.execute("SELECT * FROM clients WHERE key=?",(client_key,))
86 client = c.fetchone() 89 client = c.fetchone()
@@ -103,7 +106,7 @@ def validate_totp(client_key, client_response):
103 106
104# //TODO RSA public/private key pairs for proper authentication 107# //TODO RSA public/private key pairs for proper authentication
105 108
106def get_client_key(username,hostname,service): 109def get_client_key(db, username,hostname,service):
107 # Correlates a PAM request to a registered client 110 # Correlates a PAM request to a registered client
108 # This is done by checking the PAM request against a preconfigured 111 # This is done by checking the PAM request against a preconfigured
109 # database mapping request info (username,hostname,etc...) to clients 112 # database mapping request info (username,hostname,etc...) to clients
@@ -111,7 +114,7 @@ def get_client_key(username,hostname,service):
111 114
112 application = None 115 application = None
113 client = None 116 client = None
114 with sqlite3.connect(DB_NAME) as conn: 117 with sqlite3.connect(db) as conn:
115 c = conn.cursor() 118 c = conn.cursor()
116 c.execute("""SELECT * FROM applications WHERE username=? AND hostname=? 119 c.execute("""SELECT * FROM applications WHERE username=? AND hostname=?
117 AND service=?""",(username,hostname,service)) 120 AND service=?""",(username,hostname,service))
@@ -151,7 +154,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10):
151 conn.send(length_msg.encode(FORMAT)) 154 conn.send(length_msg.encode(FORMAT))
152 conn.send(prompt_msg.encode(FORMAT)) 155 conn.send(prompt_msg.encode(FORMAT))
153 # receive response 156 # receive response
154 response_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) 157 header = conn.recv(HEADER_LENGTH).decode(FORMAT)
158 if header == "":
159 die("error: lost connection to client")
160 response_length = int(header)
155 response = conn.recv(response_length).decode(FORMAT) 161 response = conn.recv(response_length).decode(FORMAT)
156 return response 162 return response
157 except BrokenPipeError: 163 except BrokenPipeError:
@@ -165,10 +171,10 @@ def prompt_client(client_key, user, host, service, methods, timeout=10):
165 return 0 171 return 0
166 172
167 173
168def validate_client(client_key): 174def validate_client(db, client_key):
169 # Validates a client 175 # Validates a client
170 client = None 176 client = None
171 with sqlite3.connect(DB_NAME) as conn: 177 with sqlite3.connect(db) as conn:
172 c = conn.cursor() 178 c = conn.cursor()
173 c.execute("SELECT * FROM clients WHERE key=?",(client_key,)) 179 c.execute("SELECT * FROM clients WHERE key=?",(client_key,))
174 client = c.fetchall() 180 client = c.fetchall()
@@ -184,11 +190,11 @@ def validate_client(client_key):
184 190
185 191
186 192
187def handle_client(conn, addr): 193def handle_client(db, conn, addr):
188 # Receive key from client 194 # Receive key from client
189 key = conn.recv(KEY_LENGTH).decode(FORMAT) 195 key = conn.recv(KEY_LENGTH).decode(FORMAT)
190 # Validate client 196 # Validate client
191 if not validate_client(key): 197 if not validate_client(db, key):
192 print("WARNING: client attempted to connect with invalid key") 198 print("WARNING: client attempted to connect with invalid key")
193 conn.send(DISCONNECT_MESSAGE.encode(FORMAT)) 199 conn.send(DISCONNECT_MESSAGE.encode(FORMAT))
194 conn.close() 200 conn.close()
@@ -202,15 +208,18 @@ def parse_pam_data(data):
202 # Parses pam data and returns (user,host,service) tuple 208 # Parses pam data and returns (user,host,service) tuple
203 return tuple(data.split(',')) 209 return tuple(data.split(','))
204 210
205def handle_pam(conn, addr): 211def handle_pam(db, conn, addr):
206 # Get request and data from PAM module 212 # Get request and data from PAM module
207 data_length = int(conn.recv(HEADER_LENGTH).decode(FORMAT)) 213 header = conn.recv(HEADER_LENGTH).decode(FORMAT)
214 if header == "":
215 die("error: lost connection to pam module")
216 data_length = int(header)
208 pam_data = conn.recv(data_length).decode(FORMAT) 217 pam_data = conn.recv(data_length).decode(FORMAT)
209 print("Got pam_data: " + pam_data) 218 print("Got pam_data: " + pam_data)
210 user,host,service = parse_pam_data(pam_data) 219 user,host,service = parse_pam_data(pam_data)
211 220
212 # Correlate request to client 221 # Correlate request to client
213 client_key,mfa_methods = get_client_key(user,host,service) 222 client_key,mfa_methods = get_client_key(db, user,host,service)
214 mfa_methods = mfa_methods.split(' ') 223 mfa_methods = mfa_methods.split(' ')
215 if client_key == None: 224 if client_key == None:
216 print("No applications found for user="+user+" host="+host+" service="+service) 225 print("No applications found for user="+user+" host="+host+" service="+service)
@@ -221,26 +230,26 @@ def handle_pam(conn, addr):
221 response = prompt_client(client_key,user,host,service,mfa_methods) 230 response = prompt_client(client_key,user,host,service,mfa_methods)
222 231
223 # Evaluate Response 232 # Evaluate Response
224 decision = eval_mfa(client_key, mfa_methods, response) 233 decision = eval_mfa(db, client_key, mfa_methods, response)
225 234
226 # Return response to PAM module 235 # Return response to PAM module
227 # Respone will either be 0 for authenticated and 1 for denied 236 # Respone will either be 0 for authenticated and 1 for denied
228 conn.send(str(decision).encode(FORMAT)) 237 conn.send(str(decision).encode(FORMAT))
229 238
230 239
231def listen_client(addr, port): 240def listen_client(db, addr, port):
232 with socket.create_server((addr, port)) as server: 241 with socket.create_server((addr, port)) as server:
233 while True: 242 while True:
234 conn, addr = server.accept() 243 conn, addr = server.accept()
235 thread = threading.Thread(target=handle_client,args=(conn,addr)) 244 thread = threading.Thread(target=handle_client,args=(db, conn,addr))
236 thread.start() 245 thread.start()
237 246
238 247
239def listen_pam(addr, port): 248def listen_pam(db, addr, port):
240 with socket.create_server((addr,port)) as pam_server: 249 with socket.create_server((addr,port)) as pam_server:
241 while True: 250 while True:
242 conn, addr = pam_server.accept() 251 conn, addr = pam_server.accept()
243 thread = threading.Thread(target=handle_pam,args=(conn,addr)) 252 thread = threading.Thread(target=handle_pam,args=(db, conn,addr))
244 thread.start() 253 thread.start()
245 254
246 255
@@ -303,14 +312,14 @@ def main():
303 args = parse_arguments() 312 args = parse_arguments()
304 confparser = read_config(args.config) 313 confparser = read_config(args.config)
305 314
306 bind_addr, client_port, pam_port, DB_NAME = get_vars(args,confparser) 315 bind_addr, client_port, pam_port, db = get_vars(args,confparser)
307 316
308 if not os.path.exists(DB_NAME): 317 if not os.path.exists(db):
309 print("Creating DB") 318 print("Creating DB")
310 create_db(DB_NAME) 319 create_db(db)
311 320
312 clients = threading.Thread(target=listen_client,args=(bind_addr,client_port)) 321 clients = threading.Thread(target=listen_client,args=(db, bind_addr,client_port))
313 pam = threading.Thread(target=listen_pam,args=(bind_addr,pam_port)) 322 pam = threading.Thread(target=listen_pam,args=(db, bind_addr,pam_port))
314 clients.start() 323 clients.start()
315 pam.start() 324 pam.start()
316 325