import os
import pandas as pd
from sqlalchemy import create_engine, Table, Column, MetaData, String, Integer, Float, Date, DateTime
from sqlalchemy.exc import ProgrammingError
from sqlalchemy import inspect
from elasticsearch import Elasticsearch, helpers
from sqlalchemy.sql import text  
import openai
import re
import warnings
from sqlalchemy import Text

warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("ignore", category=UserWarning)
# Database configurations
MYSQL_DATABASE_URI = "mysql+pymysql://root:dataaegis123@localhost/reconcile_rule_engine"
engine = create_engine(MYSQL_DATABASE_URI)

# MetaData object
metadata = MetaData()

# Elasticsearch configuration
#ELASTICSEARCH_HOST = "http://localhost:9200"  # Replace with your Elasticsearch host
#es = Elasticsearch(ELASTICSEARCH_HOST)

openai.api_key = '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'

def infer_type_with_openai(sample_data):
    """Send the sample data to OpenAI API and infer the SQLAlchemy type."""
    data = ", ".join([str(x) for x in sample_data.tolist()])  # Join the first few rows as a comma-separated string
    prompt = f"Given the following data values: {data}, always suggest the SQLAlchemy data type as VARCHAR(255), regardless of the data."
    try:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=100,
            temperature=0
        )
        print(f"Full OpenAI response: {response}")  # Debugging
        inferred_type = response['choices'][0]['message']['content'].strip()
        print(f"Inferred type: {inferred_type}")  # Debugging
        return String(255)  # Always return VARCHAR(255)
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return String(255)  # Default to VARCHAR(255) if API call fails        

def sanitize_dataframe_columns(dataframe):
    """Sanitize all column names in the dataframe."""
    dataframe.columns = [sanitize_column_name(col) for col in dataframe.columns]
    return dataframe

def sanitize_column_name(col_name):
    """Sanitize column names to lowercase, replace spaces with underscores, and remove special characters."""
    col_name = col_name.lower()  # Convert to lowercase
    col_name = re.sub(r'[^a-z0-9_]', '', col_name)  # Remove any non-alphanumeric characters (except _)
    col_name = col_name.replace(' ', '_')  # Replace spaces with underscores
    return col_name

def infer_sqlalchemy_type(column, sample_data=None):
    """Infer the SQLAlchemy type based on the column data using OpenAI."""
    if sample_data is None:
        sample_data = column.head(5)  # Default to first 5 rows as sample
    sql_type = infer_type_with_openai(sample_data)
    return sql_type

def upload_data_to_mysql(engine, table_name, dataframe, file_path):
    dataframe = sanitize_dataframe_columns(dataframe)
    create_or_update_table(engine, table_name, dataframe)
    dataframe.to_sql(table_name, engine, if_exists="append", index=False)
    print(f"Data from '{file_path}' uploaded to MySQL table '{table_name}'.")

#def create_or_update_table(engine, table_name, dataframe):
#    """Create table if it does not exist and set all columns as VARCHAR(255)."""
#    metadata = MetaData()
#    inspector = inspect(engine)
#    if table_name not in inspector.get_table_names():
#        print(f"Table '{table_name}' does not exist. Creating a new table.")
#        columns = [Column("id", Integer, primary_key=True, autoincrement=True)]
#        for col in dataframe.columns:
#            sanitized_col_name = sanitize_column_name(col)
#            print(f"Creating column '{sanitized_col_name}' as VARCHAR(255)")
#            #columns.append(Column(sanitized_col_name, String(255)))  # Default to VARCHAR(255)
#            if len(dataframe.columns) > 150:
                # Use TEXT for all columns to avoid row size limit
#                columns.append(Column(sanitized_col_name, String(1000).with_variant(Text(), 'mysql')))
#            else:
#                columns.append(Column(sanitized_col_name, String(255)))
#        Table(table_name, metadata, *columns).create(engine)
#        print(f"Table '{table_name}' created with all columns as VARCHAR(255).")
#        return

    # Check for missing columns and add them
#    existing_table = Table(table_name, metadata, autoload_with=engine)
#    for col in dataframe.columns:
#        sanitized_col_name = sanitize_column_name(col)
#        if sanitized_col_name not in [c.name for c in existing_table.columns]:
#            print(f"Adding missing column '{sanitized_col_name}' as VARCHAR(255)")
#            with engine.connect() as conn:
#                try:
#                    # Use text() to wrap the raw SQL string
#                    conn.execute(text(f"ALTER TABLE `{table_name}` ADD COLUMN `{sanitized_col_name}` VARCHAR(255)"))
#                except Exception as e:
#                    print(f"Error adding column '{sanitized_col_name}' to table '{table_name}': {e}")

def create_or_update_table(engine, table_name, dataframe):
    """Create table if it does not exist and set all columns as TEXT to avoid row size overflow."""
    metadata = MetaData()
    inspector = inspect(engine)

    if table_name not in inspector.get_table_names():
        print(f"Table '{table_name}' does not exist. Creating a new table.")
        columns = [Column("id", Integer, primary_key=True, autoincrement=True)]

        for col in dataframe.columns:
            sanitized_col_name = sanitize_column_name(col)
            print(f"Creating column '{sanitized_col_name}' as TEXT() to avoid row size limit")
            columns.append(Column(sanitized_col_name, Text()))

        Table(table_name, metadata, *columns).create(engine)
        print(f"Table '{table_name}' created with all columns as TEXT().")
        return

    # If table exists, check for missing columns and add them
    existing_table = Table(table_name, metadata, autoload_with=engine)
    existing_column_names = [c.name for c in existing_table.columns]

    for col in dataframe.columns:
        sanitized_col_name = sanitize_column_name(col)
        if sanitized_col_name not in existing_column_names:
            print(f"Adding missing column '{sanitized_col_name}' as TEXT()")
            with engine.connect() as conn:
                try:
                    conn.execute(text(f"ALTER TABLE `{table_name}` ADD COLUMN `{sanitized_col_name}` TEXT"))
                except Exception as e:
                    print(f"Error adding column '{sanitized_col_name}' to table '{table_name}': {e}")

def clean_dataframe(dataframe):
    dataframe = dataframe.where(pd.notna(dataframe), None)
    for col in dataframe.select_dtypes(include=["datetime"]):
        dataframe[col] = dataframe[col].dt.strftime("%Y-%m-%d")
    return dataframe

def process_and_upload(file_path, sheet_name, header=0):
    """Process the Excel file and upload data to MySQL."""
    # Derive table name from the file name (without extension)
    table_name = os.path.splitext(os.path.basename(file_path))[0].lower().replace(" ", "_")
    try:
        dataframe = pd.read_excel(file_path, sheet_name=sheet_name, header=header)
    except ValueError:
        print(f"Sheet '{sheet_name}' not found in file '{file_path}'. Skipping...")
        return

    # Remove columns with empty header names
    dataframe = dataframe.loc[:, dataframe.columns != '']

    # Check if the DataFrame is empty
    if dataframe.empty:
        print(f"The DataFrame for file '{file_path}' is empty. Skipping...")
        return

    print(f"Original columns: {dataframe.columns}")

    # Drop unnamed or empty columns
    dataframe = dataframe.loc[:, ~dataframe.columns.str.contains('^Unnamed')]  # Drop unnamed columns
    dataframe = dataframe.dropna(how='all', axis=1)  # Drop columns where all values are NaN
    print(f"Cleaned columns: {dataframe.columns}")

    # Replace "N/A" with None
    dataframe.replace("N/A", None, inplace=True)

    # Sanitize column names
    dataframe = sanitize_dataframe_columns(dataframe)
    print(f"Sanitized columns: {dataframe.columns}")

    for col in dataframe.columns:
        print(f"Checking column: {col}, type: {type(dataframe[col])}, dtype: {dataframe.dtypes[col]}")
        if dataframe.dtypes[col] == "object":
            try:
                # Attempt to convert to datetime, but keep original values if conversion fails
                dataframe[col] = pd.to_datetime(dataframe[col], errors="ignore")
            except Exception as e:
                print(f"Error processing column '{col}': {e}")

    # Replace NaN with None for database compatibility
    dataframe = dataframe.where(pd.notna(dataframe), None)

    print("DataFrame before upload:")
    print(dataframe.head())

    # Create or update table and upload data
    create_or_update_table(engine, table_name, dataframe)
    upload_data_to_mysql(engine, table_name, dataframe, file_path)

    # Drop unnamed or empty columns
    #dataframe = dataframe.loc[:, ~dataframe.columns.str.contains('^Unnamed')]  # Drop unnamed columns
    #dataframe = dataframe.dropna(how='all', axis=1)  # Drop columns where all values are NaN
    #print(f"Cleaned columns: {dataframe.columns}")

    # Replace "N/A" with None
    #dataframe.replace("N/A", None, inplace=True)

    # Sanitize column names
    #dataframe = sanitize_dataframe_columns(dataframe)
    #print(f"Sanitized columns: {dataframe.columns}")

    #for col in dataframe.columns:
     #   print(f"Checking column: {col}, type: {type(dataframe[col])}, dtype: {dataframe.dtypes[col]}")
     #   if dataframe.dtypes[col] == "object":
     #       try:
                # Attempt to convert to datetime, but keep original values if conversion fails
     #           dataframe[col] = pd.to_datetime(dataframe[col], errors="ignore")
     #       except Exception as e:
     #           print(f"Error processing column '{col}': {e}")

    # Replace NaN with None for database compatibility
    #dataframe = dataframe.where(pd.notna(dataframe), None)

    #print("DataFrame before upload:")
    #print(dataframe.head())

    # Create or update table and upload data
    #create_or_update_table(engine, table_name, dataframe)
    #upload_data_to_mysql(engine, table_name, dataframe, file_path)

def traverse_and_process(root_folder, sheet_name, header=0):
    """Recursively traverse the folder and process Excel files."""
    print(f"Starting to traverse: {root_folder}")
    for dirpath, _, filenames in os.walk(root_folder):
        print(f"Checking directory: {dirpath}")
        for file in filenames:
            if file.endswith(".xlsx"):
                file_path = os.path.join(dirpath, file)
                print(f"Processing file: {file_path}")
                process_and_upload(file_path, sheet_name, header=header)  # Pass the correct file_path here
    print("Finished processing all files.")

if __name__ == "__main__":
    root_folder = r"/home/kavya/Continental/RoyalCare_1"
    sheet_name = 'Sheet1'
    traverse_and_process(root_folder, sheet_name, header=0)
