# Copyright 2011-2013 Colin Scott
# Copyright 2011-2013 Andreas Wundsam
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import os
import shutil
import itertools
from copy import copy
import types
import signal
import tempfile
from config.experiment_config_lib import ControllerConfig
from sts.control_flow import Replayer, MCSFinder, EfficientMCSFinder
from sts.topology import FatTree, MeshTopology
from sts.simulation_state import Simulation, SimulationConfig
from sts.replay_event import Event, InternalEvent, InputEvent
from sts.event_dag import EventDag
from sts.entities import Host, Controller
import logging
sys.path.append(os.path.dirname(__file__) + "/../../..")
[docs]class MockMCSFinderBase(MCSFinder):
''' Overrides self.invariant_check and run_simulation_forward() '''
[docs] def __init__(self, event_dag, mcs):
super(MockMCSFinderBase, self).__init__(None, None,
invariant_check_name="InvariantChecker.check_liveness")
# Hack! Give a fake name in config.invariant_checks.name_to_invariant_checks, but
# but remove it from our dict directly after. This is to prevent
# sanity check exceptions from being thrown.
self.invariant_check = self._invariant_check
self.dag = event_dag
self.new_dag = None
self.mcs = mcs
self.simulation = None
self.transform_dag = None
[docs] def log(self, message):
self._log.info(message)
def _invariant_check(self, _):
for e in self.mcs:
if e not in self.new_dag._events_set:
return []
return ["violation"]
[docs] def replay(self, new_dag, hook=None):
self.new_dag = new_dag
return self.invariant_check(new_dag)
# Horrible horrible hack. This way lies insanity
[docs]class MockMCSFinder(MockMCSFinderBase, MCSFinder):
[docs] def __init__(self, event_dag, mcs):
MockMCSFinderBase.__init__(self, event_dag, mcs)
self._log = logging.getLogger("mock_mcs_finder")
[docs]class MockEfficientMCSFinder(MockMCSFinderBase, EfficientMCSFinder):
[docs] def __init__(self, event_dag, mcs):
MockMCSFinderBase.__init__(self, event_dag, mcs)
self._log = logging.getLogger("mock_efficient_mcs_finder")
mcs_results_path = "/tmp/mcs_results"
[docs]class MCSFinderTest(unittest.TestCase):
[docs] def test_basic(self):
self.basic(MockMCSFinder)
[docs] def test_basic_efficient(self):
self.basic(MockEfficientMCSFinder)
[docs] def basic(self, mcs_finder_type):
trace = [ MockInputEvent(fingerprint=("class",f)) for f in range(1,7) ]
dag = EventDag(trace)
mcs = [trace[0]]
mcs_finder = mcs_finder_type(dag, mcs)
try:
os.makedirs(mcs_results_path)
mcs_finder.init_results(mcs_results_path)
mcs_finder.simulate()
finally:
shutil.rmtree(mcs_results_path)
self.assertEqual(mcs, mcs_finder.dag.input_events)
[docs] def test_straddle(self):
self.straddle(MockMCSFinder)
[docs] def test_straddle_efficient(self):
self.straddle(MockEfficientMCSFinder)
[docs] def straddle(self, mcs_finder_type):
trace = [ MockInputEvent(fingerprint=("class",f)) for f in range(1,7) ]
dag = EventDag(trace)
mcs = [trace[0],trace[5]]
mcs_finder = mcs_finder_type(dag, mcs)
try:
os.makedirs(mcs_results_path)
mcs_finder.init_results(mcs_results_path)
mcs_finder.simulate()
finally:
shutil.rmtree(mcs_results_path)
self.assertEqual(mcs, mcs_finder.dag.input_events)
[docs] def test_all(self):
self.all(MockMCSFinder)
[docs] def test_all_efficient(self):
self.all(MockEfficientMCSFinder)
[docs] def all(self, mcs_finder_type):
trace = [ MockInputEvent(fingerprint=("class",f)) for f in range(1,7) ]
dag = EventDag(trace)
mcs = trace
mcs_finder = mcs_finder_type(dag, mcs)
try:
os.makedirs(mcs_results_path)
mcs_finder.init_results(mcs_results_path)
mcs_finder.simulate()
finally:
shutil.rmtree(mcs_results_path)
self.assertEqual(mcs, mcs_finder.dag.input_events)
if __name__ == '__main__':
unittest.main()