Source code for sts.control_flow.mcs_finder

# Copyright 2011-2013 Colin Scott
# Copyright 2011-2013 Andreas Wundsam
# Copyright 2012-2013 Andrew Or
# Copyright 2012-2013 Sam Whitlock
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

'''
An orchestrating control flow that invokes replayer several times to
find the minimal causal sequence (MCS) of a failure.
'''

from sts.util.console import msg, color
from sts.util.convenience import timestamp_string, mkdir_p, ExitCode
from sts.util.rpc_forker import LocalForker, test_serialize_response
from sts.util.precompute_cache import PrecomputeCache, PrecomputePowerSetCache
from sts.replay_event import *
from sts.event_dag import EventDag, split_list
import sts.input_traces.log_parser as log_parser
from sts.input_traces.input_logger import InputLogger
from sts.control_flow.base import ControlFlow, ReplaySyncCallback
from sts.control_flow.replayer import Replayer
from sts.control_flow.peeker import Peeker
from config.invariant_checks import name_to_invariant_check

from collections import defaultdict, Counter
import copy
import itertools
import sys
import time
import random
import logging
import json
import os

[docs]class MCSFinder(ControlFlow):
[docs] def __init__(self, simulation_cfg, superlog_path_or_dag, invariant_check_name=None, transform_dag=None, end_wait_seconds=0.5, mcs_trace_path=None, extra_log=None, runtime_stats_file=None, wait_on_deterministic_values=False, no_violation_verification_runs=1, optimized_filtering=False, forker=LocalForker(), replay_final_trace=False, **kwargs): super(MCSFinder, self).__init__(simulation_cfg) self.mcs_log_tracker = None self.replay_log_tracker = None self.mcs_trace_path = mcs_trace_path self.sync_callback = None self._log = logging.getLogger("mcs_finder") if invariant_check_name is None: raise ValueError("Must specify invariant check") if invariant_check_name not in name_to_invariant_check: raise ValueError('''Unknown invariant check %s.\n''' '''Invariant check name must be defined in config.invariant_checks''', invariant_check_name) self.invariant_check = name_to_invariant_check[invariant_check_name] if type(superlog_path_or_dag) == str: self.superlog_path = superlog_path_or_dag # The dag is codefied as a list, where each element has # a list of its dependents self.dag = EventDag(log_parser.parse_path(self.superlog_path)) else: self.dag = superlog_path_or_dag self.transform_dag = transform_dag # A second log with just our MCS progress log messages self._extra_log = extra_log self.kwargs = kwargs self.end_wait_seconds = end_wait_seconds self.wait_on_deterministic_values = wait_on_deterministic_values # `no' means "number" self.no_violation_verification_runs = no_violation_verification_runs self._runtime_stats = RuntimeStats(runtime_stats_file) # Whether to try alternate trace splitting techiques besides splitting by time. self.optimized_filtering = optimized_filtering self.forker = forker self.replay_final_trace = replay_final_trace
[docs] def log(self, s): ''' Output a message to both self._log and self._extra_log ''' msg.mcs_event(s) if self._extra_log is not None: self._extra_log.write(s + '\n') self._extra_log.flush()
[docs] def log_violation(self, s): ''' Output a message to both self._log and self._extra_log ''' msg.mcs_event(color.RED + s) if self._extra_log is not None: self._extra_log.write(s + '\n') self._extra_log.flush()
[docs] def log_no_violation(self, s): ''' Output a message to both self._log and self._extra_log ''' msg.mcs_event(color.GREEN + s) if self._extra_log is not None: self._extra_log.write(s + '\n') self._extra_log.flush()
[docs] def init_results(self, results_dir): if self._extra_log is None: self._extra_log = open("%s/mcs_finder.log" % results_dir, "w") if self._runtime_stats.runtime_stats_file is None: self._runtime_stats.runtime_stats_file = "%s/runtime_stats.json" % results_dir if self.mcs_trace_path is None: self.mcs_trace_path = "%s/mcs.trace" % results_dir # TODO(cs): assumes that transform dag is a peeker, not some other # transformer peeker_exists = self.transform_dag is not None self.mcs_log_tracker = MCSLogTracker(results_dir, self.mcs_trace_path, self._runtime_stats, self.simulation_cfg, peeker_exists) self.replay_log_tracker = ReplayLogTracker(results_dir)
[docs] def simulate(self, check_reproducability=True): self._runtime_stats.set_dag_stats(self.dag) # apply domain knowledge: treat failure/recovery pairs atomically, and # filter event types we don't want to include in the MCS # (e.g. CheckInvariants) self.dag.mark_invalid_input_sequences() self.dag = self.dag.filter_unsupported_input_types() if len(self.dag) == 0: raise RuntimeError("No supported input types?") if check_reproducability: # First, run through without pruning to verify that the violation exists self._runtime_stats.record_replay_start() for i in range(0, self.no_violation_verification_runs): violations = self.replay(self.dag, "reproducibility") if violations != []: break self._runtime_stats.set_initial_verification_runs_needed(i) self._runtime_stats.record_replay_end() if violations == []: msg.fail("Unable to reproduce correctness violation!") sys.exit(5) self.log("Violation reproduced successfully! Proceeding with pruning") Replayer.total_replays = 0 Replayer.total_inputs_replayed = 0 self._runtime_stats.record_prune_start() # TODO(cs): Better than a boolean flag: check if # log(len(self.dag)) > number of input types to try if self.optimized_filtering: self._optimize_event_dag() precompute_cache = PrecomputeCache() # Invoke delta debugging (dag, total_inputs_pruned) = self._ddmin(self.dag, 2, precompute_cache=precompute_cache) # Make sure to track the final iteration size self._track_iteration_size(total_inputs_pruned) self.dag = dag self._runtime_stats.record_prune_end() self.mcs_log_tracker.dump_runtime_stats() self.log("Final MCS (%d elements):" % len(self.dag.input_events)) for i in self.dag.input_events: self.log(" - %s" % str(i)) if self.replay_final_trace: # Replaying the final trace achieves two goals: # - verifies that the MCS indeed ends in the violation # - allows us to prune internal events that time out violations = self.replay(self.dag, "final_mcs_trace") if violations == []: self.log('''Warning! Final MCS did not result in violation.''' ''' Try without timed out events? ''' ''' See tools/visualize_event_trace.html for debugging''') if self.mcs_trace_path is not None: self.mcs_log_tracker.dump_mcs_trace(self.dag, self) self.log("=== Total replays: %d ===" % Replayer.total_replays) return ExitCode(0)
def _ddmin(self, dag, split_ways, precompute_cache=None, label_prefix=(), total_inputs_pruned=0): # This is the delta-debugging algorithm from: # http://www.st.cs.uni-saarland.de/papers/tse2002/tse2002.pdf, # Section 3.2 # TODO(cs): we could do much better if we leverage domain knowledge (e.g., # start by pruning all LinkFailures, or splitting by nodes rather than # time) if split_ways > len(dag.input_events): self.log("Done") return (dag, total_inputs_pruned) local_label = lambda i, inv=False: "%s%d/%d" % ("~" if inv else "", i, split_ways) subset_label = lambda label: ".".join(map(str, label_prefix + ( label, ))) print_subset = lambda label, s: subset_label(label) + ": "+" ".join(map(lambda e: e.label, s)) subsets = split_list(dag.input_events, split_ways) self.log("Subsets:\n"+"\n".join(print_subset(local_label(i), s) for i, s in enumerate(subsets))) for i, subset in enumerate(subsets): label = local_label(i) new_dag = dag.input_subset(subset) input_sequence = tuple(new_dag.input_events) self.log("Current subset: %s" % print_subset(label, input_sequence)) if precompute_cache.already_done(input_sequence): self.log("Already computed. Skipping") continue precompute_cache.update(input_sequence) if input_sequence == (): self.log("Subset after pruning dependencies was empty. Skipping") continue self._track_iteration_size(total_inputs_pruned) violation = self._check_violation(new_dag, i, label) if violation: self.log_violation("Subset %s reproduced violation. Subselecting." % subset_label(label)) self.mcs_log_tracker.maybe_dump_intermediate_mcs(new_dag, subset_label(label), self) total_inputs_pruned += len(dag.input_events) - len(new_dag.input_events) return self._ddmin(new_dag, 2, precompute_cache=precompute_cache, label_prefix = label_prefix + (label, ), total_inputs_pruned=total_inputs_pruned) self.log_no_violation("No subsets with violations. Checking complements") for i, subset in enumerate(subsets): label = local_label(i, True) prefix = label_prefix + (label, ) new_dag = dag.input_complement(subset) input_sequence = tuple(new_dag.input_events) self.log("Current complement: %s" % print_subset(label, input_sequence)) if precompute_cache.already_done(input_sequence): self.log("Already computed. Skipping") continue precompute_cache.update(input_sequence) if input_sequence == (): self.log("Subset %s after pruning dependencies was empty. Skipping", subset_label(label)) continue self._track_iteration_size(total_inputs_pruned) violation = self._check_violation(new_dag, i, label) if violation: self.log_violation("Subset %s reproduced violation. Subselecting." % subset_label(label)) self.mcs_log_tracker.maybe_dump_intermediate_mcs(new_dag, subset_label(label), self) total_inputs_pruned += len(dag.input_events) - len(new_dag.input_events) return self._ddmin(new_dag, max(split_ways - 1, 2), precompute_cache=precompute_cache, label_prefix=prefix, total_inputs_pruned=total_inputs_pruned) self.log_no_violation("No complements with violations.") if split_ways < len(dag.input_events): self.log("Increasing granularity.") return self._ddmin(dag, min(len(dag.input_events), split_ways*2), precompute_cache=precompute_cache, label_prefix=label_prefix, total_inputs_pruned=total_inputs_pruned) return (dag, total_inputs_pruned) def _track_iteration_size(self, total_inputs_pruned): self._runtime_stats.record_iteration_size(len(self.dag.input_events) - total_inputs_pruned) def _check_violation(self, new_dag, subset_index, label): ''' Check if there were violations ''' # Try no_violation_verification_runs times to see if the bug shows up for i in range(0, self.no_violation_verification_runs): violations = self.replay(new_dag, label) if violations != []: # Violation in the subset self.log_violation("Violation! Considering %d'th" % subset_index) self._runtime_stats.record_violation_found(i) return True # No violation! self.log_no_violation("No violation in %d'th..." % subset_index) return False
[docs] def replay(self, new_dag, label): # Run the simulation forward if self.transform_dag: new_dag = self.transform_dag(new_dag) def play_forward(results_dir): # TODO(cs): need to serialize the parameters to Replayer rather than # wrapping them in a closure... otherwise, can't use RemoteForker # TODO(aw): MCSFinder needs to configure Simulation to always let DataplaneEvents pass through ReplayLogTracker.create_replay_logger_dir(results_dir) input_logger = InputLogger() input_logger.open(results_dir) replayer = Replayer(self.simulation_cfg, new_dag, wait_on_deterministic_values=self.wait_on_deterministic_values, input_logger=input_logger, **self.kwargs) violations = [] simulation = None try: simulation = replayer.simulate() self._track_new_internal_events(simulation, replayer) # Wait a bit in case the bug takes awhile to happen self.log("Sleeping %d seconds after run" % self.end_wait_seconds) time.sleep(self.end_wait_seconds) violations = self.invariant_check(simulation) if violations != []: input_logger.log_input_event(InvariantViolation(violations)) except SystemExit: # One of the invariant checks bailed early. Oddly, this is not an # error for us, it just means that there were no violations... # [this logic is arguably broken] # Return no violations, and let Forker handle system exit for us. violations = [] finally: input_logger.close(replayer, self.simulation_cfg, skip_mcs_cfg=True) if simulation is not None: simulation.clean_up() test_serialize_response(violations, self._runtime_stats.client_dict()) return (violations, self._runtime_stats.client_dict()) # TODO(cs): once play_forward() is no longer a closure, register it only once self.forker.register_task("play_forward", play_forward) results_dir = self.replay_log_tracker.get_replay_logger_dir(label) (violations, client_runtime_stats) = self.forker.fork("play_forward", results_dir) self._runtime_stats.merge_client_dict(client_runtime_stats) return violations
def _optimize_event_dag(self): ''' Employs domain knowledge of event classes to reduce the size of event dag. Currently prunes event types.''' # TODO(cs): Another approach for later: split by nodes event_types = [TrafficInjection, DataplaneDrop, SwitchFailure, SwitchRecovery, LinkFailure, LinkRecovery, HostMigration, ControllerFailure, ControllerRecovery, PolicyChange, ControlChannelBlock, ControlChannelUnblock] for event_type in event_types: pruned = [e for e in self.dag.input_events if not isinstance(e, event_type)] if len(pruned)==len(self.dag.input_events): self.log("\t** No events pruned for event type %s. Next!" % event_type) continue pruned_dag = self.dag.input_complement(pruned) violations = self.replay(pruned_dag, "opt_%s" % event_type.__name__) if violations != []: self.log("\t** VIOLATION for pruning event type %s! Resizing original dag" % event_type) self.dag = pruned_dag def _track_new_internal_events(self, simulation, replayer): ''' Pre: simulation must have been run through a replay''' # We always check against internal events that were buffered at the end of # the original run (don't want to overcount) path = self.superlog_path + ".unacked" if not os.path.exists(path): log.warn("unacked internal events file from original run does not exist") return prev_buffered_receives = [ e.pending_receive for e in EventDag(log_parser.parse_path(path)).events ] new_message_receipts = [] for p in simulation.god_scheduler.pending_receives(): if p not in prev_buffered_receives: new_message_receipts.append(repr(p)) else: prev_buffered_receives.remove(p) new_state_changes = replayer.unexpected_state_changes new_internal_events = new_state_changes + new_message_receipts self._runtime_stats.record_new_internal_events(new_internal_events) self._runtime_stats.record_early_internal_events(replayer.early_state_changes) self._runtime_stats.record_timed_out_events(dict(replayer.event_scheduler_stats.event2timeouts)) self._runtime_stats.record_matched_events(dict(replayer.event_scheduler_stats.event2matched)) # TODO(cs): Hack alert. Shouldn't be a subclass
[docs]class EfficientMCSFinder(MCSFinder): ''' Exactly the same functionality as MCSFinder, but assumes that indeterminate results cannot occur. Worst-case runtime of O(n) as opposed to O(n^2) replays. Taken from the predecessor paper: http://www.st.cs.uni-saarland.de/publications/files/zeller-esec-1999.pdf Section 4 ''' def _ddmin(self, dag, carryover_inputs, precompute_cache=None, recursion_level=0, label_prefix=(), total_inputs_pruned=0): ''' carryover_inputs is the variable "r" from the paper. ''' # Hack: superclass calls _ddmin with an integer, which doesn't match our # API. Translate that to an empty sequence. (we also don't use precompute_cache) if type(carryover_inputs) == int: carryover_inputs = [] local_label = lambda i: "%s/%d" % ("l" if i == 0 else "r", recursion_level) subset_label = lambda label: ".".join(map(str, label_prefix + ( label, ))) print_subset = lambda label, s: subset_label(label) + ": "+" ".join(map(lambda e: e.label, s)) # Base case. Note that atomic_inputs are grouped-together failure/recovery # pairs, or normal inputs otherwise. if len(dag.atomic_input_events) == 1: self.log("Base case %s" % str(dag.input_events)) return (dag, total_inputs_pruned) (left, right) = split_list(dag.atomic_input_events, 2) self.log("Subsets:\n"+"\n".join(print_subset(local_label(i), s) for i, s in enumerate([left,right]))) # This is: [dag.input_subset(left), dag.input_subset(right)] left_right_dag = [] for i, subsequence in enumerate([left, right]): label = local_label(i) prefix = label_prefix + (label, ) new_dag = dag.atomic_input_subset(subsequence) self.log("Current subset: %s" % print_subset(label, new_dag.atomic_input_events)) left_right_dag.append(new_dag) # We test on subsequence U carryover_inputs test_dag = new_dag.insert_atomic_inputs(carryover_inputs) self._track_iteration_size(total_inputs_pruned) violation = self._check_violation(test_dag, i, label) if violation: self.log("Violation found in %dth half. Recursing" % i) total_inputs_pruned += len(dag.input_events) - len(new_dag.input_events) self.mcs_log_tracker.maybe_dump_intermediate_mcs(new_dag, "", self) return self._ddmin(new_dag, carryover_inputs, recursion_level=recursion_level+1, label_prefix=prefix, total_inputs_pruned=total_inputs_pruned) self.log("Interference") (left_dag, right_dag) = left_right_dag self.log("Recursing on left half") prefix = label_prefix + ("il/%d" % recursion_level,) (left_result, total_inputs_pruned) = self._ddmin(left_dag, right_dag.insert_atomic_inputs(carryover_inputs).atomic_input_events, recursion_level=recursion_level+1, label_prefix=prefix, total_inputs_pruned=total_inputs_pruned) self.log("Recursing on right half") prefix = label_prefix + ("ir/%d" % recursion_level,) (right_result, total_inputs_pruned) = self._ddmin(right_dag, left_dag.insert_atomic_inputs(carryover_inputs).atomic_input_events, recursion_level=recursion_level+1, label_prefix=prefix, total_inputs_pruned=total_inputs_pruned) return (left_result.insert_atomic_inputs(right_result.atomic_input_events), total_inputs_pruned)
[docs]class ReplayLogTracker(object): ''' Logs intermediate and final replay traces chosen by delta debugging'''
[docs] def __init__(self, results_dir): self.results_dir = results_dir self.count = 0
[docs] def get_replay_logger_dir(self, label): dst = os.path.join(self.results_dir, "interreplay_%d_%s" % (self.count, label.replace("/", "."))) self.count += 1 return dst
@staticmethod
[docs] def create_replay_logger_dir(results_dir): mkdir_p(results_dir)
[docs]class MCSLogTracker(object): ''' Logs intermedate and final MCS results that are the outcome(s) of delta debugging'''
[docs] def __init__(self, results_dir, mcs_trace_path, runtime_stats, simulation_cfg, peeker_exists): self.results_dir = results_dir self.mcs_trace_path = mcs_trace_path self.runtime_stats = runtime_stats self.simulation_cfg = simulation_cfg self.peeker_exists = peeker_exists self.min_size = sys.maxint self.count = 0
[docs] def dump_runtime_stats(self, runtime_stats_file=None): runtime_stats = self.runtime_stats.clone() if runtime_stats_file is not None: runtime_stats.runtime_stats_file = runtime_stats_file runtime_stats.set_peeker(self.peeker_exists) runtime_stats.set_config(str(self.simulation_cfg)) runtime_stats.write_runtime_stats()
[docs] def maybe_dump_intermediate_mcs(self, dag, label, control_flow): if len(dag.events) < self.min_size: # Only dump if MCS decreases in size self.min_size = len(dag.events) self.count += 1 dst = os.path.join(self.results_dir, "intermcs_%d_%s" % (self.count, label.replace("/", "."))) mkdir_p(dst) self.dump_mcs_trace(dag, control_flow, os.path.join(dst, os.path.basename(self.mcs_trace_path))) self.dump_runtime_stats(os.path.join(dst, os.path.basename(self.runtime_stats.runtime_stats_file)))
[docs] def dump_mcs_trace(self, dag, control_flow, mcs_trace_path=None): if mcs_trace_path is None: mcs_trace_path = self.mcs_trace_path for extension in ["", ".notimeouts"]: output_path = mcs_trace_path + extension input_logger = InputLogger() input_logger.open(os.path.dirname(output_path)) for e in dag.events: if extension == ".notimeouts" and e.timed_out: continue input_logger.log_input_event(e) input_logger.close(control_flow, self.simulation_cfg, skip_mcs_cfg=True)
[docs]class RuntimeStats(object): ''' Tracks statistics and configuration information of the delta debugging runs '''
[docs] def __init__(self, runtime_stats_file): self.runtime_stats_file = runtime_stats_file self.iteration_size = {} self.violation_found_in_run = Counter() # { replay iteration -> [string representations new internal events] } self.new_internal_events = {} # { replay iteration -> [string representations internal events that # violated causality] } self.early_internal_events = {} # { replay iteration -> { event type -> timeouts } } self.timed_out_events = {} # { replay iteration -> { event type -> successful matches } } self.matched_events = {} self.total_inputs = 0 self.total_events = 0 self.original_duration_seconds = 0 self.replay_start_epoch = 0 self.replay_end_epoch = 0 self.replay_duration_seconds = 0 self.prune_start_epoch = 0 self.prune_duration_seconds = 0 self.initial_verification_runs_needed = 0 self.peeker = "" self.config = "" self.total_replays = 0 self.total_inputs_replayed = 0 self.ambiguous_counts = {} self.ambiguous_events = {}
[docs] def write_runtime_stats(self): # Now write contents to a file now = timestamp_string() if self.runtime_stats_file is None: self.runtime_stats_file = "runtime_stats/" + now + ".json" with file(self.runtime_stats_file, "w") as output: json_string = json.dumps(self.__dict__, sort_keys=True, indent=2, separators=(',', ': ')) output.write(json_string)
[docs] def set_dag_stats(self, dag): self.total_inputs = len(dag.input_events) self.total_events = len(dag) self.original_duration_seconds = \ (dag.events[-1].time.as_float() - dag.events[0].time.as_float())
[docs] def record_replay_start(self): self.replay_start_epoch = time.time()
[docs] def record_replay_end(self): self.replay_end_epoch = time.time() self.replay_duration_seconds = self.replay_end_epoch - self.replay_start_epoch
[docs] def record_prune_start(self): self.prune_start_epoch = time.time()
[docs] def record_prune_end(self): self.prune_end_epoch = time.time() self.prune_duration_seconds = self.prune_end_epoch - self.prune_start_epoch
[docs] def set_initial_verification_runs_needed(self, verification_runs): self.initial_verification_runs_needed = verification_runs
[docs] def set_peeker(self, peeker): self.peeker = peeker
[docs] def set_config(self, config): self.config = config
[docs] def record_iteration_size(self, iteration_size): self.iteration_size[Replayer.total_replays] = iteration_size
[docs] def record_violation_found(self, verification_iteration): self.violation_found_in_run[verification_iteration] += 1
[docs] def record_new_internal_events(self, new_internal_events): self.new_internal_events[Replayer.total_replays] = new_internal_events
[docs] def record_early_internal_events(self, early_internal_events): self.early_internal_events[Replayer.total_replays] = early_internal_events
[docs] def record_timed_out_events(self, timed_out_events): self.timed_out_events[Replayer.total_replays] = timed_out_events
[docs] def record_matched_events(self, matched_events): self.matched_events[Replayer.total_replays] = matched_events
[docs] def record_global_stats(self): self.total_replays = Replayer.total_replays self.total_inputs_replayed = Replayer.total_inputs_replayed # TODO(cs): assumes that Peeker is the dag transformer self.ambiguous_counts = dict(Peeker.ambiguous_counts) self.ambiguous_events = dict(Peeker.ambiguous_events)
[docs] def clone(self): return copy.deepcopy(self)
[docs] def client_dict(self): ''' Return a serializable dict ''' # Only include relevent fields for parent d = {} self.record_global_stats() for field in ['new_internal_events', 'early_internal_events', 'timed_out_events', 'matched_events', 'total_replays', 'total_inputs_replayed', 'ambiguous_counts', 'ambiguous_events']: v = getattr(self, field) # xmlrpclib doesn't allow non-string keys if type(v) == dict: v = dict((str(key), value) for key, value in v.items()) d[field] = v return d
[docs] def merge_client_dict(self, client_dict): for field, value in client_dict.iteritems(): try: field = int(field) except: pass if type(value) == dict: setattr(self, field, dict(getattr(self, field).items() + value.items())) elif type(value) == int: setattr(self, field, getattr(self, field) + value) else: raise ValueError("Unknown field %s: %s" % (str(field),str(value)))