#include "Rivet/Analysis.hh"
#include "Rivet/Projections/FinalState.hh"
#include "Rivet/Projections/FastJets.hh"

namespace Rivet {

  /// @brief Monte Carlo validation observables for HH->4b production
  class MC_HH4B : public Analysis {
  public:

    /// @name Constructors etc.
    //@{
    /// Constructor
    RIVET_DEFAULT_ANALYSIS_CTOR(MC_HH4B);

    //@}
    /// @name Analysis methods
    //@{
    /// Book histograms and initialise projections before the run
    void init() {

      // set ptcut from input option
      const double jetptcut = getOption<double>("PTJMIN", 50.0);
      _jetptcut = jetptcut * GeV;

      // set clustering radius from input option
      const double R = getOption<double>("R", 0.4);

      // set clustering algorithm from input option
      JetAlg clusterAlgo;
      const string algoopt = getOption("ALGO", "ANTIKT");
      if ( algoopt == "KT" ) {
        clusterAlgo = JetAlg::KT;
      } else if ( algoopt == "CA" ) {
        clusterAlgo = JetAlg::CA;
      } else if ( algoopt == "ANTIKT" ) {
        clusterAlgo = JetAlg::ANTIKT;
      } else {
        MSG_WARNING("Unknown jet clustering algorithm option " + algoopt + ". Defaulting to anti-kT");
        clusterAlgo = JetAlg::ANTIKT;
      }

      declare(FastJets(FinalState(), clusterAlgo, R), "jets");

      /// Book histograms
      for (const string& type : vector<string>{"bjet_"s, "ljet_"s}) {
        for (const string& num : vector<string>{"1_"s, "2_"s, "3_"s, "4_"s}) {
          book(_h[type+num+"eta"],  type+num+"eta",  50, -4.,   4.);
          book(_h[type+num+"phi"],  type+num+"phi",  50,  0.,   1.);
          book(_h[type+num+"pT"],   type+num+"pT",   50,  0., 500.);
          book(_h[type+num+"mass"], type+num+"mass", 50,  0., 500.);
        }
      }
      book(_d["nbjets"], "nbjets", {0, 1, 2, 3, 4, 5, 6, 7, 8});
      book(_d["njets"],  "njets", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});

      book(_h["H1Mass"],        "HiggsMass1",    50, 0.,  250.); // Higgs candidate 1 mass
      book(_h["H2Mass"],        "HiggsMass2",    50, 0.,  250.); // Higgs candidate 2 mass
      book(_h["H1Pt"],          "HiggsPt1",      50, 0., 1000.); // Higgs candidate 1 pt
      book(_h["H2Pt"],          "HiggsPt2",      50, 0., 1000.); // Higgs candidate 2 pt
      book(_h["DiHiggsMass"],   "DiHiggsMass",   50, 0., 2000.); // Di-Higgs system mass
      book(_h["DiHiggsPt"],     "DiHiggsPt",     50, 0., 1000.); // Di-Higgs system pT
      book(_h["HiggsMassDiff"], "HiggsMassDiff", 50, 0.,  100.); // Di-Higgs system mass
    }

    /// Perform the per-event analysis
    void analyze(const Event& event) {

      const Jets jets = apply<FastJets>(event, "jets").jetsByPt(Cuts::pT > _jetptcut && Cuts::abseta < 4.0);
      _d["njets"]->fill(jets.size());

      // Identify the b-jets
      Jets bjets;
      int ljets = 0;
      for (const Jet& jet : jets) {
        const double jetEta = jet.eta();
        const double jetPhi = jet.phi();
        const double jetPt = jet.pT();
        const double jetMass = jet.mass();

        if (jet.bTagged()) {
          bjets += jet;
          if (bjets.size() > 4)  continue;
          const string pre("bjet_"+toString(bjets.size()));
          _h[pre+"_eta"]->fill(jetEta);
          _h[pre+"_phi"]->fill(jetPhi/2/M_PI);
          _h[pre+"_pT"]->fill(jetPt/GeV);
          _h[pre+"_mass"]->fill(jetMass/GeV);
        }
        else if (ljets < 4) {
          const string pre("ljet_"+toString(++ljets));
          _h[pre+"_eta"]->fill(jetEta);
          _h[pre+"_phi"]->fill(jetPhi/2/M_PI);
          _h[pre+"_pT"]->fill(jetPt/GeV);
          _h[pre+"_mass"]->fill(jetMass/GeV);
        }
      }
      _d["nbjets"]->fill(bjets.size());

      // if(bjets.empty()) vetoEvent;

      if (bjets.size() < 4) vetoEvent;

      double best_mass_diff = FLT_MAX;
      std::pair<Jet, Jet> bpair1, bpair2;
      for (size_t i = 0; i < bjets.size(); ++i) {
        for (size_t j = i + 1; j < bjets.size(); ++j) {
          for (size_t k = 0; k < bjets.size(); ++k) {
            if (k == i || k == j) continue;
            for (size_t l = k + 1; l < bjets.size(); ++l) {
              if (l == i || l == j) continue;

              // Calculate masses
              const double m1 = (bjets[i].mom() + bjets[j].mom()).mass();
              const double m2 = (bjets[k].mom() + bjets[l].mom()).mass();
              const double mass_diff = fabs(m1 - m2);

              bool isOneHiggsInMassRange = inRange(m1, 95*GeV, 145*GeV) || inRange(m2, 95*GeV, 145*GeV);
              if (!isOneHiggsInMassRange)  continue;

              // Check for best pairing
              // and mass_diff < 20.0*GeV && isOneHiggsInMassRange
              if (mass_diff < best_mass_diff && mass_diff < 20*GeV) {
                best_mass_diff = mass_diff;
                bpair1 = make_pair(bjets[i], bjets[j]);
                bpair2 = make_pair(bjets[k], bjets[l]);
              }
            }
          }
        }
      }
      if (best_mass_diff > 20*GeV) vetoEvent;

      // Reconstruct Higgs masses
      const FourMomentum H1 = bpair1.first.mom() + bpair1.second.mom();
      const FourMomentum H2 = bpair2.first.mom() + bpair2.second.mom();

      bool h1_inrange = inRange(H1.mass(), 95*GeV, 145*GeV);
      bool h2_inrange = inRange(H2.mass(), 95*GeV, 145*GeV);
      if (!h1_inrange && !h2_inrange)  vetoEvent;

      const double mass_diff = fabs(H1.mass() - H2.mass());

      // Fill histograms (redundant mass check)
      _h["H1Mass"]->fill(H1.mass()/GeV);
      _h["H1Pt"]->fill(H1.pT()/GeV);

      _h["H2Mass"]->fill(H2.mass()/GeV);
      _h["H2Pt"]->fill(H2.pT()/GeV);

      _h["HiggsMassDiff"]->fill(mass_diff/GeV);

      // Reconstruct di-Higgs
      const FourMomentum dihiggs = H1 + H2;
      _h["DiHiggsMass"]->fill(dihiggs.mass()/GeV);
      _h["DiHiggsPt"]->fill(dihiggs.pT()/GeV);
    }

    /// Normalise histograms etc., after the run
    void finalize() {
      const double sf = crossSection()/femtobarn/sumOfWeights();
      scale(_h, sf);
      scale(_d, sf);
    }

  //@}

  private:

  /// @name Histograms
  //@{

    map<string,Histo1DPtr> _h;
    map<string,BinnedHistoPtr<int>> _d;
    double _jetptcut;

  //@}

  };

  RIVET_DECLARE_PLUGIN(MC_HH4B);
}
