######################################################
#
# Make a copy of this file and edit it to your liking
#
######################################################

# Which population to plot
species_names = ["Bacteroides_vulgatus_57955"]
 
# Output filename
filename = parse_midas_data.analysis_directory+'Bacteroides_vulgatus_antibiotic_snp_change_timecourse.png'


####################################################################
#
#  Function controlling which mutations are colored
#
#      Returns: true for colored
#               false for not colored
#
####################################################################
def color_condition(population_idx, chromosome, location, gene_name, variant_type, times, freqs, depths):
    
    start_idxs = (times==sample_time_map[parse_timecourse_data.highcoverage_start_2])
    antibiotic_idxs = (times==sample_time_map[parse_timecourse_data.highcoverage_antibiotic])
    lyme_idxs = (times==sample_time_map[parse_timecourse_data.highcoverage_lyme])
    end_idxs = (times==sample_time_map[parse_timecourse_data.highcoverage_end])
    antibiotic_idxs = numpy.nonzero(times==sample_time_map[parse_timecourse_data.highcoverage_antibiotic])[0]
    

    
    if (end_idxs.sum() > 0) and (antibiotic_idxs.sum() > 0):
        
        antibiotic_freq = freqs[antibiotic_idxs][0]
        end_freq = freqs[end_idxs][0]
        
        if antibiotic_freq < 0.5:
            antibiotic_freq = 1-antibiotic_freq
            end_freq = 1-end_freq
            freqs = 1-freqs
        
        if (antibiotic_freq>0.8) and (end_freq>0.6) and (freqs<0.5).any():
            condition = True
        else:
            condition = False
            
    else:
        condition = False
        
    #condition = ( gene_name=='435590.9.peg.242' )
    
    #condition = (freqs.max()-freqs.min() > 0.8)
    #condition = (freqs[0]<0.05)*(freqs[1]<0.05)*((freqs.max()-freqs.min()) > 0.5)
    # Examples:
    # 
    # Mutation is in the majority at the end of the experiment
    # condition = f(60000) > 0.5
    #
    # Mutation is a structural variant in the nadR gene
    # condition = (gene_name == 'nadR') and var_type='sv'
    #
    # Mutation is in one of the mut* genes
    # condition = gene_name.startswith('mut')
    # ... 
    
    if condition:
        items = [chromosome, location, gene_name, variant_type]
        print_str = ", ".join([str(item) for item in items])
        sys.stderr.write("%s\n" % print_str)
    
    return condition
