Parallelize TopDownMutators

Convert parallelVisitAllBottomUp into a generic parallelVisit that takes
a visitOrderer to select top down or bottom up.  Combine
runTopDownMutator and runBottomUpMutator into runMutators that takes a
mutatorDirection to select the visitOrderer for parallelVisit and which
function pointer in the mutatorInfo to run.  Optimize out the
updateDependencies to only run if dependencies have been modified to
avoid running it after TopDownMutators that cannot modify dependencies.

Change-Id: Ib00302db1108ebab2ce8e01b20aa026140d382a4
diff --git a/context.go b/context.go
index 60a4363..fc11ec6 100644
--- a/context.go
+++ b/context.go
@@ -74,6 +74,8 @@
 	variantMutatorNames []string
 	moduleNinjaNames    map[string]*moduleGroup
 
+	depsModified uint32 // positive if a mutator modified the dependencies
+
 	dependenciesReady bool // set to true on a successful ResolveDependencies
 	buildActionsReady bool // set to true on a successful PrepareBuildActions
 
@@ -146,7 +148,7 @@
 
 	// set during updateDependencies
 	reverseDeps []*moduleInfo
-	depsCount   int
+	forwardDeps []*moduleInfo
 
 	// used by parallelVisitAllBottomUp
 	waitingCount int
@@ -366,38 +368,44 @@
 	return typ.PkgPath() + "." + typ.Name()
 }
 
-// RegisterTopDownMutator registers a mutator that will be invoked to propagate
-// dependency info top-down between Modules.  Each registered mutator
-// is invoked in registration order (mixing TopDownMutators and BottomUpMutators)
-// once per Module, and is invoked on a module before being invoked on any of its
-// dependencies.
+// RegisterTopDownMutator registers a mutator that will be invoked to propagate dependency info
+// top-down between Modules.  Each registered mutator is invoked in registration order (mixing
+// TopDownMutators and BottomUpMutators) once per Module, and the invocation on any module will
+// have returned before it is in invoked on any of its dependencies.
 //
 // The mutator type names given here must be unique to all top down mutators in
 // the Context.
-func (c *Context) RegisterTopDownMutator(name string, mutator TopDownMutator) {
+//
+// Returns a MutatorHandle, on which Parallel can be called to set the mutator to visit modules in
+// parallel while maintaining ordering.
+func (c *Context) RegisterTopDownMutator(name string, mutator TopDownMutator) MutatorHandle {
 	for _, m := range c.mutatorInfo {
 		if m.name == name && m.topDownMutator != nil {
 			panic(fmt.Errorf("mutator name %s is already registered", name))
 		}
 	}
 
-	c.mutatorInfo = append(c.mutatorInfo, &mutatorInfo{
+	info := &mutatorInfo{
 		topDownMutator: mutator,
 		name:           name,
-	})
+	}
+
+	c.mutatorInfo = append(c.mutatorInfo, info)
+
+	return info
 }
 
-// RegisterBottomUpMutator registers a mutator that will be invoked to split
-// Modules into variants.  Each registered mutator is invoked in registration
-// order (mixing TopDownMutators and BottomUpMutators) once per Module, and is
-// invoked on dependencies before being invoked on dependers.
+// RegisterBottomUpMutator registers a mutator that will be invoked to split Modules into variants.
+// Each registered mutator is invoked in registration order (mixing TopDownMutators and
+// BottomUpMutators) once per Module, will not be invoked on a module until the invocations on all
+// of the modules dependencies have returned.
 //
 // The mutator type names given here must be unique to all bottom up or early
 // mutators in the Context.
 //
-// Returns a BottomUpMutatorHandle, on which Parallel can be called to set
-// the mutator to visit modules in parallel while maintaining ordering.
-func (c *Context) RegisterBottomUpMutator(name string, mutator BottomUpMutator) BottomUpMutatorHandle {
+// Returns a MutatorHandle, on which Parallel can be called to set the mutator to visit modules in
+// parallel while maintaining ordering.
+func (c *Context) RegisterBottomUpMutator(name string, mutator BottomUpMutator) MutatorHandle {
 	for _, m := range c.variantMutatorNames {
 		if m == name {
 			panic(fmt.Errorf("mutator name %s is already registered", name))
@@ -415,12 +423,14 @@
 	return info
 }
 
-type BottomUpMutatorHandle interface {
-	// Set the mutator to visit modules in parallel while maintaining ordering
-	Parallel() BottomUpMutatorHandle
+type MutatorHandle interface {
+	// Set the mutator to visit modules in parallel while maintaining ordering.  Calling any
+	// method on the mutator context is thread-safe, but the mutator must handle synchronization
+	// for any modifications to global state or any modules outside the one it was invoked on.
+	Parallel() MutatorHandle
 }
 
-func (mutator *mutatorInfo) Parallel() BottomUpMutatorHandle {
+func (mutator *mutatorInfo) Parallel() MutatorHandle {
 	mutator.parallel = true
 	return mutator
 }
@@ -1019,6 +1029,8 @@
 	origModule.logicModule = nil
 	origModule.splitModules = newModules
 
+	atomic.AddUint32(&c.depsModified, 1)
+
 	return newModules, errs
 }
 
@@ -1235,6 +1247,7 @@
 			}
 		}
 		module.directDeps = append(module.directDeps, depInfo{m, tag})
+		atomic.AddUint32(&c.depsModified, 1)
 		return nil
 	}
 
@@ -1328,6 +1341,7 @@
 				}}
 			}
 			module.directDeps = append(module.directDeps, depInfo{m, tag})
+			atomic.AddUint32(&c.depsModified, 1)
 			return nil
 		}
 	}
@@ -1362,25 +1376,69 @@
 	}
 
 	fromInfo.directDeps = append(fromInfo.directDeps, depInfo{toInfo, tag})
+	atomic.AddUint32(&c.depsModified, 1)
 }
 
-func (c *Context) visitAllBottomUp(visit func(group *moduleInfo) bool) {
-	for _, module := range c.modulesSorted {
+type visitOrderer interface {
+	// returns the number of modules that this module needs to wait for
+	waitCount(module *moduleInfo) int
+	// returns the list of modules that are waiting for this module
+	propagate(module *moduleInfo) []*moduleInfo
+	// visit modules in order
+	visit(modules []*moduleInfo, visit func(*moduleInfo) bool)
+}
+
+type bottomUpVisitorImpl struct{}
+
+func (bottomUpVisitorImpl) waitCount(module *moduleInfo) int {
+	return len(module.forwardDeps)
+}
+
+func (bottomUpVisitorImpl) propagate(module *moduleInfo) []*moduleInfo {
+	return module.reverseDeps
+}
+
+func (bottomUpVisitorImpl) visit(modules []*moduleInfo, visit func(*moduleInfo) bool) {
+	for _, module := range modules {
 		if visit(module) {
 			return
 		}
 	}
 }
 
+type topDownVisitorImpl struct{}
+
+func (topDownVisitorImpl) waitCount(module *moduleInfo) int {
+	return len(module.reverseDeps)
+}
+
+func (topDownVisitorImpl) propagate(module *moduleInfo) []*moduleInfo {
+	return module.forwardDeps
+}
+
+func (topDownVisitorImpl) visit(modules []*moduleInfo, visit func(*moduleInfo) bool) {
+	for i := 0; i < len(modules); i++ {
+		module := modules[len(modules)-1-i]
+		if visit(module) {
+			return
+		}
+	}
+}
+
+var (
+	bottomUpVisitor bottomUpVisitorImpl
+	topDownVisitor  topDownVisitorImpl
+)
+
 // Calls visit on each module, guaranteeing that visit is not called on a module until visit on all
 // of its dependencies has finished.
-func (c *Context) parallelVisitAllBottomUp(visit func(group *moduleInfo) bool) {
+func (c *Context) parallelVisit(order visitOrderer, visit func(group *moduleInfo) bool) {
 	doneCh := make(chan *moduleInfo)
 	count := 0
 	cancel := false
 
 	for _, module := range c.modulesSorted {
-		module.waitingCount = module.depsCount
+		module.waitingCount = order.waitCount(module)
 	}
 
 	visitOne := func(module *moduleInfo) {
@@ -1404,10 +1462,10 @@
 		select {
 		case doneModule := <-doneCh:
 			if !cancel {
-				for _, parent := range doneModule.reverseDeps {
-					parent.waitingCount--
-					if parent.waitingCount == 0 {
-						visitOne(parent)
+				for _, module := range order.propagate(doneModule) {
+					module.waitingCount--
+					if module.waitingCount == 0 {
+						visitOne(module)
 					}
 				}
 			}
@@ -1474,7 +1532,7 @@
 		}
 
 		module.reverseDeps = []*moduleInfo{}
-		module.depsCount = len(deps)
+		module.forwardDeps = []*moduleInfo{}
 
 		for dep := range deps {
 			if checking[dep] {
@@ -1504,6 +1562,7 @@
 				}
 			}
 
+			module.forwardDeps = append(module.forwardDeps, dep)
 			dep.reverseDeps = append(dep.reverseDeps, module)
 		}
 
@@ -1602,9 +1661,9 @@
 
 	for _, mutator := range mutators {
 		if mutator.topDownMutator != nil {
-			errs = c.runTopDownMutator(config, mutator)
+			errs = c.runMutator(config, mutator, topDownMutator)
 		} else if mutator.bottomUpMutator != nil {
-			errs = c.runBottomUpMutator(config, mutator)
+			errs = c.runMutator(config, mutator, bottomUpMutator)
 		} else {
 			panic("no mutator set on " + mutator.name)
 		}
@@ -1616,49 +1675,52 @@
 	return nil
 }
 
-func (c *Context) runTopDownMutator(config interface{}, mutator *mutatorInfo) (errs []error) {
-
-	for i := 0; i < len(c.modulesSorted); i++ {
-		module := c.modulesSorted[len(c.modulesSorted)-1-i]
-		mctx := &mutatorContext{
-			baseModuleContext: baseModuleContext{
-				context: c,
-				config:  config,
-				module:  module,
-			},
-			name: mutator.name,
-		}
-		func() {
-			defer func() {
-				if r := recover(); r != nil {
-					in := fmt.Sprintf("top down mutator %q for %s", mutator.name, module)
-					if err, ok := r.(panicError); ok {
-						err.addIn(in)
-						mctx.error(err)
-					} else {
-						mctx.error(newPanicErrorf(r, in))
-					}
-				}
-			}()
-			mutator.topDownMutator(mctx)
-		}()
-
-		if len(mctx.errs) > 0 {
-			errs = append(errs, mctx.errs...)
-			return errs
-		}
-	}
-
-	return errs
+type mutatorDirection interface {
+	run(mutator *mutatorInfo, ctx *mutatorContext)
+	orderer() visitOrderer
+	fmt.Stringer
 }
 
+type bottomUpMutatorImpl struct{}
+
+func (bottomUpMutatorImpl) run(mutator *mutatorInfo, ctx *mutatorContext) {
+	mutator.bottomUpMutator(ctx)
+}
+
+func (bottomUpMutatorImpl) orderer() visitOrderer {
+	return bottomUpVisitor
+}
+
+func (bottomUpMutatorImpl) String() string {
+	return "bottom up mutator"
+}
+
+type topDownMutatorImpl struct{}
+
+func (topDownMutatorImpl) run(mutator *mutatorInfo, ctx *mutatorContext) {
+	mutator.topDownMutator(ctx)
+}
+
+func (topDownMutatorImpl) orderer() visitOrderer {
+	return topDownVisitor
+}
+
+func (topDownMutatorImpl) String() string {
+	return "top down mutator"
+}
+
+var (
+	topDownMutator  topDownMutatorImpl
+	bottomUpMutator bottomUpMutatorImpl
+)
+
 type reverseDep struct {
 	module *moduleInfo
 	dep    depInfo
 }
 
-func (c *Context) runBottomUpMutator(config interface{},
-	mutator *mutatorInfo) (errs []error) {
+func (c *Context) runMutator(config interface{}, mutator *mutatorInfo,
+	direction mutatorDirection) (errs []error) {
 
 	newModuleInfo := make(map[Module]*moduleInfo)
 	for k, v := range c.moduleInfo {
@@ -1672,6 +1734,8 @@
 	newModulesCh := make(chan []*moduleInfo)
 	done := make(chan bool)
 
+	c.depsModified = 0
+
 	visit := func(module *moduleInfo) bool {
 		if module.splitModules != nil {
 			panic("split module found in sorted module list")
@@ -1689,7 +1753,7 @@
 		func() {
 			defer func() {
 				if r := recover(); r != nil {
-					in := fmt.Sprintf("bottom up mutator %q for %s", mutator.name, module)
+					in := fmt.Sprintf("%s %q for %s", direction, mutator.name, module)
 					if err, ok := r.(panicError); ok {
 						err.addIn(in)
 						mctx.error(err)
@@ -1698,7 +1762,7 @@
 					}
 				}
 			}()
-			mutator.bottomUpMutator(mctx)
+			direction.run(mutator, mctx)
 		}()
 
 		if len(mctx.errs) > 0 {
@@ -1738,9 +1802,9 @@
 	}()
 
 	if mutator.parallel {
-		c.parallelVisitAllBottomUp(visit)
+		c.parallelVisit(direction.orderer(), visit)
 	} else {
-		c.visitAllBottomUp(visit)
+		direction.orderer().visit(c.modulesSorted, visit)
 	}
 
 	done <- true
@@ -1773,12 +1837,14 @@
 	for module, deps := range reverseDeps {
 		sort.Sort(depSorter(deps))
 		module.directDeps = append(module.directDeps, deps...)
+		c.depsModified++
 	}
 
-	// TODO(ccross): update can be elided if no dependencies were modified
-	errs = c.updateDependencies()
-	if len(errs) > 0 {
-		return errs
+	if c.depsModified > 0 {
+		errs = c.updateDependencies()
+		if len(errs) > 0 {
+			return errs
+		}
 	}
 
 	return errs
@@ -1864,7 +1930,7 @@
 		}
 	}()
 
-	c.parallelVisitAllBottomUp(func(module *moduleInfo) bool {
+	c.parallelVisit(bottomUpVisitor, func(module *moduleInfo) bool {
 		// The parent scope of the moduleContext's local scope gets overridden to be that of the
 		// calling Go package on a per-call basis.  Since the initial parent scope doesn't matter we
 		// just set it to nil.