diff options
Diffstat (limited to 'scripts/NetworkProxy.py')
-rw-r--r-- | scripts/NetworkProxy.py | 65 |
1 files changed, 44 insertions, 21 deletions
diff --git a/scripts/NetworkProxy.py b/scripts/NetworkProxy.py index 4d4493e8b59..464c80580be 100644 --- a/scripts/NetworkProxy.py +++ b/scripts/NetworkProxy.py @@ -7,14 +7,13 @@ # # ********************************************************************** -import sys, os, threading, socket, select, atexit +import sys, os, threading, socket, select class InvalidRequest(Exception): pass class BaseConnection(threading.Thread): def __init__(self, socket, remote): threading.Thread.__init__(self) - self.setDaemon(True) self.socket = socket self.remote = remote self.remoteSocket = None @@ -27,6 +26,8 @@ class BaseConnection(threading.Thread): pass def close(self): + if self.closed: + return self.closed = True try: if self.socket: @@ -77,39 +78,59 @@ class BaseProxy(threading.Thread): threading.Thread.__init__(self) self.port = port self.closed = False - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.cond = threading.Condition() + self.socket = None + self.failed = None self.connections = [] - atexit.register(self.terminate) - self.setDaemon(True) self.start() + with self.cond: + while not self.socket and not self.failed: + self.cond.wait(60) + if self.failed: + raise self.failed def createConnection(self): return None def run(self): - self.socket.bind(("127.0.0.1", self.port)) - self.socket.listen(1) + with self.cond: + try: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, "SO_REUSEPORT"): + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + self.socket.bind(("127.0.0.1", self.port)) + self.socket.listen(1) + self.cond.notify() + except Exception as ex: + self.failed = ex + self.cond.notify() + return + try: while not self.closed: incoming, peer = self.socket.accept() connection = self.createConnection(incoming, peer) connection.start() - self.connections.append(connection) + with self.cond: + self.connections.append(connection) except: pass finally: self.socket.close() def terminate(self): - if self.closed: - return - self.closed = True - for c in self.connections: - try: - c.close() - except Exception as ex: - print(ex) + with self.cond: + if self.closed: + return + self.closed = True + for c in self.connections: + try: + c.close() + c.join() + except Exception as ex: + print(ex) + connectToSelf = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: connectToSelf.connect(("127.0.0.1", self.port)) @@ -118,6 +139,8 @@ class BaseProxy(threading.Thread): finally: connectToSelf.close() + self.join() + class SocksConnection(BaseConnection): def request(self, s): @@ -152,7 +175,7 @@ class SocksConnection(BaseConnection): return packet if sys.version_info[0] == 2 else bytes(packet,"ascii") class SocksProxy(BaseProxy): - + def createConnection(self, socket, peer): return SocksConnection(socket, peer) @@ -172,12 +195,12 @@ class HttpConnection(BaseConnection): sep = data.find(":") if sep < len("CONNECT ") + 1: raise InvalidRequest - + host = data[len("CONNECT "):sep] space = data.find(" ", sep) if space < sep + 1: raise InvalidRequest - + port = int(data[sep + 1:space]) return (host, port) @@ -189,6 +212,6 @@ class HttpConnection(BaseConnection): return s if sys.version_info[0] == 2 else bytes(s,"ascii") class HttpProxy(BaseProxy): - + def createConnection(self, socket, peer): return HttpConnection(socket, peer) |