import pysam
import pandas as pd
import sys
import os
import re
# Input BAM and BED files
bam_path = sys.argv[1]
#bed_path = sys.argv[2]
APOBEC=sys.argv[2]

"""
if APOBEC=="A3A":
	bed_path="/home/users/ayh/Projects/27_A3B/03_sequencing/single_clone_sequencing/WTS/20_endogenous_expression/A3A.txt"
elif APOBEC=="A3B":
	bed_path="/home/users/ayh/Projects/27_A3B/03_sequencing/single_clone_sequencing/WTS/20_endogenous_expression/A3B.txt"
"""
#bed_path="/home/users/ayh/Projects/27_A3B/03_sequencing/single_clone_sequencing/WTS/20_endogenous_expression/test/A3B.txt"
# BAM filename only (used for "id" column)
#bam_id = os.path.basename(bam_path).split(".")[0]
filename = os.path.basename(bam_path)

# Use regex to extract the [condition] part
bam_id= re.match(r"(.+)\.star2\.mark\.bam$", filename).group(1)
#print filename
#print bam_id.group(1)
output_txt=bam_id+"."+APOBEC+".base.txt"
# Open BAM file
bamfile = pysam.AlignmentFile(bam_path, "rb")

count_C = 0
count_T = 0
#count_CT = 0
#count_TC = 0
count_else = 0

total_spanning = 0
total_F2R1 = 0

chrom="22"
pos1 = 39381999 -1   # 0-based
pos2 = 39381999 -1 


#results = []
#k=0
#j=1

for read in bamfile.fetch(chrom, pos1, pos2 + 1):
#    print read.query_name
#    if read.is_unmapped or read.is_duplicate or read.mapping_quality < 20:
#	print "hit1"
#        continue
    if not ((read.is_read1 and read.is_reverse and read.is_paired and not read.mate_is_reverse) or (read.is_read2 and not read.is_reverse and read.is_paired and read.mate_is_reverse)):
#	print "hit2"
        continue  # only F2R1 orientation
    total_F2R1 += 1

    if read.is_unmapped or read.is_duplicate or read.mapping_quality < 20:
#	print "hit1"
        continue
    ref_positions = read.get_reference_positions(full_length=True)
    query_bases = read.query_sequence
    base_quals = read.query_qualities

    try:
        idx1 = ref_positions.index(pos1)
 #       idx2 = ref_positions.index(pos2)
    except ValueError:
        continue  # doesn't span both

    if base_quals[idx1] < 20 :
        continue

    base1 = query_bases[idx1]
#    base2 = query_bases[idx2]
    total_spanning += 1

    if base1 == 'C':
        count_C += 1
    elif base1 == 'T':
        count_T += 1
    else:
        count_else+=1

with open(output_txt, 'w') as f:
    f.write("id\tAPOBEC\tchr\tpos1\tspanning_total\tqualified_reads\tC\tT\telse\n")
#    f.write(f"{bam_id}\t{APOBEC}\t{chrom}\t{pos1+1}\t{pos2+1}\t{total_F2R1}\t{total_spanning}\t{count_CC}\t{count_TT}\t{count_CT}\t{count_TC}\n")
    f.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n"%(bam_id,APOBEC,chrom,pos1+1,total_F2R1,total_spanning,count_C,count_T,count_else))



####
"""
with open(bed_path) as bed:
	for line in bed:
		chrom, start, end = line.strip().split()[:3]
		start = int(start)-1
		end = int(end)
#		print start
#		print end
		for pos in range(start, end):
#			print pos
			base_counts = {'A': 0, 'C': 0, 'G': 0, 'T': 0}

			for read in bamfile.fetch(chrom, pos, pos + 1):
#				print j
#				j+=1
				if read.is_unmapped or read.mapping_quality < 20 or read.is_duplicate:
					continue

				ref_positions = read.get_reference_positions(full_length=True)
				query_bases = read.query_sequence
				base_quals = read.query_qualities

				for read_idx, ref_pos in enumerate(ref_positions):
					if ref_pos == pos:
						if read_idx is not None and read_idx < len(query_bases):
							if base_quals[read_idx] >= 20:
								base = query_bases[read_idx]
								if base in base_counts:
									base_counts[base] += 1
						break

			results.append({
				"#CHROM": chrom,
				"POS": pos+1,
				"A": base_counts['A'],
				"C": base_counts['C'],
				"G": base_counts['G'],
				"T": base_counts['T'],
				"id": bam_id,
				"APOBEC": APOBEC
			})

# Convert to DataFrame
df = pd.DataFrame(results)

# Reorder columns
#print df.columns.tolist()
df = df[["id","APOBEC","#CHROM", "POS", "A", "C", "G", "T"]]

# Save to tab-delimited .txt file
df.to_csv(output_txt, sep='\t', index=False)

# Print (Python 2.7 style)
#print df
"""
