Source code for sts.util.socket_mux.sts_socket_multiplexer

# Copyright 2011-2013 Colin Scott
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from base import *
from itertools import count
import logging
import base64

log = logging.getLogger("sts_sock_mux")

[docs]class STSSocketDemultiplexer(SocketDemultiplexer): # All mock sockets have negative fileno()s, to differentiate them from # normal files # -1 is reserved for the listen socket _id_gen = count(start=-2, step=-1)
[docs] def __init__(self, true_io_worker, server_info): super(STSSocketDemultiplexer, self).__init__(true_io_worker) self.server_info = server_info # let MockSockets know who their Demuxer is upon connect() STSMockSocket.address2demuxer[server_info] = self
def _on_receive(self, worker, json_hash): super(STSSocketDemultiplexer, self)._on_receive(worker, json_hash) assert(json_hash['type'] == 'data' and 'data' in json_hash) sock_id = json_hash['id'] if sock_id not in self.id2socket: raise ValueError("Unknown socket id %d" % sock_id) sock = self.id2socket[sock_id] raw_data = base64.b64decode(json_hash['data']) sock.append_read(raw_data)
[docs] def add_new_socket(self, new_socket): sock_id = self._id_gen.next() new_socket.sock_id = sock_id new_socket.json_worker = self.json_worker MultiplexedSelect.fileno2ready_to_read[sock_id] = new_socket.ready_to_read self.id2socket[sock_id] = new_socket
[docs]class STSMockSocket(MockSocket): # Set by STSSocketDemuxers so we know who our demuxer is upon connect() address2demuxer = {}
[docs] def connect(self, address): if address not in self.address2demuxer: raise RuntimeError("don't know our demuxer %s" % str(address)) self.peer_address = address demuxer = self.address2demuxer[address] demuxer.add_new_socket(self) # Send a SYN true_address = demuxer.client_info wrapped = {'id' : self.sock_id, 'type' : 'SYN', 'address' : true_address } self.json_worker.send(wrapped) # Note: select() won't be called by STS with this socket as a param until # the switch receives a HELLO message. But for that to occur, we need the # controller to initiate the HELLO message in reaction to our connection # attempt. Therefore, we need to explicitly # cause the underlying socket to send here. try: l = demuxer.true_io_worker.socket.send(demuxer.true_io_worker.send_buf) if l > 0: demuxer.true_io_worker._consume_send_buf(l) except socket.error as (s_errno, strerror): log.error("Socket error: " + strerror) raise
[docs] def getpeername(self): return self.peer_address # To monkey patch client side: # - After booting the controller, # - and after STSSyncProtocol's socket has been created (no more auxiliary # sockets remain to be instantiated) # - create a single real socket for each ControllerInfo # - connect them normally # - wrap them in MultiplexedSelect's io_worker # - create a STSSocketDemultiplexer for them # - override select.select with MultiplexedSelect (this will create a true # socket for the pinger) # - override socket.socket # - takes two params: protocol, socket type # - if not SOCK_STREAM type, return a normal socket # - else, return STSMockSocket