import pandas as pd
import numpy as np
import os
import os.path
import glob
import sys

# script to extract the top hits of RNAbound predictions from singel sequence analysis

flank = 10
bpthr = 0.0005

def get_snorna_details():
    
    snorna = pd.read_csv("dataset/annotations/snoRNAdb_Rfam_overlap.bed", sep="\t", header=None)
    snorna_rfam_match={}
    for idx, row in snorna.iterrows():
        if 'infernal' in row[9]:
            snorna_rfam_match["%s_%s_%s" % (row[0], row[1], row[2])] = row[9].split("|")[2]
            
    return snorna_rfam_match

snorna_rfam_match = get_snorna_details()


def group_shapes(shape):  
    if len(shape)==2:
        return "I"
    elif "[[]" not in shape:
        return "II"
    elif shape.startswith("[[]") and shape.endswith("]]"):
        return "III"
    else:
        return "IV"
    

rnashape = pd.read_csv("dataset/annotations/structure_percentage_rnashape.txt", sep="\t")

def get_shapegroup_name(row):
    #print(row['name'])
    if 'Rfam' in row['name']:
        rid = row['name'].split("|")[3]
        #print(rid)
        tt = rnashape[rnashape['RfamID']==rid]
        if len(tt)>0:
            return row['name'], tt['rnashape'].values[0], group_shapes(tt['rnashape'].values[0])
        else:
            return row['name'],"no_shape", "unclassified"

    elif 'snoRNA' in row['name'] and not 'miRNA' in row['name']:
        if "%s_%s_%s" % (row['chr'], row['anno_start'], row['anno_end']) in snorna_rfam_match.keys():
            rid = snorna_rfam_match["%s_%s_%s" % (row['chr'], row['anno_start'], row['anno_end'])]
            tt = rnashape[rnashape['RfamID']==rid]
            if len(tt)>0:
                return row['name'],tt['rnashape'].values[0], group_shapes(tt['rnashape'].values[0])
            else:
                return row['name'],"no_shape", "unclassified"
    elif 'miRNA' in row['name']:
        return row['name'],"[]", "I"
    elif 'tRNA' in row['name']:
        return row['name'],"[[][][]]", "III"
    else:
        return row['name'],"no_shape", "unclassified"


windows = [100, 150, 200]

collect_results = []
for win in windows:
    
    annot = pd.read_csv("dataset/windows/window_%s_benchmark_data.bed" % win, sep="\t")
        
    tt = annot.apply(get_shapegroup_name, axis=1) 
    annot['shape'] = [ a[1] for a in tt ]
    annot['group'] = [ a[2] for a in tt ]
    print(annot['group'].value_counts())
    
    for idx, row in annot.iterrows():

        # single sequence analysis (RNAfold)
        segfile="results/rnafold_rnabound_window_%s_%s_%s/%s_%s_%s_segments.txt.gz" % (win, flank, bpthr, row['chr'], row['window_start'], row['window_end'])
        if not os.path.isfile(segfile):
            print("%s not found" % segfile)
            #segfile="/home/projects/rth/rnabound/multiz_mapping/output/15wayalignment/rnafold_rnabound_window_%s_10/%s_%s_%s_segments.txt.gz" % (win, row['chr'], row['window_start'], row['window_end'])

        tmpname = "%s_%s_%s_%s" % (row['chr'], row['anno_start'], row['anno_end'], row['name'])
        if os.path.isfile(segfile):
            tcollect = []
            try:
                tmp = pd.read_csv(segfile, sep="\t")
                tmp['len'] = tmp['end_l']-tmp['start_k']+1
                bp_start = row['anno_start']-row['window_start']
                bp_end = bp_start + (row['anno_end']-row['anno_start'])
                seqlen=row['window_end']-row['window_start']
                famname=row['familyname']
                group = row['group']
                if row['strand'] == "+":
                    tcollect += [win, bp_start, bp_end]
                    for scol in ['D_kl_w', 'Dotu_segmen']:
                        hit = tmp.sort_values(["%s" % scol], ascending=False).head(1) 
                        tcollect += [ hit['start_k'].values[0], hit['end_l'].values[0],
                                 bp_start-hit['start_k'].values[0], hit['end_l'].values[0]-bp_end]


                else:

                    nbp_start = seqlen - bp_end + 1
                    nbp_end = seqlen - bp_start + 1
                    tcollect += [win, bp_start, bp_end]
                    for scol in ['D_kl_w', 'Dotu_segmen']:
                        hit = tmp.sort_values(["%s" % scol], ascending=False).head(1)
                        tcollect += [hit['start_k'].values[0], hit['end_l'].values[0],
                                 nbp_start-hit['start_k'].values[0], hit['end_l'].values[0]-nbp_end]

                collect_results.append(list(row.values) + tcollect) 

            except ValueError:
                print(segfile)

df = pd.DataFrame(collect_results)
df.columns = list(annot.columns) + ['winsize', 'bp_start', 'bp_end', 'rnabound_start_k', 'rnabound_end_l', 'rnabound_left_diff', 'rnabound_right_diff',
                                    'dotu_start_k', 'dotu_end_l', 'dotu_left_diff', 'dotu_right_diff']
df.to_csv("results/summary/rnabound_dotu_boundaries_singleseq_f%s_p%s.tsv" % (flank, bpthr), sep="\t", header=True, index=False)
