# Copyright 2011-2013 Colin Scott
# Copyright 2011-2013 Andreas Wundsam
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import itertools
import os
import string
import sys
import re
import socket
import random
from sts.util.convenience import find_port
from sts.entities import Controller, POXController, BigSwitchController
controller_type_map = {
"pox": POXController,
"bsc": BigSwitchController
}
[docs]class ControllerConfig(object):
_port_gen = itertools.count(6633)
_controller_count_gen = itertools.count(1)
_controller_labels = set()
[docs] def __init__(self, start_cmd="", kill_cmd="", address="127.0.0.1", port=None,
additional_ports={}, cwd=None, sync=None, controller_type=None,
label=None, config_file=None, config_template=None,
try_new_ports=False):
'''
Store metadata for the controller.
- start_cmd: command that starts a controller or a set of controllers,
followed by a list of command line tokens as arguments
- kill_cmd: command that kills a controller or a set of controllers,
followed by a list of command line tokens as arguments
- address, port: controller socket info to listen for switches on
- controller_type: controller vendor, specified by the corresponding Controller
class itself, or a string chosen from one of the keys in controller_type_map
'''
if start_cmd == "":
raise RuntimeError("Must specify boot parameters.")
self.start_cmd = start_cmd
self.kill_cmd = kill_cmd
self.address = address
if (re.match("\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", address) or
address == "localhost"):
# Normal TCP socket
if not port:
port = self._port_gen.next()
if try_new_ports:
port = find_port(xrange(port, port+2000))
self.port = port
self._server_info = (self.address, port)
else:
# Unix domain socket
self.port = None
self._server_info = address
if controller_type is None:
for t in controller_type_map.keys():
if t in self.start_cmd:
self.type = controller_type_map[t]
break
else:
raise RuntimeError("Controller type not inferred from command line!")
elif isinstance(controller_type, Controller):
self.type = controller_type
else:
if controller_type not in controller_type_map.keys():
raise RuntimeError("Unknown controller type: %s" % controller_type)
self.type = controller_type_map[controller_type]
self.cwd = cwd
if not cwd:
sys.stderr.write("""
=======================================================================
WARN - no working directory defined for controller with command line
%s
The controller is run in the STS base directory. This may result
in unintended consequences (i.e. controller not logging correctly).
=======================================================================
\n""" % (self.start_cmd) )
self.sync = sync
if label:
label = label
else:
label = "c"+str(self._controller_count_gen.next())
if label in self._controller_labels:
raise ValueError("Label %s already registered!" % label)
self._controller_labels.add(label)
self.label = label
self.config_file = config_file
self.config_template = config_template
self.additional_ports = additional_ports
@property
[docs] def cid(self):
''' Return this controller's id '''
return self.label
@property
[docs] def server_info(self):
""" information about the _real_ socket that the controller is listening on"""
return self._server_info
def _expand_vars(self, s):
return reduce(lambda s, (name, val): s.replace("__%s_port__" % name, str(val)), self.additional_ports.iteritems(), s) \
.replace("__port__", str(self.port)) \
.replace("__address__", str(self.address)) \
.replace("__config__", str(os.path.abspath(self.config_file) if self.config_file else ""))
@property
[docs] def expanded_start_cmd(self):
return map(self._expand_vars, self.start_cmd.split())
@property
[docs] def expanded_kill_cmd(self):
return map(self._expand_vars, self.kill_cmd.split())
[docs] def generate_config_file(self, target_dir):
if self.config_file is None:
self.config_file = os.path.join(target_dir, os.path.basename(self.config_template).replace(".template", ""))
with open(self.config_template, "r") as in_file:
with open(self.config_file, "w") as out_file:
out_file.write(self._expand_vars(in_file.read()))
def __repr__(self):
attributes = ("start_cmd", "address", "port", "cwd", "sync")
pairs = ( (attr, getattr(self, attr)) for attr in attributes)
quoted = ( "%s=%s" % (attr, repr(value)) for (attr, value) in pairs if value)
return self.__class__.__name__ + "(" + ", ".join(quoted) + ")"