import sys
import os
import glob
import math
import re
import pandas as pd
import numpy as np

k_event_type_marker = 34
k_event_type_pushpop = 59
k_event_type_startend = 60

# DTSP-14650: Code cleanup
def drop_payload(nvtx_df):
    
    payload_columns = ['uint64Value', 'int64Value', 'doubleValue', 'uint32Value', 'int32Value', 'floatValue', 'jsonTextId', 'jsonText']
    if (set(payload_columns).issubset(nvtx_df.columns) == False):
        return nvtx_df
    return nvtx_df.drop(columns=payload_columns)
  
def combine_text_fields(str_df, nvtx_df, nvtx_write_df):
 
    #exit early if textId is already missing
    if ({'textId'}.issubset(nvtx_df.columns) == False):
        return nvtx_df
    
    #merge will ruin the index so we need to force it into the data
    nvtx_textId_df = nvtx_df.reset_index()
    
    nvtx_textId_df = nvtx_textId_df.loc[ (pd.isna( nvtx_textId_df['textId'] ) == False) ]
    if(nvtx_textId_df.shape[0] == 0):
        return nvtx_df

    nvtx_textId_df = nvtx_textId_df[['index','textId']]
    nvtx_textId_df = nvtx_textId_df.merge(pd.DataFrame(data={'textId':str_df['id'], 'textIdStr':str_df['value']}), how='left', on='textId')

    #restore index ruined by merge
    nvtx_textId_df.set_index(nvtx_textId_df['index'], inplace=True)
    
    #swap out 'original' but we're already working on a new DF, just referencing the loaded data
    nvtx_write_df.loc[nvtx_textId_df.index, 'text'] = nvtx_textId_df['textIdStr']
    nvtx_df = nvtx_write_df.drop(columns=['textId'])
    
    #finally commit for future
    return nvtx_df


#compute level and parent
def compute_pushpop_extras(nvtx_df_arg):
    shameful_printf_debug =0

    orig_df = nvtx_df_arg
        
    #make sure we can get back to the original index
    nvtx_df_arg = nvtx_df_arg.reset_index()

    #necessary for all not NA for groupby to succeed
    nvtx_df_arg = nvtx_df_arg.copy()
    nvtx_df_arg['domainId'].fillna( 0, inplace=True )

    
    #only relevent for push/pop & ones where we know the start
    nvtx_df_arg = nvtx_df_arg.loc[ 
        ((nvtx_df_arg['eventType'] == k_event_type_pushpop)
        & (nvtx_df_arg['start'].isnull() == False))
        ]

    #print("Push/Pop Total: ", nvtx_df_arg.shape[0])
   
    nvtx_level_list = list()
    nvtx_parent_list = list()
   
    nvtx_gdf = nvtx_df_arg.groupby(['globalTid', 'domainId'])
    #print("Group Count: ", nvtx_gdf.ngroups )
   
    for group_key, nvtx_df_arg in nvtx_gdf:
        
        nvtx_row_count = nvtx_df_arg.shape[0]
        #print("Group Key: ", nvtx_row_count)
        #print("Group Size: ", group_key)

        #ensure sorted.  Should be already within a thread but not between threads
        nvtx_df_arg = nvtx_df_arg.sort_values(by='start', ignore_index=True)
      
        nvtx_start_s = nvtx_df_arg['start']
        nvtx_end_s = nvtx_df_arg['end']
        nvtx_text_s = nvtx_df_arg['text']
        nvtx_index_s = nvtx_df_arg['index']

        nvtx_level_np = np.zeros(nvtx_row_count)
        nvtx_parent_np = np.zeros(nvtx_row_count)
    
        stack = [] 
        #depth = 0
        for cur_index in range(nvtx_row_count):
            #print(cur_index)
            depth = len(stack)
            while True:
                if (depth == 0):
                    #empty
                    nvtx_level_np[cur_index] = depth
                    nvtx_parent_np[cur_index] = -1 #NAN not supported for integers

                    stack.append(cur_index)
                    depth = len(stack)
                    #depth += 1
                    if shameful_printf_debug:
                        print("PushEmpty D:",depth, nvtx_text_s.iat[cur_index])
                    break;
                else:                
                    top_index = stack[depth-1]
                    top_text = nvtx_text_s.iat[top_index]
                    if (nvtx_start_s.iat[cur_index] < nvtx_end_s.iat[top_index] ): 
                        ##push

                        #nvtx_df_arg.at[index, 'level'] = depth+1
                        nvtx_level_np[cur_index] = depth
                        nvtx_parent_np[cur_index] = nvtx_index_s.iat[top_index]

                        stack.append(cur_index)
                        depth=len(stack)                    
                        #depth += 1; #prep for next

                        if shameful_printf_debug:
                            print("Push D:",depth, nvtx_text_s.iat[cur_index])
                        break;
                    else:
                        ##pop

                        stack.pop()
                        depth = len(stack)
                        #depth -= 1

                        if shameful_printf_debug:
                            print("Pop  D:", depth, top_text)
                        ##DO NOT BREAK BECAUSE THIS MAY HAPPEN MULTIPLE TIMES
                        #end-if
                    #end-if
                depth = len(stack)
                #end-while
            #end-if
        #end-for
         
        nvtx_level_list.append(pd.Series(nvtx_level_np, index=nvtx_df_arg['index'], dtype='Int64'))
        nvtx_parent_list.append(pd.Series(nvtx_parent_np, index=nvtx_df_arg['index'], dtype='Int64'))
   
    nvtx_level_ps = pd.concat( nvtx_level_list )
    nvtx_parent_ps = pd.concat( nvtx_parent_list )
    nvtx_parent_ps = nvtx_parent_ps.replace(-1, None)
    extra_df = pd.DataFrame(data={'level':nvtx_level_ps,'parent':nvtx_parent_ps})
    return pd.concat( [orig_df, extra_df], axis=1)


def cuda_projection( str_df, nvtx_df, cuda_api_df, cuda_graph_api_df, cuda_kernel_df, cuda_memcopies_df, cuda_memsets_df ):
    nvtx_df = combine_text_fields(str_df, nvtx_df, nvtx_df)
    cuda_apis_launching_work_df = __cuda_apis_reduce_to_gpu_workloads(cuda_api_df, cuda_graph_api_df, str_df)
    cuda_correlation_nvtx_map, cuda_graphnode_nvtx_map = __cuda_find_nvtx_ranges(cuda_apis_launching_work_df, nvtx_df)
    cuda_gpu_df = __nvtx_add_gpu_ranges(nvtx_df, cuda_correlation_nvtx_map, cuda_graphnode_nvtx_map, cuda_kernel_df, cuda_memcopies_df, cuda_memsets_df)
    return cuda_gpu_df

def __cuda_apis_reduce_to_gpu_workloads(cuda_api_df, cuda_graph_api_df, str_df): 
    #DEBUG: MERGE IN FUNCTION NAMES
    #cuda_api_df = cuda_api_df.merge(pd.DataFrame({'nameId':str_df['id'], 'nameStr':str_df['value']}), how='left', on='nameId')
    #print(cuda_api_df)

    #IDs we want to project
    cuda_launch_str_df = str_df[ 
        str_df['value'].str.startswith( 'cu' ) 
        & (str_df['value'].str.startswith( (
        "cudaLaunch", 
        "cuLaunch", 
        "cudaMemcpy",
        "cuMemcpy",
        "cudaMemset",
        "cuMemset"
        "cudaGraphLaunch", 
        "cuGraphLaunch",
        "cudaGraphAddKernel",
        "cuGraphAddKernel",
        "cudaGraphAddMemcpy",
        "cuGraphAddMemcpy",
        "cudaGraphAddMemset",
        "cuGraphAddMemset",
        )) == True) 
        & (str_df['value'].str.startswith(
        'cudaLaunchHostFunc', 'cuLaunchHostFunc') == False) 
        ]
    #print(cuda_launch_str_df)

    #filter to just functions of interest
    cuda_api_df = cuda_api_df[ cuda_api_df['nameId'].isin( cuda_launch_str_df['id'] ) ]
    
    cuda_work_df = cuda_api_df[['correlationId', 'globalTid', 'nameId', 'start']]
    if(cuda_graph_api_df != None):
        cuda_graph_api_df = cuda_graph_api_df[ cuda_graph_api_df['nameId'].isin( cuda_launch_str_df['id'] ) ]
        cuda_work_df = pd.concat([
            cuda_work_df,
            cuda_graph_api_df[['graphNodeId', 'originalGraphNodeId','globalTid', 'nameId', 'start']],
            ])
    
    #print(cuda_work_df)

    return cuda_work_df

def __cuda_find_nvtx_ranges(cuda_apis_launching_work_df, nvtx_df):
    #NVTX index to set of CUDA correlation IDs
    cudaCorrelationToNVTX = dict() #OrderedDict()
    cudaGraphNodeIdToNVTX = dict() #OrderedDict()
    
    #filter open ranges
    #filter ranges ending on different threads
    #NOTE: inplace causes error due to mixed types
    nvtx_df = nvtx_df[ (nvtx_df['start'].isnull() == False) & (nvtx_df['end'].isnull() == False) & (nvtx_df['endGlobalTid'].isnull() == True) ]

    #group by thread
    nvtx_gdf = nvtx_df.groupby(by="globalTid");
    cuda_apis_launching_work_gdf = cuda_apis_launching_work_df.groupby(by="globalTid");

    #find what ranges the CUDA API call starts in
    for globalTid, nvtx_tid_df in nvtx_gdf:
        #get matching CUDA TID dataframe
        cuda_tid_df = cuda_apis_launching_work_gdf.get_group(globalTid)

        #build NVTX points on a singular timeline
        k_nvtx_start = 0
        k_nvtx_end = 1
        nvtx_starts_df = pd.DataFrame(data={'time':nvtx_tid_df['start'], 'id':nvtx_tid_df.index}) #, 'type':k_nvtx_start, 'text':nvtx_tid_df['text']} )        
        nvtx_ends_df = pd.DataFrame(data={'time':nvtx_tid_df['end'], 'id':nvtx_tid_df.index}) #, 'type':k_nvtx_end, 'text':nvtx_tid_df['text']} )
        
        nvtx_starts_df = nvtx_starts_df.sort_values("time")
        nvtx_starts_itr = iter(nvtx_starts_df.itertuples())        
        nvtx_starts_row = next(nvtx_starts_itr)
        nvtx_start_time = nvtx_starts_row.time
        nvtx_starts_done = (nvtx_starts_df.shape[0] == 0)
        nvtx_starts_count = 0
        
        nvtx_ends_df = nvtx_ends_df.sort_values("time")
        nvtx_ends_itr = iter(nvtx_ends_df.itertuples())
        nvtx_ends_row = next(nvtx_ends_itr)
        nvtx_end_time = nvtx_ends_row.time
        nvtx_ends_done = (nvtx_ends_df.shape[0] == 0)
        nvtx_ends_count = 0


        cuda_tid_df = cuda_tid_df.sort_values("start")
        cuda_itr = iter(cuda_tid_df.itertuples())
        cuda_row = next(cuda_itr)
        cuda_time = cuda_row.start        
        cuda_done = (cuda_tid_df.shape[0] == 0)
        cuda_count = 0
        cuda_inside_nvtx = 0
        cuda_outside_nvtx = 0
        
        nvtx_active_bag=set() 
            
        i = 0
        #print("Starting thread")
        print_stack = False
        while (not cuda_done) & (not nvtx_ends_done): # & (i < 20):
            i += 1
                  
            if (nvtx_starts_done == False) & (nvtx_start_time <= nvtx_end_time) & (nvtx_start_time <= cuda_time):
                if print_stack: 
                    print("Starting: ", nvtx_starts_row.id )

                nvtx_id = nvtx_starts_row.id
                nvtx_active_bag.add(nvtx_id)
                
                #iterate      
                try:
                    nvtx_starts_row = next(nvtx_starts_itr)
                    nvtx_start_time = nvtx_starts_row.time
                    nvtx_starts_count += 1
                except StopIteration:
                    nvtx_start_time = sys.maxsize
                    nvtx_starts_done = True
                    
            elif (nvtx_ends_done == False) & (nvtx_end_time <= cuda_time):
                if print_stack: 
                    print("Ending: ", nvtx_ends_row.id)
                
                nvtx_id = nvtx_ends_row.id
                nvtx_active_bag.remove(nvtx_id)

                #iterate      
                try:
                    nvtx_ends_row = next(nvtx_ends_itr)
                    nvtx_end_time = nvtx_ends_row.time
                    nvtx_ends_count += 1
                except StopIteration:
                    nvtx_end_time = sys.maxsize
                    nvtx_ends_done=True   
                    
            elif (cuda_done == False):
                if print_stack: 
                    print("CUDA:", list(nvtx_active_bag.keys()))
                
                #copy nvtx bag
                if len(nvtx_active_bag) > 0:
                    correlationId = getattr(cuda_row, 'correlationId', None)
                    if correlationId != None:
                        cudaCorrelationToNVTX[correlationId] = nvtx_active_bag.copy()
                    else:
                        cudaGraphNodeIdToNVTX[cuda_row.graphNodeId] = nvtx_active_bag.copy()
                    cuda_inside_nvtx += 1
                else:
                    cuda_outside_nvtx += 1
                    
                #iterate      
                try:
                    cuda_row = next(cuda_itr)
                    cuda_time = cuda_row.start
                    cuda_count += 1    
                except StopIteration:
                    cuda_time = sys.maxsize
                    cuda_done=True
            #end if
        #end while
        #print("Done with thread.") 
        if print_stack: 
            print("nvtx_starts_count ", nvtx_starts_count ) 
            print("nvtx_ends_count ", nvtx_ends_count )  
            print("cuda_inside_nvtx ", cuda_inside_nvtx ) 
            print("cuda_outside_nvtx ", cuda_outside_nvtx ) 

    return cudaCorrelationToNVTX, cudaGraphNodeIdToNVTX
        
def __nvtx_add_gpu_ranges(
    nvtx_df,
    cuda_correlation_nvtx_map,
    cuda_graphnode_nvtx_map,
    cuda_kernel_df,
    cuda_memcopies_df,
    cuda_memsets_df):
    
    #DEBUG: MERGE IN KERNEL NAMES
    #cuda_kernel_df = cuda_kernel_df.merge(pd.DataFrame({'shortName':str_df['id'], 'shortNameStr':str_df['value']}), how='left', on='shortName') 
    #cuda_kernel_df = cuda_kernel_df.merge(pd.DataFrame({'demangledName':str_df['id'], 'demangledNameStr':str_df['value']}), how='left', on='demangledName') 
    #cuda_kernel_df = cuda_kernel_df.merge(pd.DataFrame({'correlationId':cuda_api_df['correlationId'], 'API':cuda_api_df['nameStr']}), how='left', on='correlationId') 
    #print(cuda_kernel_df)

    #print(len(cuda_correlation_nvtx_map), len(cuda_graphnode_nvtx_map))
    
    #combine all GPU dataframe into one for simpler iteration.  Not strictly necessary
    cuda_gpu_df = pd.concat([
        cuda_kernel_df[['correlationId', 'start', 'end']],
        cuda_memsets_df[['correlationId', 'start', 'end']],
        cuda_memcopies_df[['correlationId', 'start', 'end']],
        ],ignore_index=True)
    
    #print(cuda_gpu_df)

    nvtx_np_length = nvtx_df.index.max()+1
    #print(nvtx_np_length)

    nvtx_gpu_start_np = np.full( nvtx_np_length, -sys.maxsize, dtype=np.int64 )   
    #nvtx_gpu_start_np = pd.Series(itertools.repeat(-sys.maxsize, nvtx_np_length))
    nvtx_gpu_end_np = np.full( nvtx_np_length, -sys.maxsize, dtype=np.int64 )   
    #nvtx_gpu_end_np = pd.Series(itertools.repeat(-sys.maxsize, nvtx_np_length)) 
   
    expanded = 0
    #print("Starting expansions") 

    for cuda_gpu_row in cuda_gpu_df.itertuples():
        nvtx_active_bag = cuda_correlation_nvtx_map.get(cuda_gpu_row.correlationId, None)
        if nvtx_active_bag == None:
            continue

        #print("CUDA ", cuda_gpu_row.start, cuda_gpu_row.end)
        
        #expand each NVTX with this cuda GPU work
        for nvtx_active_idx in nvtx_active_bag:
                
            #assign if uninitialized
            old_gpu_start = nvtx_gpu_start_np[nvtx_active_idx]
            old_gpu_end = nvtx_gpu_end_np[nvtx_active_idx]
            #print("Before ", nvtx_active_idx, old_gpu_start, old_gpu_end)
            
            if old_gpu_start == -sys.maxsize:
                #first time init
                #print("init ", nvtx_active_idx)
                nvtx_gpu_start_np[nvtx_active_idx] = cuda_gpu_row.start 
                nvtx_gpu_end_np[nvtx_active_idx] = cuda_gpu_row.end
            else: #expand if initialized
                
                #expand start
                if cuda_gpu_row.start < old_gpu_start: #default is sys.maxsize
                    nvtx_gpu_start_np[nvtx_active_idx] = cuda_gpu_row.start 
                #expand end
                if cuda_gpu_row.end > old_gpu_end: #default is -sys.maxsize
                    nvtx_gpu_end_np[nvtx_active_idx] = cuda_gpu_row.end

            #print("After ", nvtx_active_idx, nvtx_gpu_start_np[nvtx_active_idx], nvtx_gpu_end_np[nvtx_active_idx])
                
            expanded += 1
        #done with nvtx_active_bag
    #print("Done expansions: ", expanded) 

    #nvtx_gpu_length_np = nvtx_gpu_end_np - nvtx_gpu_start_np
    #nvtx_gpu_length_ps = pd.Series(nvtx_gpu_length_np, dtype='Int64')
    #nvtx_df["gpu_duration"] = nvtx_gpu_length_ps

    #display(nvtx_gpu_start_np)
    nvtx_gpu_start_ps = pd.Series(nvtx_gpu_start_np, dtype='Int64')
    nvtx_gpu_start_ps.mask( (nvtx_gpu_start_ps == -sys.maxsize) , None, inplace=True )
    #print(nvtx_gpu_start_ps.loc[nvtx_gpu_start_np > 0])
    nvtx_df["gpu_start"] = nvtx_gpu_start_ps
    
    #display(nvtx_gpu_end_np)
    nvtx_gpu_end_ps = pd.Series(nvtx_gpu_end_np, dtype='Int64')
    nvtx_gpu_end_ps.mask( (nvtx_gpu_end_ps == -sys.maxsize) , None, inplace=True )
    #print(nvtx_gpu_end_ps.loc[nvtx_gpu_end_np > 0])
    nvtx_df["gpu_end"] = nvtx_gpu_end_ps

    nvtx_df["gpu_duration"] = nvtx_gpu_end_ps - nvtx_gpu_start_ps
    
    #print(nvtx_df.loc[nvtx_df['gpu_duration']>0])
    return nvtx_df


#print('nvtx_cuda_projection.py imported')
