#! /usr/bin/env python3

"""\
Usage: %(prog)s <PDF1> [<PDF2> [...]]

Plot PDF (ratio) and alphaS values for the named LHAPDF6 sets.

A set of plots will be generated for single-flavour and multi-flavour PDFs
against x for various Q values, and against Q for various x flavours. The
parton IDs, and discrete x and Q values can be customised with the --pids,
--xs, and --qs options, and the plot types restricted using --plots. Limits
to the continuous axis values of x and Q can be specified with --xmin and
--qmax, and the PDF-value display optionally limited using --ymax/--ymin
(and --ylin to force a linear scale). Shorthands for PDF names and PIDs can
be given via the --pdf-aliases and --pid-aliases options.

If the --ratio option is given, ratios of PDFs will be shown rather than
absolute values. The ratios are computed with the values of the first-listed
PDF on the denominator; this mode hence requires at least two PDF arguments
to be given.

TODO:
 * Allow user specification of the various PDF/parton line colours and styles
 * Show PDF error bands
 * Speed up plotting using parallelisation
 * USe dynamic sampling for smoothness
"""

import sys, os

import argparse
ap = argparse.ArgumentParser(usage=__doc__)
ap.add_argument("PNAMES", metavar="NAME", nargs="+", help="PDF members to include in the plots")
ap.add_argument("--xmin", dest="XMIN", metavar="NUM", help="minimum x value [default: %(default)s]", type=float,
                default=1e-10)
ap.add_argument("--qmin", dest="QMIN", metavar="NUM", help="minimum Q value in GeV [default: %(default)s]", type=float,
                default=2.0)
ap.add_argument("--qmax", dest="QMAX", metavar="NUM", help="maximum Q value in GeV [default: %(default)s]", type=float,
                default=1e4)
ap.add_argument("--ylin", dest="YLIN", action="store_true", help="use a linear scale for the y-axis", default=False)
ap.add_argument("--ymin", dest="YMIN", metavar="NUM", help="minimum y value [default: %(default)s]", type=float,
                default=None)
ap.add_argument("--ymax", dest="YMAX", metavar="NUM", help="maximum y value [default: %(default)s]", type=float,
                default=None)
ap.add_argument("--ratio", dest="RATIO", action="store_true", help="plot as ratios wrt the first PDF (requires > 1)", default=False)
ap.add_argument("-f", "--format", dest="FORMAT", metavar="F",
                help="plot file format, i.e. file extension pdf/png/... [default: %(default)s]", default="pdf")
ap.add_argument("--qs", dest="QS", metavar="Q1,Q2,...",
                help="discrete Q values to use on plots vs. x [default: %(default)s]", default="1,10,100,1000,10000")
ap.add_argument("--xs", dest="XS", metavar="X1,X2,...",
                help="discrete x values to use on plots vs. Q [default: %(default)s]", default="1e-5,1e-3,1e-2,1e-1")
ap.add_argument("--pids", dest="PIDS", metavar="ID1,ID2,...",
                help="PID values to use on PDF plots [default: %(default)s]", default="0,1,2,3,4,5,-1,-2,-3,-4,-5")
ap.add_argument("--plots", dest="PLOTS", metavar="PLOT1,PLOT2,...",
                help="plot types to show, default value lists all types [default: %(default)s]",
                default="alphas,xf_x/pid,xf_x/q,xf_q/pid,xf_q/x")
ap.add_argument("--pdf-aliases", dest="PDFALIASES", metavar="PATT/REPL,PATT/REPL,...",
                help="pattern/replacement pairs for PDF names in the plot legend", default=None)
ap.add_argument("--pid-aliases", dest="PIDALIASES", metavar="PATT/REPL,PATT/REPL,...",
                help="pattern/replacement pairs for PID names in the plot labelling", default=None)
ap.add_argument("--tex", dest="TEX", action="store_true", help="use TeX rendering backend for text", default=False)
ap.add_argument("-q", "--quiet", dest="VERBOSITY", action="store_const", const=0,
                help="suppress non-essential messages", default=1)
ap.add_argument("-v", "--verbose", dest="VERBOSITY", action="store_const", const=2, help="output debug messages",
                default=1)
args = ap.parse_args()

args.PLOTS = args.PLOTS.upper().split(",")
if not args.PNAMES:
    print(__doc__)
    sys.exit(1)
if args.RATIO and len(args.PNAMES) < 2:
    print("Ratio plots with respect to the first PDF require more than one PDF to be specified")
    print(__doc__)
    sys.exit(2)

import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"
plt.rcParams["pdf.use14corefonts"] = True
#plt.rcParams["font.serif"] = "Palatino" #"Computer Modern Roman"
if args.TEX:
    plt.rcParams["text.usetex"] = True

STYLES = ["-", "--", "-.", (0, (5, 2, 1, 2, 1, 2)), ":"]
COLORS = ["red", "blue", "darkgreen", "orange", "purple", "magenta", "gray", "cyan"]
PARTONS = {-5: r"$\bar{b}$", -4: r"$\bar{c}$", -3: r"$\bar{s}$", -2: r"$\bar{u}$", -1: r"$\bar{d}$",
           1: r"$d$", 2: r"$u$", 3: r"$s$", 4: r"$c$", 5: r"$b$", 0: r"$g$"}

## Construct PDF-name aliases
PDFALIASES = {}
if args.PDFALIASES:
    for pname_palias in args.PDFALIASES.split(","):
        parts = pname_palias.split("/")
        if len(parts) == 1:
            parts.append("")
        PDFALIASES[parts[0]] = parts[1]

## Construct PID aliases
PIDALIASES = {}
if args.PIDALIASES:
    for pname_palias in args.PIDALIASES.split(","):
        parts = pname_palias.split("/")
        if len(parts) == 1:
            parts.append("")
        PIDALIASES[parts[0]] = parts[1]


def pdfalias(pname):
    parts = pname.split("/")
    # print(parts[0], "->", PDFALIASES.get(parts[0]))
    parts[0] = PDFALIASES.get(parts[0], parts[0])
    return "/".join(parts)

def pidalias(pid):
    # print(pid, "->", PIDALIASES.get(str(pid), str(pid)))
    return PIDALIASES.get(str(pid), str(pid))


def tex_str(a):
    if not args.TEX: return a
    return a.replace("_", r"\_").replace("#", r"\#")

def tex_float(f):
    float_str = "{0:.2g}".format(f)
    if "e" in float_str:
        mant, exp = float_str.split("e")
        exp = int(exp)
        return r"{0} \times 10^{{{1}}}".format(mant, exp)
    else:
        return float_str


## Get sampling points in x,Q
from math import log10, sqrt
import numpy as np
xs = np.hstack(( np.logspace(log10(args.XMIN), -1, 150), np.linspace(0.1, 1.0, 50) ))
qs = np.logspace(log10(args.QMIN), log10(args.QMAX), 200)
xs_few = [float(x) for x in args.XS.split(",")]  # [1e-5, 1e-3, 1e-2, 1e-1]
qs_few = [float(q) for q in args.QS.split(",")]  # [1, 10, 100, 1000, 10000]
pids = [int(i) for i in args.PIDS.split(",")]  # [0] + range(1,5+1) + [-i for i in range(1,5+1)]
# print(xs_few, qs_few, pids)

## Load PDFs for plotting, indexed by name
import lhapdf
lhapdf.setVerbosity(args.VERBOSITY)
pdfs = {pname: lhapdf.mkPDF(pname) for pname in args.PNAMES}
print()

# Determine common valid x and Q^2 ranges for all PDFs and filter sampling grids
xmins = [pdfs[p].xMin for p in args.PNAMES]
xmaxs = [pdfs[p].xMax for p in args.PNAMES]
q2mins = [pdfs[p].q2Min for p in args.PNAMES]
q2maxs = [pdfs[p].q2Max for p in args.PNAMES]

xmin_valid = max(xmins)
xmax_valid = min(xmaxs)
q2min_valid = max(q2mins)
q2max_valid = min(q2maxs)

xs_valid     = [x for x in xs     if xmin_valid <= x <= xmax_valid]
xs_few_valid = [x for x in xs_few if xmin_valid <= x <= xmax_valid]

q2_vals      = [q*q for q in qs]
q2_valid     = [q2 for q2 in q2_vals      if q2min_valid <= q2 <= q2max_valid]
qs_valid     = [sqrt(q2) for q2 in q2_valid]

q2_few_vals  = [q*q for q in qs_few]
q2_few_valid = [q2 for q2 in q2_few_vals if q2min_valid <= q2 <= q2max_valid]
qs_few_valid = [sqrt(q2) for q2 in q2_few_valid]

xs = xs_valid
qs = q2_valid
xs_few = xs_few_valid
qs_few = qs_few_valid

## Make PDF xf vs. x & Q plots for each parton flavour, and a single alpha_s vs. Q plot
fig = plt.figure()
ax = fig.add_subplot(111)

## alpha_s vs Q plot
if "ALPHAS" in args.PLOTS:
    plt.cla()
    for i, pname in enumerate(args.PNAMES):
        color = COLORS[i % len(COLORS)]
        as_vals = [pdfs[pname].alphasQ(q) for q in qs]
        ax.plot(qs, as_vals, label=tex_str(pdfalias(pname)), color=color, ls="-")
    ax.set_xlabel("$Q$")
    ax.set_ylabel(r"$\alpha_s(Q)$")
    ax.set_ylim(bottom=0)
    ax.set_xscale("log")
    l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
    fname = "alpha_s.{}".format(args.FORMAT)
    if args.VERBOSITY > 0:
        print("Writing plot file", fname)
    plt.tight_layout()
    fig.savefig(fname)

## xf vs. x plots (per PID)
if "XF_X/PID" in args.PLOTS:
    for pid in pids:
        plt.cla()
        #ax.text(0.95, 0.5, "PID={}".format(pidalias(pid)), transform=ax.transAxes, ha="right", va="top")
        plt.title(f"PID={pidalias(pid)}", loc="right")
        ax.set_xlabel("$x$")
        ax.set_ylabel("$xf(x,Q)$")
        if args.RATIO:
            ax.set_ylabel("$xf(x,Q) ~/~ xf_0(x,Q)$")

        for i, pname in enumerate(args.PNAMES):
            if i == 0 and args.RATIO:
                continue
            for j, q in enumerate(qs_few):
                title = "{}, $Q={}~\\mathrm{{GeV}}$".format(tex_str(pdfalias(pname)), tex_float(q))
                color = COLORS[i % len(COLORS)]
                style = STYLES[j % len(STYLES)]
                xf_vals = np.array([pdfs[pname].xfxQ(pid, x, q) for x in xs])
                if args.RATIO:
                    title = "{}/{}, $Q={}~\\mathrm{{GeV}}$".format(tex_str(pdfalias(pname)), tex_str(pdfalias(args.PNAMES[0])), tex_float(q))
                    xf_vals /= np.array([pdfs[args.PNAMES[0]].xfxQ(pid, x, q) for x in xs])
                plt.plot(xs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        ax.set_xlim(left=args.XMIN)
        if args.RATIO:
            plt.axhline(1.0, color="gray", ls="-", lw=1.0, alpha=0.5)
        if args.YMIN:
            ax.set_ylim(bottom=args.YMIN)
        if args.YMAX:
            ax.set_ylim(top=args.YMAX)

        if not args.RATIO and not args.YLIN:
            ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_pid{:d}_x.{}".format(pid, args.FORMAT)
        if args.RATIO:
            fname = fname.replace("pdf_", "pdfratio_")
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        plt.tight_layout()
        fig.savefig(fname)

## xf vs. x plots (per Q)
if "XF_X/Q" in args.PLOTS:
    for j, q in enumerate(qs_few):
        plt.cla()
        # ax.text(0.95, 0.5, "$Q={}~\\mathrm{{GeV}}$".format(tex_float(q)), transform=ax.transAxes, ha="right", va="top")
        plt.title("$Q={}~\\mathrm{{GeV}}$".format(tex_float(q)), loc="right")
        ax.set_xlabel("$x$")
        ax.set_ylabel("$xf(x,Q={}~\\mathrm{{GeV}})$".format(tex_float(q)))
        if args.RATIO:
            ax.set_ylabel("$xf(x,Q) ~/~ xf_0(x,Q)$")

        for pid in pids:
            for i, pname in enumerate(args.PNAMES):
                if i == 0 and args.RATIO:
                    continue
                title = "{}, PID={}".format(tex_str(pdfalias(pname)), pidalias(pid))
                color = COLORS[pid % len(COLORS)]
                style = STYLES[i % len(STYLES)]
                xf_vals = np.array([pdfs[pname].xfxQ(pid, x, q) for x in xs])
                if args.RATIO:
                    title = "{}/{}, PID={}".format(tex_str(pdfalias(pname)), tex_str(pdfalias(args.PNAMES[0])), pidalias(pid))
                    xf_vals /= np.array([pdfs[args.PNAMES[0]].xfxQ(pid, x, q) for x in xs])
                plt.plot(xs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        ax.set_xlim(left=args.XMIN)
        if args.RATIO:
            plt.axhline(1.0, color="gray", ls="-", lw=1.0, alpha=0.5)
        if args.YMIN:
            ax.set_ylim(bottom=args.YMIN)
        if args.YMAX:
            ax.set_ylim(top=args.YMAX)

        if not args.RATIO and not args.YLIN:
            ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_q{:d}_x.{}".format(int(q), args.FORMAT)
        if args.RATIO:
            fname = fname.replace("pdf_", "pdfratio_")
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        plt.tight_layout()
        fig.savefig(fname)

## xf vs. Q plots (per PID)
if "XF_Q/PID" in args.PLOTS:
    for pid in pids:
        plt.cla()
        #ax.text(0.05, 0.7, "PID={}".format(pidalias(pid)), transform=ax.transAxes, ha="left", va="center")
        plt.title(f"PID={pidalias(pid)}", loc="right")
        ax.set_xlabel("$Q$")
        ax.set_ylabel("$xf(x,Q)$")
        if args.RATIO:
            ax.set_ylabel("$xf(x,Q) ~/~ xf_0(x,Q)$")

        for i, pname in enumerate(args.PNAMES):
            if i == 0 and args.RATIO:
                continue
            for j, x in enumerate(xs_few):
                title = "{}, $x={}$".format(tex_str(pdfalias(pname)), tex_float(x))
                color = COLORS[i % len(COLORS)]
                style = STYLES[j % len(STYLES)]
                xf_vals = np.array([pdfs[pname].xfxQ(pid, x, q) for q in qs])
                if args.RATIO:
                    title = "{}/{}, $x={}$".format(tex_str(pdfalias(pname)), tex_str(pdfalias(args.PNAMES[0])), tex_float(x))
                    xf_vals /= np.array([pdfs[args.PNAMES[0]].xfxQ(pid, x, q) for q in qs])
                plt.plot(qs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        ax.set_xlim(left=args.QMIN, right=args.QMAX)
        if args.RATIO:
            plt.axhline(1.0, color="gray", ls="-", lw=1.0, alpha=0.5)
        if args.YMIN:
            ax.set_ylim(bottom=args.YMIN)
        if args.YMAX:
            ax.set_ylim(top=args.YMAX)

        if not args.RATIO and not args.YLIN:
            ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_pid{:d}_q.{}".format(pid, args.FORMAT)
        if args.RATIO:
            fname = fname.replace("pdf_", "pdfratio_")
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        plt.tight_layout()
        fig.savefig(fname)

## xf vs. Q plots (per x)
if "XF_Q/X" in args.PLOTS:
    for j, x in enumerate(xs_few):
        plt.cla()
        # ax.text(0.95, 0.5, "$x={}$".format(tex_float(x)), transform=ax.transAxes, ha="right", va="bottom")
        plt.title("$x={}$".format(tex_float(x)), loc="right")
        ax.set_xlabel("$Q$")
        ax.set_ylabel("$xf(x={},Q)$".format(tex_float(x)))
        if args.RATIO:
            ax.set_ylabel("$xf(x,Q) ~/~ xf_0(x,Q)$")

        for pid in pids:
            for i, pname in enumerate(args.PNAMES):
                if i == 0 and args.RATIO:
                    continue
                title = "{}, PID={}".format(tex_str(pdfalias(pname)), pidalias(pid))
                color = COLORS[pid % len(COLORS)]
                style = STYLES[i % len(STYLES)]
                xf_vals = np.array([pdfs[pname].xfxQ(pid, x, q) for q in qs])
                if args.RATIO:
                    title = "{}/{}, PID={}".format(tex_str(pdfalias(pname)), tex_str(pdfalias(args.PNAMES[0])), pidalias(pid))
                    xf_vals /= np.array([pdfs[args.PNAMES[0]].xfxQ(pid, x, q) for q in qs])
                plt.plot(qs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        ax.set_xlim(left=args.QMIN, right=args.QMAX)
        if args.RATIO:
            plt.axhline(1.0, color="gray", ls="-", lw=1.0, alpha=0.5)
        if args.YMIN:
            ax.set_ylim(bottom=args.YMIN)
        if args.YMAX:
            ax.set_ylim(top=args.YMAX)

        if not args.RATIO and not args.YLIN:
            ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_x{:0.2f}_q.{}".format(x, args.FORMAT)
        if args.RATIO:
            fname = fname.replace("pdf_", "pdfratio_")
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        plt.tight_layout()
        fig.savefig(fname)

plt.close(fig)
print()
