diff options
Diffstat (limited to 'server')
-rwxr-xr-x | server/mfad.py | 38 |
1 files changed, 26 insertions, 12 deletions
diff --git a/server/mfad.py b/server/mfad.py index 17d2585..18a048a 100755 --- a/server/mfad.py +++ b/server/mfad.py | |||
@@ -1,5 +1,6 @@ | |||
1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 |
2 | import socket | 2 | import socket |
3 | import ssl | ||
3 | import os | 4 | import os |
4 | import sys | 5 | import sys |
5 | import time | 6 | import time |
@@ -211,8 +212,9 @@ def parse_pam_data(data): | |||
211 | def handle_pam(db, conn, addr): | 212 | def handle_pam(db, conn, addr): |
212 | # Get request and data from PAM module | 213 | # Get request and data from PAM module |
213 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) | 214 | header = conn.recv(HEADER_LENGTH).decode(FORMAT) |
214 | if header == "": | 215 | if len(header) != HEADER_LENGTH: |
215 | die("error: lost connection to pam module") | 216 | conn.close() |
217 | die("error: invalid data from PAM module") | ||
216 | data_length = int(header) | 218 | data_length = int(header) |
217 | pam_data = conn.recv(data_length).decode(FORMAT) | 219 | pam_data = conn.recv(data_length).decode(FORMAT) |
218 | print("Got pam_data: " + pam_data) | 220 | print("Got pam_data: " + pam_data) |
@@ -238,19 +240,31 @@ def handle_pam(db, conn, addr): | |||
238 | 240 | ||
239 | 241 | ||
240 | def listen_client(db, addr, port): | 242 | def listen_client(db, addr, port): |
241 | with socket.create_server((addr, port)) as server: | 243 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
242 | while True: | 244 | context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") |
243 | conn, addr = server.accept() | 245 | with socket.create_server((addr, port)) as sock: |
244 | thread = threading.Thread(target=handle_client,args=(db, conn,addr)) | 246 | with context.wrap_socket(sock, server_side=True) as tls_socket: |
245 | thread.start() | 247 | while True: |
248 | try: | ||
249 | conn, addr = tls_socket.accept() | ||
250 | thread = threading.Thread(target=handle_client,args=(db,conn,addr)) | ||
251 | thread.start() | ||
252 | except ssl.SSLError: | ||
253 | print("client: ssl handshake error") | ||
246 | 254 | ||
247 | 255 | ||
248 | def listen_pam(db, addr, port): | 256 | def listen_pam(db, addr, port): |
249 | with socket.create_server((addr,port)) as pam_server: | 257 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
250 | while True: | 258 | context.load_cert_chain(certfile="server/cert.pem", keyfile="server/key.pem") |
251 | conn, addr = pam_server.accept() | 259 | with socket.create_server((addr,port)) as sock: |
252 | thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) | 260 | with context.wrap_socket(sock, server_side=True) as tls_socket: |
253 | thread.start() | 261 | while True: |
262 | try: | ||
263 | conn, addr = tls_socket.accept() | ||
264 | thread = threading.Thread(target=handle_pam,args=(db, conn,addr)) | ||
265 | thread.start() | ||
266 | except ssl.SSLError: | ||
267 | print("pam: ssl handshake error") | ||
254 | 268 | ||
255 | 269 | ||
256 | ################################################################################ | 270 | ################################################################################ |