Parallelize BottomUpMutators

Allow BottomUpMutators to run in parallel by calling Parallel() on the
return value of RegisterBottomUpMutator.  To avoid locking, moves
updates of global state into a separate goroutine that receives updates
over channels from the mutator goroutines.

Change-Id: Ic59b612da9b406cf59ec44940f0c1dee0c051a51
diff --git a/context.go b/context.go
index a2b0890..a6d8626 100644
--- a/context.go
+++ b/context.go
@@ -224,6 +224,7 @@
 	topDownMutator  TopDownMutator
 	bottomUpMutator BottomUpMutator
 	name            string
+	parallel        bool
 }
 
 func (e *Error) Error() string {
@@ -393,19 +394,35 @@
 //
 // The mutator type names given here must be unique to all bottom up or early
 // mutators in the Context.
-func (c *Context) RegisterBottomUpMutator(name string, mutator BottomUpMutator) {
+//
+// Returns a BottomUpMutatorHandle, on which Parallel can be called to set
+// the mutator to visit modules in parallel while maintaining ordering.
+func (c *Context) RegisterBottomUpMutator(name string, mutator BottomUpMutator) BottomUpMutatorHandle {
 	for _, m := range c.variantMutatorNames {
 		if m == name {
 			panic(fmt.Errorf("mutator name %s is already registered", name))
 		}
 	}
 
-	c.mutatorInfo = append(c.mutatorInfo, &mutatorInfo{
+	info := &mutatorInfo{
 		bottomUpMutator: mutator,
 		name:            name,
-	})
+	}
+	c.mutatorInfo = append(c.mutatorInfo, info)
 
 	c.variantMutatorNames = append(c.variantMutatorNames, name)
+
+	return info
+}
+
+type BottomUpMutatorHandle interface {
+	// Set the mutator to visit modules in parallel while maintaining ordering
+	Parallel() BottomUpMutatorHandle
+}
+
+func (mutator *mutatorInfo) Parallel() BottomUpMutatorHandle {
+	mutator.parallel = true
+	return mutator
 }
 
 // RegisterEarlyMutator registers a mutator that will be invoked to split
@@ -991,11 +1008,6 @@
 
 		newModules = append(newModules, newModule)
 
-		// Insert the new variant into the global module map.  If this is the first variant then
-		// it reuses logicModule from the original module, which causes this to replace the
-		// original module in the global module map.
-		c.moduleInfo[newModule.logicModule] = newModule
-
 		newErrs := c.convertDepsToVariation(newModule, mutatorName, variationName)
 		if len(newErrs) > 0 {
 			errs = append(errs, newErrs...)
@@ -1352,6 +1364,16 @@
 	fromInfo.directDeps = append(fromInfo.directDeps, depInfo{toInfo, tag})
 }
 
+func (c *Context) visitAllBottomUp(visit func(group *moduleInfo) bool) {
+	for _, module := range c.modulesSorted {
+		if visit(module) {
+			return
+		}
+	}
+}
+
+// Calls visit on each module, guaranteeing that visit is not called on a module until visit on all
+// of its dependencies has finished.
 func (c *Context) parallelVisitAllBottomUp(visit func(group *moduleInfo) bool) {
 	doneCh := make(chan *moduleInfo)
 	count := 0
@@ -1580,9 +1602,9 @@
 
 	for _, mutator := range mutators {
 		if mutator.topDownMutator != nil {
-			errs = c.runTopDownMutator(config, mutator.name, mutator.topDownMutator)
+			errs = c.runTopDownMutator(config, mutator)
 		} else if mutator.bottomUpMutator != nil {
-			errs = c.runBottomUpMutator(config, mutator.name, mutator.bottomUpMutator)
+			errs = c.runBottomUpMutator(config, mutator)
 		} else {
 			panic("no mutator set on " + mutator.name)
 		}
@@ -1594,8 +1616,7 @@
 	return nil
 }
 
-func (c *Context) runTopDownMutator(config interface{},
-	name string, mutator TopDownMutator) (errs []error) {
+func (c *Context) runTopDownMutator(config interface{}, mutator *mutatorInfo) (errs []error) {
 
 	for i := 0; i < len(c.modulesSorted); i++ {
 		module := c.modulesSorted[len(c.modulesSorted)-1-i]
@@ -1605,12 +1626,12 @@
 				config:  config,
 				module:  module,
 			},
-			name: name,
+			name: mutator.name,
 		}
 		func() {
 			defer func() {
 				if r := recover(); r != nil {
-					in := fmt.Sprintf("top down mutator %q for %s", name, module)
+					in := fmt.Sprintf("top down mutator %q for %s", mutator.name, module)
 					if err, ok := r.(panicError); ok {
 						err.addIn(in)
 						mctx.error(err)
@@ -1619,7 +1640,7 @@
 					}
 				}
 			}()
-			mutator(mctx)
+			mutator.topDownMutator(mctx)
 		}()
 
 		if len(mctx.errs) > 0 {
@@ -1631,14 +1652,27 @@
 	return errs
 }
 
+type reverseDep struct {
+	module *moduleInfo
+	dep    depInfo
+}
+
 func (c *Context) runBottomUpMutator(config interface{},
-	name string, mutator BottomUpMutator) (errs []error) {
+	mutator *mutatorInfo) (errs []error) {
+
+	newModuleInfo := make(map[Module]*moduleInfo)
+	for k, v := range c.moduleInfo {
+		newModuleInfo[k] = v
+	}
 
 	reverseDeps := make(map[*moduleInfo][]depInfo)
 
-	for _, module := range c.modulesSorted {
-		newModules := make([]*moduleInfo, 0, 1)
+	errsCh := make(chan []error)
+	reverseDepsCh := make(chan []reverseDep)
+	newModulesCh := make(chan []*moduleInfo)
+	done := make(chan bool)
 
+	visit := func(module *moduleInfo) bool {
 		if module.splitModules != nil {
 			panic("split module found in sorted module list")
 		}
@@ -1649,14 +1683,13 @@
 				config:  config,
 				module:  module,
 			},
-			name:        name,
-			reverseDeps: reverseDeps,
+			name: mutator.name,
 		}
 
 		func() {
 			defer func() {
 				if r := recover(); r != nil {
-					in := fmt.Sprintf("bottom up mutator %q for %s", name, module)
+					in := fmt.Sprintf("bottom up mutator %q for %s", mutator.name, module)
 					if err, ok := r.(panicError); ok {
 						err.addIn(in)
 						mctx.error(err)
@@ -1665,28 +1698,76 @@
 					}
 				}
 			}()
-			mutator(mctx)
+			mutator.bottomUpMutator(mctx)
 		}()
+
 		if len(mctx.errs) > 0 {
-			errs = append(errs, mctx.errs...)
-			return errs
+			errsCh <- errs
+			return true
 		}
 
-		// Fix up any remaining dependencies on modules that were split into variants
-		// by replacing them with the first variant
-		for i, dep := range module.directDeps {
-			if dep.module.logicModule == nil {
-				module.directDeps[i].module = dep.module.splitModules[0]
+		if len(mctx.newModules) > 0 {
+			newModulesCh <- mctx.newModules
+		}
+
+		if len(mctx.reverseDeps) > 0 {
+			reverseDepsCh <- mctx.reverseDeps
+		}
+
+		return false
+	}
+
+	// Process errs and reverseDeps in a single goroutine
+	go func() {
+		for {
+			select {
+			case newErrs := <-errsCh:
+				errs = append(errs, newErrs...)
+			case newReverseDeps := <-reverseDepsCh:
+				for _, r := range newReverseDeps {
+					reverseDeps[r.module] = append(reverseDeps[r.module], r.dep)
+				}
+			case newModules := <-newModulesCh:
+				for _, m := range newModules {
+					newModuleInfo[m.logicModule] = m
+				}
+			case <-done:
+				return
 			}
 		}
+	}()
 
-		if module.splitModules != nil {
-			newModules = append(newModules, module.splitModules...)
-		} else {
-			newModules = append(newModules, module)
+	if mutator.parallel {
+		c.parallelVisitAllBottomUp(visit)
+	} else {
+		c.visitAllBottomUp(visit)
+	}
+
+	done <- true
+
+	if len(errs) > 0 {
+		return errs
+	}
+
+	c.moduleInfo = newModuleInfo
+
+	for _, group := range c.moduleGroups {
+		for i := 0; i < len(group.modules); i++ {
+			module := group.modules[i]
+
+			// Update module group to contain newly split variants
+			if module.splitModules != nil {
+				group.modules, i = spliceModules(group.modules, i, module.splitModules)
+			}
+
+			// Fix up any remaining dependencies on modules that were split into variants
+			// by replacing them with the first variant
+			for j, dep := range module.directDeps {
+				if dep.module.logicModule == nil {
+					module.directDeps[j].module = dep.module.splitModules[0]
+				}
+			}
 		}
-
-		module.group.modules = spliceModules(module.group.modules, module, newModules)
 	}
 
 	for module, deps := range reverseDeps {
@@ -1694,6 +1775,7 @@
 		module.directDeps = append(module.directDeps, deps...)
 	}
 
+	// TODO(ccross): update can be elided if no dependencies were modified
 	errs = c.updateDependencies()
 	if len(errs) > 0 {
 		return errs
@@ -1714,18 +1796,9 @@
 	}
 }
 
-func spliceModules(modules []*moduleInfo, origModule *moduleInfo,
-	newModules []*moduleInfo) []*moduleInfo {
-	for i, m := range modules {
-		if m == origModule {
-			return spliceModulesAtIndex(modules, i, newModules)
-		}
-	}
-
-	panic("failed to find original module to splice")
-}
-
-func spliceModulesAtIndex(modules []*moduleInfo, i int, newModules []*moduleInfo) []*moduleInfo {
+// Removes modules[i] from the list and inserts newModules... where it was located, returning
+// the new slice and the index of the last inserted element
+func spliceModules(modules []*moduleInfo, i int, newModules []*moduleInfo) ([]*moduleInfo, int) {
 	spliceSize := len(newModules)
 	newLen := len(modules) + spliceSize - 1
 	var dest []*moduleInfo
@@ -1743,7 +1816,7 @@
 	// Copy the new modules into the slice
 	copy(dest[i:], newModules)
 
-	return dest
+	return dest, i + spliceSize - 1
 }
 
 func (c *Context) initSpecialVariables() {