| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #define LOG_TAG "resolv" |
| |
| #include "DnsTlsQueryMap.h" |
| |
| #include <android-base/logging.h> |
| |
| #include "Experiments.h" |
| |
| namespace android { |
| namespace net { |
| |
| DnsTlsQueryMap::DnsTlsQueryMap() { |
| mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries); |
| if (mMaxTries < 1) mMaxTries = 1; |
| } |
| |
| std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery( |
| const netdutils::Slice query) { |
| std::lock_guard guard(mLock); |
| |
| // Store the query so it can be matched to the response or reissued. |
| if (query.size() < 2) { |
| LOG(WARNING) << "Query is too short"; |
| return nullptr; |
| } |
| int32_t newId = getFreeId(); |
| if (newId < 0) { |
| LOG(WARNING) << "All query IDs are in use"; |
| return nullptr; |
| } |
| |
| // Make a copy of the query. |
| std::vector<uint8_t> tmp(query.base(), query.base() + query.size()); |
| Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)}; |
| |
| const auto [it, inserted] = mQueries.try_emplace(newId, q); |
| if (!inserted) { |
| LOG(ERROR) << "Failed to store pending query"; |
| return nullptr; |
| } |
| return std::make_unique<QueryFuture>(q, it->second.result.get_future()); |
| } |
| |
| void DnsTlsQueryMap::expire(QueryPromise* p) { |
| Result r = { .code = Response::network_error }; |
| p->result.set_value(r); |
| } |
| |
| void DnsTlsQueryMap::markTried(uint16_t newId) { |
| std::lock_guard guard(mLock); |
| auto it = mQueries.find(newId); |
| if (it != mQueries.end()) { |
| it->second.tries++; |
| } |
| } |
| |
| void DnsTlsQueryMap::cleanup() { |
| std::lock_guard guard(mLock); |
| for (auto it = mQueries.begin(); it != mQueries.end();) { |
| auto& p = it->second; |
| if (p.tries >= mMaxTries) { |
| expire(&p); |
| it = mQueries.erase(it); |
| } else { |
| ++it; |
| } |
| } |
| } |
| |
| int32_t DnsTlsQueryMap::getFreeId() { |
| if (mQueries.empty()) { |
| return 0; |
| } |
| uint16_t maxId = mQueries.rbegin()->first; |
| if (maxId < UINT16_MAX) { |
| return maxId + 1; |
| } |
| if (mQueries.size() == UINT16_MAX + 1) { |
| // Map is full. |
| return -1; |
| } |
| // Linear scan. |
| uint16_t nextId = 0; |
| for (auto& pair : mQueries) { |
| uint16_t id = pair.first; |
| if (id != nextId) { |
| // Found a gap. |
| return nextId; |
| } |
| nextId = id + 1; |
| } |
| // Unreachable (but the compiler isn't smart enough to prove it). |
| return -1; |
| } |
| |
| std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() { |
| std::lock_guard guard(mLock); |
| std::vector<DnsTlsQueryMap::Query> queries; |
| for (auto& q : mQueries) { |
| queries.push_back(q.second.query); |
| } |
| return queries; |
| } |
| |
| bool DnsTlsQueryMap::empty() { |
| std::lock_guard guard(mLock); |
| return mQueries.empty(); |
| } |
| |
| void DnsTlsQueryMap::clear() { |
| std::lock_guard guard(mLock); |
| for (auto& q : mQueries) { |
| expire(&q.second); |
| } |
| mQueries.clear(); |
| } |
| |
| void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) { |
| LOG(VERBOSE) << "Got response of size " << response.size(); |
| if (response.size() < 2) { |
| LOG(WARNING) << "Response is too short"; |
| return; |
| } |
| uint16_t id = response[0] << 8 | response[1]; |
| std::lock_guard guard(mLock); |
| auto it = mQueries.find(id); |
| if (it == mQueries.end()) { |
| LOG(WARNING) << "Discarding response: unknown ID " << id; |
| return; |
| } |
| Result r = { .code = Response::success, .response = std::move(response) }; |
| // Rewrite ID to match the query |
| const uint8_t* data = it->second.query.query.data(); |
| r.response[0] = data[0]; |
| r.response[1] = data[1]; |
| LOG(DEBUG) << "Sending result to dispatcher"; |
| it->second.result.set_value(std::move(r)); |
| mQueries.erase(it); |
| } |
| |
| } // end of namespace net |
| } // end of namespace android |