/*
 * Copyright (C) 2016 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <inttypes.h>

#include <memory>

#include "system/extras/simpleperf/report_sample.pb.h"

#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>

#include "command.h"
#include "record_file.h"
#include "thread_tree.h"
#include "utils.h"

namespace proto = simpleperf_report_proto;

namespace {

class ProtobufFileWriter : public google::protobuf::io::CopyingOutputStream {
 public:
  explicit ProtobufFileWriter(FILE* out_fp) : out_fp_(out_fp) {}

  bool Write(const void* buffer, int size) override {
    return fwrite(buffer, size, 1, out_fp_) == 1;
  }

 private:
  FILE* out_fp_;
};

class ProtobufFileReader : public google::protobuf::io::CopyingInputStream {
 public:
  explicit ProtobufFileReader(FILE* in_fp) : in_fp_(in_fp) {}

  int Read(void* buffer, int size) override {
    return fread(buffer, 1, size, in_fp_);
  }

 private:
  FILE* in_fp_;
};

class ReportSampleCommand : public Command {
 public:
  ReportSampleCommand()
      : Command(
            "report-sample", "report raw sample information in perf.data",
            // clang-format off
"Usage: simpleperf report-sample [options]\n"
"--dump-protobuf-report  <file>\n"
"           Dump report file generated by\n"
"           `simpleperf report-sample --protobuf -o <file>`.\n"
"-i <file>  Specify path of record file, default is perf.data.\n"
"-o report_file_name  Set report file name, default is stdout.\n"
"--protobuf  Use protobuf format in report_sample.proto to output samples.\n"
"            Need to set a report_file_name when using this option.\n"
"--show-callchain  Print callchain samples.\n"
            // clang-format on
            ),
        record_filename_("perf.data"),
        show_callchain_(false),
        use_protobuf_(false),
        report_fp_(nullptr),
        coded_os_(nullptr),
        sample_count_(0),
        lost_count_(0) {}

  bool Run(const std::vector<std::string>& args) override;

 private:
  bool ParseOptions(const std::vector<std::string>& args);
  bool DumpProtobufReport(const std::string& filename);
  bool ProcessRecord(std::unique_ptr<Record> record);
  bool PrintSampleRecordInProtobuf(const SampleRecord& record);
  bool GetCallEntry(const ThreadEntry* thread, bool in_kernel, uint64_t ip, bool omit_unknown_dso,
                    uint64_t* pvaddr_in_file, uint32_t* pfile_id, int32_t* psymbol_id);
  bool GetCallEntry(const ThreadEntry* thread, bool in_kernel, uint64_t ip, bool omit_unknown_dso,
                    uint64_t* pvaddr_in_file, Dso** pdso, const Symbol** psymbol);
  bool PrintLostSituationInProtobuf();
  bool PrintFileInfoInProtobuf();
  bool PrintSampleRecord(const SampleRecord& record);
  void PrintLostSituation();

  std::string record_filename_;
  std::unique_ptr<RecordFileReader> record_file_reader_;
  std::string dump_protobuf_report_file_;
  bool show_callchain_;
  bool use_protobuf_;
  ThreadTree thread_tree_;
  std::string report_filename_;
  FILE* report_fp_;
  google::protobuf::io::CodedOutputStream* coded_os_;
  size_t sample_count_;
  size_t lost_count_;
};

bool ReportSampleCommand::Run(const std::vector<std::string>& args) {
  // 1. Parse options.
  if (!ParseOptions(args)) {
    return false;
  }
  // 2. Prepare report fp.
  report_fp_ = stdout;
  std::unique_ptr<FILE, decltype(&fclose)> fp(nullptr, fclose);
  if (!report_filename_.empty()) {
    const char* open_mode = "w";
    if (!dump_protobuf_report_file_.empty() && use_protobuf_) {
      open_mode = "wb";
    }
    fp.reset(fopen(report_filename_.c_str(), open_mode));
    if (fp == nullptr) {
      PLOG(ERROR) << "failed to open " << report_filename_;
      return false;
    }
    report_fp_ = fp.get();
  }

  // 3. Dump protobuf report.
  if (!dump_protobuf_report_file_.empty()) {
    return DumpProtobufReport(dump_protobuf_report_file_);
  }

  // 4. Open record file.
  record_file_reader_ = RecordFileReader::CreateInstance(record_filename_);
  if (record_file_reader_ == nullptr) {
    return false;
  }
  record_file_reader_->LoadBuildIdAndFileFeatures(thread_tree_);

  if (use_protobuf_) {
    GOOGLE_PROTOBUF_VERIFY_VERSION;
  } else {
    thread_tree_.ShowMarkForUnknownSymbol();
    thread_tree_.ShowIpForUnknownSymbol();
  }

  // 5. Prepare protobuf output stream.
  std::unique_ptr<ProtobufFileWriter> protobuf_writer;
  std::unique_ptr<google::protobuf::io::CopyingOutputStreamAdaptor> protobuf_os;
  std::unique_ptr<google::protobuf::io::CodedOutputStream> protobuf_coded_os;
  if (use_protobuf_) {
    protobuf_writer.reset(new ProtobufFileWriter(report_fp_));
    protobuf_os.reset(new google::protobuf::io::CopyingOutputStreamAdaptor(
        protobuf_writer.get()));
    protobuf_coded_os.reset(
        new google::protobuf::io::CodedOutputStream(protobuf_os.get()));
    coded_os_ = protobuf_coded_os.get();
  }

  // 6. Read record file, and print samples online.
  if (!record_file_reader_->ReadDataSection(
          [this](std::unique_ptr<Record> record) {
            return ProcessRecord(std::move(record));
          })) {
    return false;
  }

  if (use_protobuf_) {
    if (!PrintLostSituationInProtobuf()) {
      return false;
    }
    if (!PrintFileInfoInProtobuf()) {
      return false;
    }
    coded_os_->WriteLittleEndian32(0);
    if (coded_os_->HadError()) {
      LOG(ERROR) << "print protobuf report failed";
      return false;
    }
    protobuf_coded_os.reset(nullptr);
  } else {
    PrintLostSituation();
    fflush(report_fp_);
  }
  if (ferror(report_fp_) != 0) {
    PLOG(ERROR) << "print report failed";
    return false;
  }
  return true;
}

bool ReportSampleCommand::ParseOptions(const std::vector<std::string>& args) {
  for (size_t i = 0; i < args.size(); ++i) {
    if (args[i] == "--dump-protobuf-report") {
      if (!NextArgumentOrError(args, &i)) {
        return false;
      }
      dump_protobuf_report_file_ = args[i];
    } else if (args[i] == "-i") {
      if (!NextArgumentOrError(args, &i)) {
        return false;
      }
      record_filename_ = args[i];
    } else if (args[i] == "-o") {
      if (!NextArgumentOrError(args, &i)) {
        return false;
      }
      report_filename_ = args[i];
    } else if (args[i] == "--protobuf") {
      use_protobuf_ = true;
    } else if (args[i] == "--show-callchain") {
      show_callchain_ = true;
    } else {
      ReportUnknownOption(args, i);
      return false;
    }
  }

  if (use_protobuf_ && report_filename_.empty()) {
    LOG(ERROR) << "please specify a report filename to write protobuf data";
    return false;
  }
  return true;
}

bool ReportSampleCommand::DumpProtobufReport(const std::string& filename) {
  GOOGLE_PROTOBUF_VERIFY_VERSION;
  std::unique_ptr<FILE, decltype(&fclose)> fp(fopen(filename.c_str(), "rb"),
                                              fclose);
  if (fp == nullptr) {
    PLOG(ERROR) << "failed to open " << filename;
    return false;
  }
  ProtobufFileReader protobuf_reader(fp.get());
  google::protobuf::io::CopyingInputStreamAdaptor adaptor(&protobuf_reader);
  google::protobuf::io::CodedInputStream coded_is(&adaptor);
  // map from file_id to max_symbol_id requested on the file.
  std::unordered_map<uint32_t, int32_t> max_symbol_id_map;
  // files[file_id] is the number of symbols in the file.
  std::vector<uint32_t> files;
  uint32_t max_message_size = 64 * (1 << 20);
  uint32_t warning_message_size = 512 * (1 << 20);
  coded_is.SetTotalBytesLimit(max_message_size, warning_message_size);
  while (true) {
    uint32_t size;
    if (!coded_is.ReadLittleEndian32(&size)) {
      PLOG(ERROR) << "failed to read " << filename;
      return false;
    }
    if (size == 0) {
      break;
    }
    // Handle files having large symbol table.
    if (size > max_message_size) {
      max_message_size = size;
      coded_is.SetTotalBytesLimit(max_message_size, warning_message_size);
    }
    auto limit = coded_is.PushLimit(size);
    proto::Record proto_record;
    if (!proto_record.ParseFromCodedStream(&coded_is)) {
      PLOG(ERROR) << "failed to read " << filename;
      return false;
    }
    coded_is.PopLimit(limit);
    if (proto_record.has_sample()) {
      auto& sample = proto_record.sample();
      static size_t sample_count = 0;
      FprintIndented(report_fp_, 0, "sample %zu:\n", ++sample_count);
      FprintIndented(report_fp_, 1, "time: %" PRIu64 "\n", sample.time());
      FprintIndented(report_fp_, 1, "event_count: %" PRIu64 "\n", sample.event_count());
      FprintIndented(report_fp_, 1, "thread_id: %d\n", sample.thread_id());
      FprintIndented(report_fp_, 1, "callchain:\n");
      for (int i = 0; i < sample.callchain_size(); ++i) {
        const proto::Sample_CallChainEntry& callchain = sample.callchain(i);
        FprintIndented(report_fp_, 2, "vaddr_in_file: %" PRIx64 "\n",
                       callchain.vaddr_in_file());
        FprintIndented(report_fp_, 2, "file_id: %u\n", callchain.file_id());
        int32_t symbol_id = callchain.symbol_id();
        FprintIndented(report_fp_, 2, "symbol_id: %d\n", symbol_id);
        if (symbol_id < -1) {
          LOG(ERROR) << "unexpected symbol_id " << symbol_id;
          return false;
        }
        if (symbol_id != -1) {
          max_symbol_id_map[callchain.file_id()] =
              std::max(max_symbol_id_map[callchain.file_id()], symbol_id);
        }
      }
    } else if (proto_record.has_lost()) {
      auto& lost = proto_record.lost();
      FprintIndented(report_fp_, 0, "lost_situation:\n");
      FprintIndented(report_fp_, 1, "sample_count: %" PRIu64 "\n",
                     lost.sample_count());
      FprintIndented(report_fp_, 1, "lost_count: %" PRIu64 "\n",
                     lost.lost_count());
    } else if (proto_record.has_file()) {
      auto& file = proto_record.file();
      FprintIndented(report_fp_, 0, "file:\n");
      FprintIndented(report_fp_, 1, "id: %u\n", file.id());
      FprintIndented(report_fp_, 1, "path: %s\n", file.path().c_str());
      for (int i = 0; i < file.symbol_size(); ++i) {
        FprintIndented(report_fp_, 1, "symbol: %s\n", file.symbol(i).c_str());
      }
      if (file.id() != files.size()) {
        LOG(ERROR) << "file id doesn't increase orderly, expected "
                   << files.size() << ", really " << file.id();
        return false;
      }
      files.push_back(file.symbol_size());
    } else {
      LOG(ERROR) << "unexpected record type ";
      return false;
    }
  }
  for (auto pair : max_symbol_id_map) {
    if (pair.first >= files.size()) {
      LOG(ERROR) << "file_id(" << pair.first << ") >= file count ("
                 << files.size() << ")";
      return false;
    }
    if (static_cast<uint32_t>(pair.second) >= files[pair.first]) {
      LOG(ERROR) << "symbol_id(" << pair.second << ") >= symbol count ("
                 << files[pair.first] << ") in file_id( " << pair.first << ")";
      return false;
    }
  }
  return true;
}

bool ReportSampleCommand::ProcessRecord(std::unique_ptr<Record> record) {
  thread_tree_.Update(*record);
  if (record->type() == PERF_RECORD_SAMPLE) {
    sample_count_++;
    auto& r = *static_cast<const SampleRecord*>(record.get());
    if (use_protobuf_) {
      return PrintSampleRecordInProtobuf(r);
    } else {
      return PrintSampleRecord(r);
    }
  } else if (record->type() == PERF_RECORD_LOST) {
    lost_count_ += static_cast<const LostRecord*>(record.get())->lost;
  }
  return true;
}

bool ReportSampleCommand::PrintSampleRecordInProtobuf(const SampleRecord& r) {
  uint64_t vaddr_in_file;
  uint32_t file_id;
  int32_t symbol_id;
  proto::Record proto_record;
  proto::Sample* sample = proto_record.mutable_sample();
  sample->set_time(r.time_data.time);
  sample->set_event_count(r.period_data.period);
  sample->set_thread_id(r.tid_data.tid);

  bool in_kernel = r.InKernel();
  const ThreadEntry* thread =
      thread_tree_.FindThreadOrNew(r.tid_data.pid, r.tid_data.tid);
  bool ret = GetCallEntry(thread, in_kernel, r.ip_data.ip, false, &vaddr_in_file, &file_id,
                          &symbol_id);
  CHECK(ret);
  proto::Sample_CallChainEntry* callchain = sample->add_callchain();
  callchain->set_vaddr_in_file(vaddr_in_file);
  callchain->set_file_id(file_id);
  callchain->set_symbol_id(symbol_id);

  if (show_callchain_ && (r.sample_type & PERF_SAMPLE_CALLCHAIN)) {
    bool first_ip = true;
    for (uint64_t i = 0; i < r.callchain_data.ip_nr; ++i) {
      uint64_t ip = r.callchain_data.ips[i];
      if (ip >= PERF_CONTEXT_MAX) {
        switch (ip) {
          case PERF_CONTEXT_KERNEL:
            in_kernel = true;
            break;
          case PERF_CONTEXT_USER:
            in_kernel = false;
            break;
          default:
            LOG(DEBUG) << "Unexpected perf_context in callchain: " << std::hex
                       << ip << std::dec;
        }
      } else {
        if (first_ip) {
          first_ip = false;
          // Remove duplication with sample ip.
          if (ip == r.ip_data.ip) {
            continue;
          }
        }
        if (!GetCallEntry(thread, in_kernel, ip, true, &vaddr_in_file, &file_id, &symbol_id)) {
          break;
        }
        callchain = sample->add_callchain();
        callchain->set_vaddr_in_file(vaddr_in_file);
        callchain->set_file_id(file_id);
        callchain->set_symbol_id(symbol_id);
      }
    }
  }
  coded_os_->WriteLittleEndian32(proto_record.ByteSize());
  if (!proto_record.SerializeToCodedStream(coded_os_)) {
    LOG(ERROR) << "failed to write sample to protobuf";
    return false;
  }
  return true;
}

bool ReportSampleCommand::GetCallEntry(const ThreadEntry* thread,
                                       bool in_kernel, uint64_t ip,
                                       bool omit_unknown_dso,
                                       uint64_t* pvaddr_in_file,
                                       uint32_t* pfile_id,
                                       int32_t* psymbol_id) {
  Dso* dso;
  const Symbol* symbol;
  if (!GetCallEntry(thread, in_kernel, ip, omit_unknown_dso, pvaddr_in_file, &dso, &symbol)) {
    return false;
  }
  if (!dso->GetDumpId(pfile_id)) {
    *pfile_id = dso->CreateDumpId();
  }
  if (symbol != thread_tree_.UnknownSymbol()) {
    if (!symbol->GetDumpId(reinterpret_cast<uint32_t*>(psymbol_id))) {
      *psymbol_id = dso->CreateSymbolDumpId(symbol);
    }
  } else {
    *psymbol_id = -1;
  }
  return true;
}

bool ReportSampleCommand::GetCallEntry(const ThreadEntry* thread,
                                       bool in_kernel, uint64_t ip,
                                       bool omit_unknown_dso,
                                       uint64_t* pvaddr_in_file, Dso** pdso,
                                       const Symbol** psymbol) {
  const MapEntry* map = thread_tree_.FindMap(thread, ip, in_kernel);
  if (omit_unknown_dso && thread_tree_.IsUnknownDso(map->dso)) {
    return false;
  }
  *psymbol = thread_tree_.FindSymbol(map, ip, pvaddr_in_file, pdso);
  // If we can't find symbol, use the dso shown in the map.
  if (*psymbol == thread_tree_.UnknownSymbol()) {
    *pdso = map->dso;
  }
  return true;
}

bool ReportSampleCommand::PrintLostSituationInProtobuf() {
  proto::Record proto_record;
  proto::LostSituation* lost = proto_record.mutable_lost();
  lost->set_sample_count(sample_count_);
  lost->set_lost_count(lost_count_);
  coded_os_->WriteLittleEndian32(proto_record.ByteSize());
  if (!proto_record.SerializeToCodedStream(coded_os_)) {
    LOG(ERROR) << "failed to write lost situation to protobuf";
    return false;
  }
  return true;
}

static bool CompareDsoByDumpId(Dso* d1, Dso* d2) {
  uint32_t id1 = UINT_MAX;
  d1->GetDumpId(&id1);
  uint32_t id2 = UINT_MAX;
  d2->GetDumpId(&id2);
  return id1 < id2;
}

bool ReportSampleCommand::PrintFileInfoInProtobuf() {
  std::vector<Dso*> dsos = thread_tree_.GetAllDsos();
  std::sort(dsos.begin(), dsos.end(), CompareDsoByDumpId);
  for (Dso* dso : dsos) {
    uint32_t file_id;
    if (!dso->GetDumpId(&file_id)) {
      continue;
    }
    proto::Record proto_record;
    proto::File* file = proto_record.mutable_file();
    file->set_id(file_id);
    file->set_path(dso->Path());
    const std::vector<Symbol>& symbols = dso->GetSymbols();
    std::vector<const Symbol*> dump_symbols;
    for (const auto& sym : symbols) {
      if (sym.HasDumpId()) {
        dump_symbols.push_back(&sym);
      }
    }
    std::sort(dump_symbols.begin(), dump_symbols.end(),
              Symbol::CompareByDumpId);

    for (const auto& sym : dump_symbols) {
      std::string* symbol = file->add_symbol();
      *symbol = sym->DemangledName();
    }
    coded_os_->WriteLittleEndian32(proto_record.ByteSize());
    if (!proto_record.SerializeToCodedStream(coded_os_)) {
      LOG(ERROR) << "failed to write file info to protobuf";
      return false;
    }
  }
  return true;
}

bool ReportSampleCommand::PrintSampleRecord(const SampleRecord& r) {
  uint64_t vaddr_in_file;
  Dso* dso;
  const Symbol* symbol;

  FprintIndented(report_fp_, 0, "sample:\n");
  FprintIndented(report_fp_, 1, "time: %" PRIu64 "\n", r.time_data.time);
  FprintIndented(report_fp_, 1, "event_count: %" PRIu64 "\n", r.period_data.period);
  FprintIndented(report_fp_, 1, "thread_id: %d\n", r.tid_data.tid);
  bool in_kernel = r.InKernel();
  const ThreadEntry* thread =
      thread_tree_.FindThreadOrNew(r.tid_data.pid, r.tid_data.tid);
  bool ret = GetCallEntry(thread, in_kernel, r.ip_data.ip, false, &vaddr_in_file, &dso, &symbol);
  CHECK(ret);
  FprintIndented(report_fp_, 1, "vaddr_in_file: %" PRIx64 "\n", vaddr_in_file);
  FprintIndented(report_fp_, 1, "file: %s\n", dso->Path().c_str());
  FprintIndented(report_fp_, 1, "symbol: %s\n", symbol->DemangledName());

  if (show_callchain_ && (r.sample_type & PERF_SAMPLE_CALLCHAIN)) {
    FprintIndented(report_fp_, 1, "callchain:\n");
    bool first_ip = true;
    for (uint64_t i = 0; i < r.callchain_data.ip_nr; ++i) {
      uint64_t ip = r.callchain_data.ips[i];
      if (ip >= PERF_CONTEXT_MAX) {
        switch (ip) {
          case PERF_CONTEXT_KERNEL:
            in_kernel = true;
            break;
          case PERF_CONTEXT_USER:
            in_kernel = false;
            break;
          default:
            LOG(DEBUG) << "Unexpected perf_context in callchain: " << std::hex
                       << ip;
        }
      } else {
        if (first_ip) {
          first_ip = false;
          // Remove duplication with sample ip.
          if (ip == r.ip_data.ip) {
            continue;
          }
        }
        if (!GetCallEntry(thread, in_kernel, ip, true, &vaddr_in_file, &dso, &symbol)) {
          break;
        }
        FprintIndented(report_fp_, 2, "vaddr_in_file: %" PRIx64 "\n",
                       vaddr_in_file);
        FprintIndented(report_fp_, 2, "file: %s\n", dso->Path().c_str());
        FprintIndented(report_fp_, 2, "symbol: %s\n", symbol->DemangledName());
      }
    }
  }
  return true;
}

void ReportSampleCommand::PrintLostSituation() {
  FprintIndented(report_fp_, 0, "lost_situation:\n");
  FprintIndented(report_fp_, 1, "sample_count: %" PRIu64 "\n", sample_count_);
  FprintIndented(report_fp_, 1, "lost_count: %" PRIu64 "\n", sample_count_);
}

}  // namespace

void RegisterReportSampleCommand() {
  RegisterCommand("report-sample", [] {
    return std::unique_ptr<Command>(new ReportSampleCommand());
  });
}
