__version__ = "2.0.0"
include: "modules/init.snakefile"
include: "modules/sav.snakefile"
include: "modules/fastqc.snakefile"
include: "modules/trimming.snakefile"
include: "modules/mapping.snakefile"
include: "modules/classification.snakefile"
#-------------------< Helper functions >---------------------------------------------------------#
from modules.json_output import write_summary_json, write_summary_json_new, get_fastqc_results, combine_csv, get_plot_type_names
from modules.utils import which

def trimming_input(wildcards):
    if not config["notrimming"]:
        if geninfo_config["Sample information"]["type"] == "PE":
            return expand("{path}/trimmed/{sample}_{read}_fastqc",
                          path = fastqc_path, read=["R1", "R2"],
                          sample=geninfo_config["Sample information"]["samples"])
        else:
            return expand("{path}/trimmed/{sample}_fastqc",
                          path = fastqc_path, sample=sample_dict.keys())
    else:
        return None

def get_input(wildcards, if_not, ext, samplelist =[], path = ""):
    if path !="":
        path += "/"
    if not config[if_not]:
        if samplelist:
            return expand("{path}{sample}{ext}" ,
                          path = path, ext=ext, sample=samplelist)
        return expand("{path}{sample}{ext}",
                      path = path, ext=ext, sample=wildcards.sample)
    else:
        return ""

def get_all_fastqc(wildcards, path = fastqc_path + "/raw"):
    '''
    Generate raw sample names

    Note:
        I(Rene) believe that this should also return the _fastqc_data.txt
        files, because they are required by trimbetter and the summary.
    '''
    return ["%s/%s_fastqc%s" % (
                path,
                geninfo_config["Sample information"]["rename"][get_name(x)],
                fastqc_stat)
            for x in unique_samples[wildcards.sample]
            for fastqc_stat in ["","/fastqc_data.txt"]]

def get_trimmomatic_fastqc(wildcards, ext, path = trimming_path):
    '''
    Generate list of filepaths ending with read identifying string and _fastqc

    Returns:
        obj::`list` of filenames

    Example:
        ["Path/to/QCResults/FastQC/trimmed/Sample1_S1_L001_R1_fastqc",
         "Path/to/QCResults/FastQC/trimmed/Sample1_S1_L001_R2_fastqc"]
    '''
    if config["notrimming"]:
        return []
    paired = []
    if geninfo_config["Sample information"]["type"]=="PE" and ext =="_fastqc":
        paired =["_R1","_R2"]
    if wildcards.sample in geninfo_config["Sample information"]["samples"].keys():
        if paired:
            return expand("{path}/{sample}{paired}{ext}",
                          sample= wildcards.sample, ext = ext, path = path,
                          paired = paired)
        else:
            return expand("{path}/{sample}{ext}" , sample= wildcards.sample,
                          ext = ext, path = path)
    else:
        if paired:
            return expand("{path}/{sample}{paired}{ext}",
                          sample=(geninfo_config["Sample information"]
                                        ["join_lanes"][wildcards.sample]),
                          ext = ext, path = path, paired = paired)
        else:
            return expand("{path}/{sample}{ext}",
                          sample=(geninfo_config["Sample information"]
                                        ["join_lanes"][wildcards.sample]),
                          ext = ext, path = path)
    assert False, "Something went wrong"


def get_trimmomatic_pseudofile(wildcards):
    '''
    Provides locations for pseudofiles used to force trimmomatic to run

    This used to be done with log files, which caused those to disappear in
    case of an error.

    These files have been used to report to get_trimmomatic_results(),
    but as they do not contain any data they produced bad values in report.
    '''
    if wildcards.sample in geninfo_config["Sample information"]["samples"].keys():
        return expand("{path}/{sample}.trimmomatic.log" ,
                      sample= wildcards.sample, path=log_path)
    else:
        return expand("{path}/{sample}.trimmomatic.log",
                      sample=(geninfo_config["Sample information"]
                                    ["join_lanes"][wildcards.sample]),
                      path = log_path)

def get_trimmomatic_params(wildcards):
    if wildcards.sample in geninfo_config["Sample information"]["samples"].keys():
        return expand("{path}/{sample}.trimmomatic.params",
                      sample = wildcards.sample, path=trimming_path)
    else:
        return expand("{path}/{sample}.trimmomatic.params",
                      sample=(geninfo_config["Sample information"]
                                    ["join_lanes"][wildcards.sample]),
                      path=trimming_path)

def get_batch_files(wildcards):
    steps = {"summary_json": data_path + "/summary.json"}
    # if pdflatex is not installed on the system, skip pdf output files
    if which("pdflatex") is not None:
        steps["sample_report"] = expand("{path}/{sample}.pdf",
                                      sample=unique_samples.keys(),
                                      path=main_path)
    if config["sav"]:
        steps["sav"] = sav_results
    if not config["nokraken"]:
        steps["kraken_html"] = main_path + "/kraken.html"
        steps["kraken_png"] = classification_path + "/kraken_batch.png"
    return steps
#--------------------------------------------< RULES >-----------------------------------------------------------------#

rule run_all:
    input:
        main_path + "/batch_report.html",
        lambda wildcards: ((
            "%s/%s.sam" % (mapping_path, samp)
             for samp in unique_samples.keys()) if config["save_mapping"] else [])
    params:
        save_mapping = config["save_mapping"]

rule write_final_report:
    input:
        unpack(get_batch_files)
    output:
        main_path + "/batch_report.html"
    run:
        #shell("cp {source} {output}", source = join(geninfo_config["QCumber_path"], "batch_report.html"))
        env = Environment(
            trim_blocks=True,
            variable_start_string='{{~', variable_end_string="~}}")
        env.loader = FileSystemLoader(geninfo_config["QCumber_path"])
        template = env.get_template("batch_report.html")
        summary = json.load(open(str(input.summary_json), "r"))
        general_information = json.load(
            open( data_path + "/general_information.json", "r"))
        if config["sav"]:
            sav = json.load(open( str(input.sav), "r"))
            sav_json = json.dumps(sav)
        else:
            sav_json = []
        #sav = json.load(open(str(input.general_information), "r"), object_pairs_hook=OrderedDict)
        geninfo_config["Commandline"] = cmd_input

        html = template.render(
            general_information= json.dumps(config),
            summary = json.dumps(summary["Results"]),
            summary_img = json.dumps(summary["summary_img"]),
            sav = sav_json )
        html_file  = open(str(output), "w")
        html_file.write(html)
        html_file.close()

# Write PDF report for each sample

def get_steps_per_sample(wildcards):
    '''
    Get dictionary of steps required to write sample output

    sets up filenames required by rule "get_sample_json"
    These vary depending on the arguments provided by the user

    Affected by:
        notrimming | reference | nokraken | nomapping

    Returns:
        steps (obj::`dict`): dictonary of required steps
                             key is obj::`str` step
                             value is obj::`list`(obj::`str`) filenames
    '''
    steps = {"raw_fastqc" : get_all_fastqc(wildcards)}
    if not config["notrimming"]:
        steps["trimming"]=  get_trimmomatic_pseudofile(wildcards)
        steps["trimming_params"] = get_trimmomatic_params(wildcards)
        steps["trimming_fastqc"] = get_trimmomatic_fastqc(
            wildcards,
            "_fastqc", path=fastqc_path + "/trimmed")
    if config["reference"] or config["index"]:
        steps["mapping"] = get_input(
            wildcards, if_not="nomapping",
            ext=".bowtie2.log", samplelist=[], path=log_path)
    if not config["nokraken"]:
        steps["kraken"] =   get_input(
            wildcards,if_not = "nokraken", ext=".csv", samplelist=[],
            path = classification_path ) # "{path}/{wildcards.sample}.kraken.png".format(path = classification_path, wildcards=wildcards)
        steps["kraken_log"] = get_input(
            wildcards,if_not = "nokraken", ext=".kraken.log",
            samplelist=[], path = log_path)
    return steps

'''        raw_fastqc = get_all_fastqc,
        trimming =get_trimmomatic_log,
        trimming_params = lambda wildcards: get_trimmomatic_params(wildcards),
        trimming_fastqc = lambda wildcards:  get_trimmomatic_fastqc(wildcards, "_fastqc", path = fastqc_path + "/trimmed"),
        mapping = lambda wildcards: get_input(wildcards,if_not = "nomapping", ext=".bowtie2.log",samplelist=[], path = log_path),
        kraken = lambda wildcards: get_input(wildcards,if_not = "nokraken", ext=".csv", samplelist=[], path = classification_path ),
        kraken_log = lambda wildcards: get_input(wildcards,if_not = "nokraken", ext=".kraken.log", samplelist=[], path = log_path)
        '''

def get_sample_json_output():
    output = {
      "json": data_path + "/{sample}.json",
      "newjson" : data_path + "/{sample}_new.json",
    }
    for plot_type_name in get_plot_type_names():
        output["samplecsv" + plot_type_name] = temp(data_path + "/{sample}_" + plot_type_name + ".csv")
    if not config["nokraken"]:
        output["kraken_plot"] = classification_path + "/{sample}.kraken.png"

    return output
'''
##### Note: Most run time bugs are some how involved with this rule ######

It calls getter functions from submodule snakefiles found in "./modules/"
This rule has lots of side effects
'''
rule write_sample_json:
    input:
        unpack(get_steps_per_sample)
    output:
        **get_sample_json_output()
    params:
        notrimming=config["notrimming"],
        nokraken=config["nokraken"],
        nomapping=config["nomapping"]
    message:
        "Write {wildcards.sample}.json"
    run:
        summary_dict = OrderedDict()
        summary_dict["Name"] = wildcards.sample
        summary_dict["Files"] = unique_samples[wildcards.sample]
        summary_dict["Date"] = datetime.date.today().isoformat()
        paired_end = geninfo_config["Sample information"]["type"] == "PE"
        fastqc_dict, total_seq ,overrepr_count, adapter_content = (
            get_fastqc_results(
                parameter,
                (x for x in input.raw_fastqc if x[-4:] != ".txt" ),
                data_path , "raw", to_base64,
                paired_end=paired_end)) #"QCResults/Report/tmp"
        summary_dict["Total sequences"] = total_seq
        summary_dict["%Overrepr sequences"] = overrepr_count
        summary_dict["%Adapter content"] = adapter_content
        summary_dict["raw_fastqc_results"] = fastqc_dict

        if not params.notrimming:
            summary_dict.update(get_trimmomatic_result(
                list(input.trimming),
                list(input.trimming_params)))
            print(input.trimming)
            fastqc_dict, total_seq, overrepr_count, adapter_content = (
                get_fastqc_results(parameter, input.trimming_fastqc, data_path,"trimmed", to_base64))
            if fastqc_dict !=[]:
                summary_dict["trimmed_fastqc_results"] = fastqc_dict
                summary_dict["%Overrepr sequences (trimmed)"] = overrepr_count
                summary_dict["%Adapter content (trimmed)"] = adapter_content
                # sort dict order
                new_order = ["Name", "Files", "Date", "Total sequences",
                             "#Remaining Reads","%Remaining Reads",
                             "%Adapter content","%Adapter content (trimmed)",
                             "%Overrepr sequences",
                             "%Overrepr sequences (trimmed)",
                             "raw_fastqc_results","trimmed_fastqc_results"]
                new_order.extend(list(
                    set(summary_dict.keys()) - set(new_order)))
                summary_dict = OrderedDict(
                    (key, summary_dict[key]) for key in new_order)
        if not params.nomapping:
            summary_dict.update(get_bowtie2_result(str(input.mapping)))
            summary_dict["Reference"] = config["reference"]
        if not params.nokraken:
            kraken_results = get_kraken_result(
                str(input.kraken), str(output.kraken_plot))
            if kraken_results:
                summary_dict.update(kraken_results)
                kraken_log = ""
                with open(str(input.kraken_log),"r") as kraken_reader:
                    for line in kraken_reader.readlines():
                        if "..." not in line:
                            kraken_log +=line
                summary_dict["kraken_log"] = kraken_log
        json.dump(summary_dict, open(str(output.json), "w"))

        fastqc_dict, total_seq ,overrepr_perc, adapter_content = (
            get_fastqc_results(parameter,
                (x for x in input.raw_fastqc if x[-4:] != ".txt" ),
                data_path , "raw", to_base64))
        res = dict()
        res["Sample"] = dict()
        res["Sample"]["Name"] = wildcards.sample
        res["Sample"]["TS"] = total_seq
        res["Sample"]["PAC"] = adapter_content
        res["Sample"]["PORS"] = overrepr_perc
        res["Sample"]["POST"] = "N/A"
        res["Sample"]["PACT"] = "N/A"
        res["Sample"]["NRR"] = "N/A"
        res["Sample"]["PRR"] = "N/A"
        res["Sample"]["NAR"] = "N/A"
        res["Sample"]["PAR"] = "N/A"
        res["Sample"]["NC"] = "N/A"
        res["Sample"]["PC"] = "N/A"


        if not config["notrimming"]:
            fastqc_dict, total_seq, overrepr_perc, adapter_content = (
                get_fastqc_results(parameter, input.trimming_fastqc, data_path,"trimmed", to_base64))
            trimmomatic_results = get_trimmomatic_result(list(input.trimming), list(input.trimming_params))

            res["Sample"]["POST"] = overrepr_perc
            res["Sample"]["PACT"] = adapter_content
            res["Sample"]["NRR"] = trimmomatic_results["#Remaining Reads"]
            res["Sample"]["PRR"] = trimmomatic_results["%Remaining Reads"]
        if not config["nomapping"]:
            mapping_result = get_bowtie2_result(str(input.mapping))
            res["Sample"]["NAR"] = mapping_result["#AlignedReads"]
            res["Sample"]["PAR"] = mapping_result["%AlignedReads"]
        if not config["nokraken"]:
            kraken_results = get_kraken_result(str(input.kraken), str(output.kraken_plot))
            if kraken_results is None:
                res["Sample"]["NC"] = "N/A"
                res["Sample"]["PC"] = "N/A"
        json.dump(res, open(str(output.newjson), "w"))


def get_report_info(wildcards):
    steps = {
        "sample_json" : "{path}/{sample}.json".format(
            sample = wildcards.sample, path = data_path),
        "raw_fastqc" : get_all_fastqc(wildcards)}
    if not config["notrimming"]:
        try:
            trimmed_path = fastqc_path + "/trimmed"
                               # ((fastqc_path + "/trimmed") # not needed and
                               # missing parentheses
                               #   if not True # config["trimBetter"]
                               #   else (trimbetter_path + "/FastQC"))
        except KeyError:
            trimmed_path = fastqc_path + "/trimmed"

        steps["trimming_fastqc"]= get_trimmomatic_fastqc(
            wildcards, "_fastqc", path=trimmed_path)
    #if not config["nokraken"]:
    #    steps["kraken"] = classification_path + "/{sample}.translated".format(sample = wildcards.sample)
    return steps


rule write_sample_report:
    input:
        unpack(get_report_info) #sample_json = data_path + "/{sample}.json"
    output:
        temp(main_path + "/{sample}.aux"),
        pdf=main_path + "/{sample}.pdf",
        tex=temp(main_path + "/{sample}.tex")
    log:
        log_path + "/texreport.log"
    message:
        "Write {wildcards.sample}.pdf"
    run:
        env = Environment(trim_blocks = True, variable_start_string='{{~',
                          variable_end_string = "~}}")
        env.loader = FileSystemLoader(geninfo_config["QCumber_path"])
        template = env.get_template("report.tex")

        sample = json.load(open(str(input.sample_json),"r"),
                           object_pairs_hook=OrderedDict )
        if "Reference" in sample.keys():
            sample["Reference"] = basename(sample["Reference"] )
        sample["path"] = dirname(sample["Files"][0])
        sample["Files"] = [basename(x) for x in sample["Files"]]
        # import pprint; pprint.pprint(sample)
        pdf_latex = template.render(
            #general_information=json.load(open(str(input.general.json),"r")),
            general_information=geninfo_config,
            sample=sample)
        latex = open(str(output.tex), "w")
        latex.write(pdf_latex)
        latex.close()
        #shell( "pdflatex -interaction=nonstopmode -output-directory=$(dirname {output.pdf}) {output.tex} -shell-escape 1>&2> {log}" )
        with open(log[0], 'a') as f_log:
            with subprocess.Popen(
                ["pdflatex", "-interaction=nonstopmode",
                 "-output-directory=%s" % dirname(output.pdf), output.tex],
                stdout=f_log, stderr=sys.stdout) as pdflatex_proc:
                    pdflatex_proc.wait()
        # dont knopw how to get rid of this log
        # shell("mv {log} {mv_log}", log = str(output.pdf).replace(".pdf",
        #                                                         ".log"),
        #      mv_log = str(log).replace("texreport.",
        #      "." + wildcards.sample + "."))



rule write_kraken_report:
    input:
        kraken = lambda wildcards: get_input(
            wildcards, if_not = "nokraken", ext = ".csv",
            samplelist= unique_samples.keys() , path = classification_path)
    output:
        kraken_html = main_path + "/kraken.html"
    shell:
        "ktImportText {input.kraken} -o {output.kraken_html}"

def get_files_of_all_steps():
    steps = {"raw_fastqc": expand(
                "{path}/raw/{sample}_fastqc",
                sample=sample_dict.keys(), path=fastqc_path)}
    if not config["notrimming"]:
        steps["trimming"] = trimming_input
    if not config["nomapping"]:
        steps["mapping"] = lambda wildcards: get_input(
            wildcards, if_not="nomapping", ext=".sam",
            samplelist=unique_samples.keys(),path=mapping_path)
    if not config["nokraken"]:
        steps["kraken_png"] = classification_path + "/kraken_batch.png",
    steps["sample_json"] = expand(
        "{path}/{sample}.json", sample=unique_samples.keys(), path=data_path)
    return steps


def get_batch_output():
    ''' Creation of dictonary that stores the
    output of steps required to finish one batch

    summary_json: Path/2/_data/summary.json
    fastqc_plots: GC_content | length distribution
                  | per sequence quality scores

    '''
    steps = {}
    steps["summary_json"] = data_path + "/summary.json"
    steps["summary_json_new"] = data_path + "/summary_new.json"
    steps["fastqc_plots"] = list(
        expand("{path}/{img}.png", path="QCResults/_data",
               img=["Per_sequence_GC_content", "Per_sequence_quality_scores",
               "Sequence_Length_Distribution"])
        )
    steps["n_read_plot"] = "QCResults/_data/reads_after_trimming.png"
    if not config["nomapping"]:
        steps["mapping_plot"] = "QCResults/_data/mapping.png"
        steps["insertsize_plot"] = "QCResults/_data/insertsize.png"

    return steps

def get_batch_report_input():
    steps={}
    steps["sample_json"] = expand("{path}/{sample}.json", sample=unique_samples.keys(), path=data_path)
    steps["sample_json_new"] = expand("{path}/{sample}_new.json", sample=unique_samples.keys(), path=data_path)
    steps["samplecsv"] = expand(data_path + "/{sample}_{plot_type}.csv", sample=unique_samples.keys(), plot_type=get_plot_type_names())
    if not config["nokraken"]:
        steps["kraken_batch"] = classification_path + "/kraken_batch.png"
    if config["reference"] or config["index"]:
        steps["insertsize"] = expand("{mapping_path}/{sample}_insertsizes.txt", sample=unique_samples.keys(), mapping_path=mapping_path)
    return steps

# Write html report for all samples
rule write_batch_report:
    input:
        #sample_json = expand("{path}/{sample}.json", sample=unique_samples.keys(), path=data_path)
        **get_batch_report_input()
    output:
        **get_batch_output()
    params:
        nokraken = config["nokraken"]
    run:
        combine_csv(input.samplecsv, data_path)
        fastqc_csv = expand("{path}/{img}.csv", path="QCResults/_data",
                            img = ["Per_sequence_GC_content",
                                   "Per_sequence_quality_scores",
                                   "Sequence_Length_Distribution"])

        write_summary_json(output, config, input, fastqc_csv, geninfo_config, boxplots, shell, get_name, to_base64)
        write_summary_json_new(output, input.sample_json_new)

