# Copyright 2011-2013 Colin Scott
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from base import *
import socket
import logging
import base64
[docs]class ServerSocketDemultiplexer(SocketDemultiplexer):
[docs] def __init__(self, true_io_worker, mock_listen_sock):
''' Whenever we see a handshake from the client, hand new MockSockets to
mock_listen_sock so that they can be accept()'ed'''
super(ServerSocketDemultiplexer, self).__init__(true_io_worker)
self.mock_listen_sock = mock_listen_sock
def _on_receive(self, worker, json_hash):
super(ServerSocketDemultiplexer, self)._on_receive(worker, json_hash)
sock_id = json_hash['id']
msg_type = json_hash['type']
if msg_type == "SYN":
# we just saw an unknown channel.
print("Incoming MockSocket connection %s" %
json_hash['address'])
new_sock = self.new_socket(sock_id=sock_id,
peer_address=json_hash['address'])
self.mock_listen_sock.append_new_mock_socket(new_sock)
elif msg_type == "data":
raw_data = base64.b64decode(json_hash['data'])
sock_id = json_hash['id']
if sock_id not in self.id2socket:
raise ValueError("Unknown socket id %d" % sock_id)
sock = self.id2socket[sock_id]
sock.append_read(raw_data)
else:
raise ValueError("Unknown msg_type %s" % msg_type)
[docs] def new_socket(self, sock_id=-1, peer_address=None):
sock = ServerMockSocket(None, None, sock_id=sock_id,
json_worker=self.json_worker,
peer_address=peer_address)
MultiplexedSelect.fileno2ready_to_read[sock_id] = sock.ready_to_read
self.id2socket[sock_id] = sock
return sock
[docs]class ServerMockSocket(MockSocket):
[docs] def __init__(self, protocol, sock_type, sock_id=-1, json_worker=None,
set_true_listen_socket=lambda: None, peer_address=None):
super(ServerMockSocket, self).__init__(protocol, sock_type,
sock_id=sock_id,
json_worker=json_worker)
self.set_true_listen_socket = set_true_listen_socket
self.peer_address = peer_address
self.new_sockets = []
self.log = logging.getLogger("mock_sock")
self.listener = False
[docs] def ready_to_read(self):
return self.pending_reads != [] or self.new_sockets != []
[docs] def bind(self, server_info):
# Before bind() is called, we don't know the
# address of the true connection.
self.server_info = server_info
[docs] def listen(self, _):
self.listener = True
# Here, we create a *real* socket.
# bind it to server_info, and wait for the client SocketDemultiplexer to
# connect. After this is done, we can instantiate our own
# SocketDemultiplexer.
# Assumes that all invocations of bind() are intended for connection to
# STS. TODO(cs): STS should tell pox_monkeypatcher exactly what ports it
# intends to connect to. If bind() is called for some other port, delegate to
# a real socket.
if hasattr(socket, "_old_socket"):
true_socket = socket._old_socket(self.protocol, self.sock_type)
else:
true_socket = socket.socket(self.protocol, self.sock_type)
true_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
true_socket.bind(self.server_info)
true_socket.setblocking(0)
true_socket.listen(1)
# We give this true socket to select.select and
# wait for the client SocketDemultiplexer connection
self.set_true_listen_socket(true_socket, self,
accept_callback=self._accept_callback)
def _accept_callback(self, io_worker):
# Keep a reference so that it won't be gc'ed?
self.demux = ServerSocketDemultiplexer(io_worker, mock_listen_sock=self)
# revert the monkeypatch of socket.socket in case the server
# makes auxiliary TCP connections
if hasattr(socket, "_old_socket"):
socket.socket = socket._old_socket
[docs] def accept(self):
sock = self.new_sockets.pop(0)
return (sock, self.peer_address)
[docs] def append_new_mock_socket(self, mock_sock):
self.new_sockets.append(mock_sock)
def __repr__(self):
if self.listener:
return "MockListenerSocket"
else:
return "MockServerSocket"
[docs]class ServerMultiplexedSelect(MultiplexedSelect):
[docs] def __init__(self, *args, **kwargs):
super(ServerMultiplexedSelect, self).__init__(*args, **kwargs)
self.true_listen_socks = []
self.mock_listen_socks = []
self.listen_sock_to_accept_callback = {}
self.pending_accepts = 0
[docs] def set_true_listen_socket(self, true_listen_socket, mock_listen_sock, accept_callback):
# Keep around true_listen_socket until STS's SocketDemultiplexer connects
self.true_listen_socks.append(true_listen_socket)
self.mock_listen_socks.append(mock_listen_sock)
# At this point, bind() has been called, and we need to wait for the
# client SocketDemultiplexer to connect. After it connects, we invoke
# accept_callback with the new io_worker as a parameter
self.pending_accepts += 1
self.listen_sock_to_accept_callback[true_listen_socket] = accept_callback
[docs] def ready_to_read(self, sock_or_io_worker):
if sock_or_io_worker in self.mock_listen_socks:
return sock_or_io_worker.ready_to_read()
else:
return super(ServerMultiplexedSelect, self)\
.ready_to_read(sock_or_io_worker)
[docs] def grab_workers_rwe(self):
(rl, wl, xl) = super(ServerMultiplexedSelect, self).grab_workers_rwe()
rl += self.true_listen_socks
return (rl, wl, xl)
[docs] def handle_socks_rwe(self, rl, wl, xl, mock_read_socks, mock_write_workers):
for true_listen_sock in self.true_listen_socks:
if true_listen_sock in xl:
raise RuntimeError("Error in listen socket")
if self.pending_accepts > 0 and true_listen_sock in rl:
rl.remove(true_listen_sock)
# Once the listen sock gets an accept(), throw out it out (no
# longer needed), replace it with the return of accept(),
# and invoke the accept_callback
print "Incoming true socket connected"
self.pending_accepts -= 1
new_sock = true_listen_sock.accept()[0]
true_listen_sock.close()
self.true_listen_socks.remove(true_listen_sock)
self.set_true_io_worker(self.create_worker_for_socket(new_sock))
self.listen_sock_to_accept_callback[true_listen_sock]\
(self.true_io_workers[-1])
del self.listen_sock_to_accept_callback[true_listen_sock]
return super(ServerMultiplexedSelect, self)\
.handle_socks_rwe(rl, wl, xl, mock_read_socks, mock_write_workers)