blob: ba7ee0096b8efb352c07a121ee58ab186ef623ed [file] [log] [blame]
Shinichiro Hamaji702befc2016-01-27 17:21:39 +09001// Copyright 2016 Google Inc. All rights reserved
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "thread_pool.h"
16
17#include <stack>
18#include <vector>
19
20#include "condvar.h"
21#include "mutex.h"
22#include "thread.h"
23
24class ThreadPoolImpl : public ThreadPool {
25 public:
26 explicit ThreadPoolImpl(int num_threads)
27 : is_waiting_(false) {
28 threads_.reserve(num_threads);
29 for (int i = 0; i < num_threads; i++) {
30 threads_.push_back(thread([this]() { Loop(); }));
31 }
32 }
33
34 virtual ~ThreadPoolImpl() override {
35 }
36
37 virtual void Submit(function<void(void)> task) override {
38 unique_lock<mutex> lock(mu_);
39 tasks_.push(task);
40 cond_.notify_one();
41 }
42
43 virtual void Wait() override {
44 {
45 unique_lock<mutex> lock(mu_);
46 is_waiting_ = true;
47 cond_.notify_all();
48 }
49
50 for (thread& th : threads_) {
51 th.join();
52 }
53 }
54
55 private:
56 void Loop() {
57 while (true) {
58 function<void(void)> task;
59 {
60 unique_lock<mutex> lock(mu_);
61 if (tasks_.empty()) {
62 if (is_waiting_)
63 return;
64 cond_.wait(lock);
65 }
66
67 if (tasks_.empty())
68 continue;
69
70 task = tasks_.top();
71 tasks_.pop();
72 }
73 task();
74 }
75 }
76
77 vector<thread> threads_;
78 mutex mu_;
79 condition_variable cond_;
80 stack<function<void(void)>> tasks_;
81 bool is_waiting_;
82};
83
84ThreadPool* NewThreadPool(int num_threads) {
85 return new ThreadPoolImpl(num_threads);
86}