import pandas as pd
from openpyxl import load_workbook, Workbook
from openpyxl.styles import PatternFill, Font, Alignment
from openpyxl.utils.dataframe import dataframe_to_rows
import numpy as np
from datetime import datetime, date
import os

class SrikaraSettlementComparison:
    def __init__(self):
        self.header_fill = PatternFill(fgColor='1274bd', fill_type='solid')
        self.header_font = Font(color='FFFFFF', bold=True)
        self.alignment = Alignment(horizontal='center', vertical='center')
        
        # Location mapping for Srikara (7 locations as mentioned in requirements)
        self.location_mapping = {
            'SRIKARA MAIN': 'SRIKARA MAIN',
            'SRIKARA BRANCH 1': 'SRIKARA BRANCH 1',
            'SRIKARA BRANCH 2': 'SRIKARA BRANCH 2',
            'SRIKARA BRANCH 3': 'SRIKARA BRANCH 3',
            'SRIKARA BRANCH 4': 'SRIKARA BRANCH 4',
            'SRIKARA BRANCH 5': 'SRIKARA BRANCH 5',
            'SRIKARA BRANCH 6': 'SRIKARA BRANCH 6'
        }
    
    def load_hims_collection_report(self, file_path):
        """Load HIMS collection report and group by payment group name"""
        try:
            hims_df = pd.read_excel(file_path)
            
            # Group by payment group name and date
            # Assuming columns: 'Payment Group Name', 'Amount', 'Date', 'Location'
            hims_summary = hims_df.groupby(['Payment Group Name', 'Date', 'Location']).agg({
                'Amount': 'sum',
                'Transaction Count': 'count' if 'Transaction Count' in hims_df.columns else 'size'
            }).reset_index()
            
            return hims_summary
            
        except Exception as e:
            print(f"Error loading HIMS collection report: {e}")
            return pd.DataFrame()
    
    def load_axis_bank_settlement(self, file_path):
        """Load Axis bank settlement report for Card transactions"""
        try:
            axis_df = pd.read_excel(file_path)
            
            # Process Axis bank settlement data
            # Expected columns: 'Gross Amount', 'MDR', 'GST', 'EMI', 'Date', 'Location'
            # Handle GST as separate columns if needed
            if 'GST' not in axis_df.columns:
                # Sum IGST, SGST, CGST if they exist separately
                gst_cols = ['IGST', 'SGST', 'CGST']
                existing_gst_cols = [col for col in gst_cols if col in axis_df.columns]
                if existing_gst_cols:
                    axis_df['GST'] = axis_df[existing_gst_cols].sum(axis=1)
                else:
                    axis_df['GST'] = 0
            
            # Calculate Net Amount: Net amount = gross amount - (MDR) - (GST) - (EMI)
            axis_df['Net Amount'] = (axis_df['Gross Amount'].fillna(0) - 
                                   axis_df['MDR'].fillna(0) - 
                                   axis_df['GST'].fillna(0) - 
                                   axis_df['EMI'].fillna(0))
            
            # Group by date and location
            axis_summary = axis_df.groupby(['Date', 'Location']).agg({
                'Gross Amount': 'sum',
                'Net Amount': 'sum',
                'MDR': 'sum',
                'GST': 'sum',
                'EMI': 'sum'
            }).reset_index()
            
            axis_summary['Payment Type'] = 'Card'
            
            return axis_summary
            
        except Exception as e:
            print(f"Error loading Axis bank settlement report: {e}")
            return pd.DataFrame()
    
    def load_paytm_settlement(self, file_path):
        """Load Paytm settlement report for Online/UPI transactions"""
        try:
            paytm_df = pd.read_excel(file_path)
            
            # Process Paytm settlement data
            # Expected columns: 'Amount', 'Commission', 'GST', 'Date', 'Location'
            # Calculate Net Amount: Net amount = amount - commission - GST
            paytm_df['Net Amount'] = (paytm_df['Amount'].fillna(0) - 
                                    paytm_df['Commission'].fillna(0) - 
                                    paytm_df['GST'].fillna(0))
            
            # Group by date and location
            paytm_summary = paytm_df.groupby(['Date', 'Location']).agg({
                'Amount': 'sum',
                'Net Amount': 'sum',
                'Commission': 'sum',
                'GST': 'sum'
            }).reset_index()
            
            paytm_summary['Payment Type'] = 'Online'
            
            return paytm_summary
            
        except Exception as e:
            print(f"Error loading Paytm settlement report: {e}")
            return pd.DataFrame()
    
    def load_phonepe_settlement(self, file_path):
        """Load PhonePe settlement report for Others/PhonePe transactions (placeholder)"""
        try:
            # PhonePe is not used now, so create empty placeholder
            phonepe_df = pd.DataFrame(columns=['Date', 'Location', 'Amount', 'Net Amount', 'Commission', 'GST'])
            phonepe_df['Payment Type'] = 'Others/PhonePe'
            return phonepe_df
            
        except Exception as e:
            print(f"Error loading PhonePe settlement report: {e}")
            return pd.DataFrame()
    
    def create_settlement_comparison_report(self, hims_data, axis_data, paytm_data, phonepe_data, report_date):
        """Create settlement comparison report"""
        
        # Combine all settlement data
        settlement_combined = pd.concat([axis_data, paytm_data, phonepe_data], ignore_index=True)
        
        # Create comparison report
        comparison_data = []
        
        # Process by location and date
        for location in self.location_mapping.keys():
            for date_val in hims_data['Date'].unique():
                # HIMS data for this location and date
                hims_card = hims_data[
                    (hims_data['Location'] == location) & 
                    (hims_data['Date'] == date_val) & 
                    (hims_data['Payment Group Name'] == 'Card')
                ]
                
                hims_online = hims_data[
                    (hims_data['Location'] == location) & 
                    (hims_data['Date'] == date_val) & 
                    (hims_data['Payment Group Name'] == 'Online')
                ]
                
                # Settlement data for this location and date
                axis_settlement = axis_data[
                    (axis_data['Location'] == location) & 
                    (axis_data['Date'] == date_val)
                ]
                
                paytm_settlement = paytm_data[
                    (paytm_data['Location'] == location) & 
                    (paytm_data['Date'] == date_val)
                ]
                
                # Create comparison row
                comparison_row = {
                    'Date': date_val,
                    'Location': location,
                    'HIMS_Card_Amount': hims_card['Amount'].sum() if not hims_card.empty else 0,
                    'HIMS_Online_Amount': hims_online['Amount'].sum() if not hims_online.empty else 0,
                    'HIMS_Total_Amount': (hims_card['Amount'].sum() if not hims_card.empty else 0) + 
                                       (hims_online['Amount'].sum() if not hims_online.empty else 0),
                    'Axis_Gross_Amount': axis_settlement['Gross Amount'].sum() if not axis_settlement.empty else 0,
                    'Axis_Net_Amount': axis_settlement['Net Amount'].sum() if not axis_settlement.empty else 0,
                    'Axis_MDR': axis_settlement['MDR'].sum() if not axis_settlement.empty else 0,
                    'Axis_GST': axis_settlement['GST'].sum() if not axis_settlement.empty else 0,
                    'Axis_EMI': axis_settlement['EMI'].sum() if not axis_settlement.empty else 0,
                    'Paytm_Amount': paytm_settlement['Amount'].sum() if not paytm_settlement.empty else 0,
                    'Paytm_Net_Amount': paytm_settlement['Net Amount'].sum() if not paytm_settlement.empty else 0,
                    'Paytm_Commission': paytm_settlement['Commission'].sum() if not paytm_settlement.empty else 0,
                    'Paytm_GST': paytm_settlement['GST'].sum() if not paytm_settlement.empty else 0,
                    'Card_Variance': (hims_card['Amount'].sum() if not hims_card.empty else 0) - 
                                   (axis_settlement['Net Amount'].sum() if not axis_settlement.empty else 0),
                    'Online_Variance': (hims_online['Amount'].sum() if not hims_online.empty else 0) - 
                                     (paytm_settlement['Net Amount'].sum() if not paytm_settlement.empty else 0),
                    'Total_Variance': ((hims_card['Amount'].sum() if not hims_card.empty else 0) + 
                                     (hims_online['Amount'].sum() if not hims_online.empty else 0)) - 
                                    ((axis_settlement['Net Amount'].sum() if not axis_settlement.empty else 0) + 
                                     (paytm_settlement['Net Amount'].sum() if not paytm_settlement.empty else 0))
                }
                
                comparison_data.append(comparison_row)
        
        return pd.DataFrame(comparison_data)
    
    def create_excel_report(self, comparison_df, output_file):
        """Create Excel report with proper formatting"""
        
        # Create workbook
        wb = Workbook()
        
        # Create location-wise sheets
        for location in self.location_mapping.keys():
            location_data = comparison_df[comparison_df['Location'] == location]
            
            if location_data.empty:
                continue
                
            # Create worksheet for location
            ws = wb.create_sheet(title=location)
            
            # Add headers
            headers = [
                'Date', 'Location', 'HIMS Card Amount', 'HIMS Online Amount', 'HIMS Total Amount',
                'Axis Gross Amount', 'Axis Net Amount', 'Axis MDR', 'Axis GST', 'Axis EMI',
                'Paytm Amount', 'Paytm Net Amount', 'Paytm Commission', 'Paytm GST',
                'Card Variance', 'Online Variance', 'Total Variance'
            ]
            
            # Write headers
            for col_num, header in enumerate(headers, 1):
                cell = ws.cell(row=1, column=col_num, value=header)
                cell.fill = self.header_fill
                cell.font = self.header_font
                cell.alignment = self.alignment
            
            # Write data
            for row_num, (_, row_data) in enumerate(location_data.iterrows(), 2):
                for col_num, header in enumerate(headers, 1):
                    col_key = header.replace(' ', '_').replace('HIMS_', 'HIMS_').replace('Axis_', 'Axis_').replace('Paytm_', 'Paytm_')
                    value = row_data.get(col_key, 0)
                    
                    if isinstance(value, (int, float)) and col_num > 2:  # Format currency columns
                        ws.cell(row=row_num, column=col_num, value=f"₹{value:,.2f}")
                    else:
                        ws.cell(row=row_num, column=col_num, value=value)
            
            # Add totals row
            totals_row = len(location_data) + 3
            ws.cell(row=totals_row, column=1, value="TOTAL")
            ws.cell(row=totals_row, column=1).font = Font(bold=True)
            
            # Calculate totals for numeric columns
            for col_num in range(3, len(headers) + 1):
                col_letter = ws.cell(row=1, column=col_num).column_letter
                total_formula = f"=SUM({col_letter}2:{col_letter}{totals_row-1})"
                ws.cell(row=totals_row, column=col_num, value=total_formula)
                ws.cell(row=totals_row, column=col_num).font = Font(bold=True)
        
        # Create consolidated sheet
        consolidated_ws = wb.create_sheet(title="Consolidated Report")
        
        # Group by date for consolidated view
        consolidated_data = comparison_df.groupby('Date').agg({
            'HIMS_Card_Amount': 'sum',
            'HIMS_Online_Amount': 'sum',
            'HIMS_Total_Amount': 'sum',
            'Axis_Gross_Amount': 'sum',
            'Axis_Net_Amount': 'sum',
            'Axis_MDR': 'sum',
            'Axis_GST': 'sum',
            'Axis_EMI': 'sum',
            'Paytm_Amount': 'sum',
            'Paytm_Net_Amount': 'sum',
            'Paytm_Commission': 'sum',
            'Paytm_GST': 'sum',
            'Card_Variance': 'sum',
            'Online_Variance': 'sum',
            'Total_Variance': 'sum'
        }).reset_index()
        
        # Write consolidated data
        consolidated_headers = [
            'Date', 'HIMS Card Amount', 'HIMS Online Amount', 'HIMS Total Amount',
            'Axis Gross Amount', 'Axis Net Amount', 'Axis MDR', 'Axis GST', 'Axis EMI',
            'Paytm Amount', 'Paytm Net Amount', 'Paytm Commission', 'Paytm GST',
            'Card Variance', 'Online Variance', 'Total Variance'
        ]
        
        # Write headers
        for col_num, header in enumerate(consolidated_headers, 1):
            cell = consolidated_ws.cell(row=1, column=col_num, value=header)
            cell.fill = self.header_fill
            cell.font = self.header_font
            cell.alignment = self.alignment
        
        # Write data
        for row_num, (_, row_data) in enumerate(consolidated_data.iterrows(), 2):
            for col_num, header in enumerate(consolidated_headers, 1):
                col_key = header.replace(' ', '_').replace('HIMS_', 'HIMS_').replace('Axis_', 'Axis_').replace('Paytm_', 'Paytm_')
                value = row_data.get(col_key, 0)
                
                if isinstance(value, (int, float)) and col_num > 1:  # Format currency columns
                    consolidated_ws.cell(row=row_num, column=col_num, value=f"₹{value:,.2f}")
                else:
                    consolidated_ws.cell(row=row_num, column=col_num, value=value)
        
        # Add totals row
        totals_row = len(consolidated_data) + 3
        consolidated_ws.cell(row=totals_row, column=1, value="TOTAL")
        consolidated_ws.cell(row=totals_row, column=1).font = Font(bold=True)
        
        # Calculate totals for numeric columns
        for col_num in range(2, len(consolidated_headers) + 1):
            col_letter = consolidated_ws.cell(row=1, column=col_num).column_letter
            total_formula = f"=SUM({col_letter}2:{col_letter}{totals_row-1})"
            consolidated_ws.cell(row=totals_row, column=col_num, value=total_formula)
            consolidated_ws.cell(row=totals_row, column=col_num).font = Font(bold=True)
        
        # Remove default sheet
        if "Sheet" in wb.sheetnames:
            wb.remove(wb["Sheet"])
        
        # Save workbook
        wb.save(output_file)
        print(f"Settlement comparison report saved to: {output_file}")
    
    def generate_settlement_report(self, hims_file, axis_file, paytm_file, phonepe_file=None, output_file=None):
        """Main function to generate settlement comparison report"""
        
        if output_file is None:
            output_file = f"Srikara_Settlement_Comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
        
        print("Loading data files...")
        
        # Load data
        hims_data = self.load_hims_collection_report(hims_file)
        axis_data = self.load_axis_bank_settlement(axis_file)
        paytm_data = self.load_paytm_settlement(paytm_file)
        phonepe_data = self.load_phonepe_settlement(phonepe_file) if phonepe_file else pd.DataFrame()
        
        if hims_data.empty:
            print("Warning: HIMS collection report is empty or could not be loaded")
            return
        
        print("Creating settlement comparison report...")
        
        # Create comparison report
        comparison_df = self.create_settlement_comparison_report(
            hims_data, axis_data, paytm_data, phonepe_data, datetime.now().date()
        )
        
        # Create Excel report
        self.create_excel_report(comparison_df, output_file)
        
        print("Settlement comparison report generated successfully!")
        return output_file

# Example usage
if __name__ == "__main__":
    # Initialize the settlement comparison class
    settlement_comparison = SrikaraSettlementComparison()
    
    # File paths (these should be updated with actual file paths)
    hims_file = "SrikaraintegratedHIS.xlsx"  # HIMS collection report
    axis_file = "MomentsPaySrikara_Bank_new.xlsx"  # Axis bank settlement report
    paytm_file = "Srikaramomentpay06-06.csv"  # Paytm settlement report
    
    # Generate settlement comparison report
    try:
        output_file = settlement_comparison.generate_settlement_report(
            hims_file=hims_file,
            axis_file=axis_file,
            paytm_file=paytm_file
        )
        print(f"Report generated: {output_file}")
    except Exception as e:
        print(f"Error generating report: {e}")
        print("Please ensure all input files exist and have the correct format.")