#!/usr/bin/env python3
"""
Merchant API Automated Testing Suite
For use with the UPI PSP Platform API Documentation

This script provides comprehensive testing for all merchant-related APIs
and can be used in conjunction with the API docs live testing interface.
"""

import requests
import json
import time
import os
import sys
import argparse
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from datetime import datetime

@dataclass
class TestResult:
    test_name: str
    endpoint: str
    method: str
    status_code: int
    expected_status: int
    response_time: float
    success: bool
    error_message: Optional[str] = None
    response_data: Optional[Dict] = None

class MerchantAPITester:
    """Comprehensive testing suite for Merchant APIs"""
    
    def __init__(self, base_url: str, api_token: str, partner_id: str, verbose: bool = False):
        self.base_url = base_url.rstrip('/')
        self.partner_id = partner_id
        self.verbose = verbose
        self.results: List[TestResult] = []
        
        self.session = requests.Session()
        self.session.headers.update({
            'Authorization': f'Bearer {api_token}',
            'Content-Type': 'application/json',
            'User-Agent': 'MerchantAPI-Tester/1.0'
        })
        
        # Test data containers
        self.created_merchants: List[str] = []
        self.created_qr_codes: List[str] = []
    
    def log(self, message: str, level: str = "INFO"):
        """Enhanced logging with colors and timestamps"""
        colors = {
            "INFO": "\033[92m",
            "ERROR": "\033[91m", 
            "WARN": "\033[93m",
            "DEBUG": "\033[94m",
            "SUCCESS": "\033[95m",
            "END": "\033[0m"
        }
        
        timestamp = datetime.now().strftime("%H:%M:%S")
        prefix = f"[{timestamp}] {colors.get(level, '')}{level}{colors['END']}"
        print(f"{prefix} {message}")
        
        if self.verbose and level == "DEBUG":
            print(f"    └─ {message}")
    
    def make_request(self, method: str, endpoint: str, data: Optional[Dict] = None, 
                    expected_status: int = 200) -> TestResult:
        """Make HTTP request and return structured result"""
        start_time = time.time()
        test_name = f"{method} {endpoint}"
        
        try:
            if data:
                response = self.session.request(method, f"{self.base_url}{endpoint}", json=data)
            else:
                response = self.session.request(method, f"{self.base_url}{endpoint}")
            
            response_time = (time.time() - start_time) * 1000  # Convert to ms
            
            success = response.status_code == expected_status
            response_data = None
            error_message = None
            
            try:
                response_data = response.json()
            except json.JSONDecodeError:
                response_data = {"raw_response": response.text}
            
            if not success:
                error_message = f"Expected {expected_status}, got {response.status_code}"
                if response_data:
                    error_message += f": {response_data}"
            
            result = TestResult(
                test_name=test_name,
                endpoint=endpoint,
                method=method,
                status_code=response.status_code,
                expected_status=expected_status,
                response_time=response_time,
                success=success,
                error_message=error_message,
                response_data=response_data
            )
            
            self.results.append(result)
            
            if success:
                self.log(f"✓ {test_name} ({response_time:.1f}ms)", "SUCCESS")
            else:
                self.log(f"✗ {test_name} - {error_message}", "ERROR")
            
            if self.verbose:
                self.log(f"Response: {json.dumps(response_data, indent=2)}", "DEBUG")
            
            return result
            
        except requests.exceptions.RequestException as e:
            response_time = (time.time() - start_time) * 1000
            error_message = str(e)
            
            result = TestResult(
                test_name=test_name,
                endpoint=endpoint,
                method=method,
                status_code=0,
                expected_status=expected_status,
                response_time=response_time,
                success=False,
                error_message=error_message
            )
            
            self.results.append(result)
            self.log(f"✗ {test_name} - Request failed: {error_message}", "ERROR")
            return result
    
    def test_create_merchant(self, merchant_type: str = "retail") -> Optional[str]:
        """Test merchant creation with different business types"""
        merchant_templates = {
            "retail": {
                "merchant_code": f"RETAIL{int(time.time())}",
                "brand_name": "Automated Test Retail Store",
                "merchant_vpa": f"retailtest{int(time.time())}@upi",
                "business_type": "RETAIL",
                "corridors": ["DOMESTIC"],
                "contact_email": "test@retailstore.com",
                "contact_phone": "+91-9876543210"
            },
            "restaurant": {
                "merchant_code": f"REST{int(time.time())}",
                "brand_name": "Automated Test Restaurant",
                "merchant_vpa": f"resttest{int(time.time())}@upi",
                "business_type": "FOOD_AND_BEVERAGE",
                "corridors": ["DOMESTIC"],
                "contact_email": "manager@testrestaurant.com",
                "contact_phone": "+91-9876543211"
            },
            "ecommerce": {
                "merchant_code": f"ECOM{int(time.time())}",
                "brand_name": "Automated Test E-commerce",
                "merchant_vpa": f"ecomtest{int(time.time())}@upi",
                "business_type": "ECOMMERCE",
                "corridors": ["DOMESTIC", "SGD-INR"],
                "contact_email": "support@testecommerce.com",
                "contact_phone": "+91-9876543212",
                "website": "https://testecommerce.com"
            },
            "international": {
                "merchant_code": f"INTL{int(time.time())}",
                "brand_name": "Automated Test International",
                "merchant_vpa": f"intltest{int(time.time())}@upi",
                "business_type": "IMPORT_EXPORT",
                "corridors": ["SGD-INR", "USD-INR", "EUR-INR"],
                "contact_email": "finance@testinternational.com",
                "contact_phone": "+91-9876543213",
                "fbar_number": f"FBAR{int(time.time())}"
            }
        }
        
        merchant_data = merchant_templates.get(merchant_type, merchant_templates["retail"])
        
        self.log(f"Creating {merchant_type} merchant...", "INFO")
        result = self.make_request(
            "POST", 
            f"/api/v1/partners/{self.partner_id}/merchants",
            merchant_data,
            201
        )
        
        if result.success and result.response_data:
            merchant_id = result.response_data.get('data', {}).get('id')
            if merchant_id:
                self.created_merchants.append(merchant_id)
                self.log(f"Merchant created with ID: {merchant_id}", "SUCCESS")
                return merchant_id
        
        return None
    
    def test_list_merchants(self, filters: Optional[Dict] = None):
        """Test listing merchants with various filters"""
        endpoint = f"/api/v1/partners/{self.partner_id}/merchants"
        
        if filters:
            query_params = "&".join([f"{k}={v}" for k, v in filters.items()])
            endpoint += f"?{query_params}"
        
        self.log("Testing merchant listing...", "INFO")
        return self.make_request("GET", endpoint, None, 200)
    
    def test_get_merchant(self, merchant_id: str):
        """Test getting specific merchant details"""
        self.log(f"Getting merchant details for {merchant_id}...", "INFO")
        return self.make_request(
            "GET",
            f"/api/v1/partners/{self.partner_id}/merchants/{merchant_id}",
            None,
            200
        )
    
    def test_update_merchant(self, merchant_id: str):
        """Test updating merchant information"""
        update_data = {
            "brand_name": "Updated Test Store Name",
            "contact_email": "updated@teststore.com"
        }
        
        self.log(f"Updating merchant {merchant_id}...", "INFO")
        return self.make_request(
            "PUT",
            f"/api/v1/partners/{self.partner_id}/merchants/{merchant_id}",
            update_data,
            200
        )
    
    def test_merchant_status_updates(self, merchant_id: str):
        """Test various status transitions"""
        statuses = ["SUSPENDED", "ACTIVE", "INACTIVE", "ACTIVE"]
        
        for status in statuses:
            self.log(f"Updating merchant status to {status}...", "INFO")
            result = self.make_request(
                "PATCH",
                f"/api/v1/partners/{self.partner_id}/merchants/{merchant_id}/status",
                {"status": status},
                200
            )
            
            if not result.success:
                break
            
            time.sleep(0.5)  # Brief pause between status changes
    
    def test_merchant_validation(self, merchant_id: str):
        """Test merchant validation"""
        self.log(f"Validating merchant {merchant_id}...", "INFO")
        return self.make_request(
            "GET",
            f"/api/v1/partners/{self.partner_id}/merchants/{merchant_id}/validate",
            None,
            200
        )
    
    def test_transaction_limits(self, merchant_id: str):
        """Test transaction limit checking with various amounts"""
        test_amounts = ["50.00", "1000.00", "5000.00", "25000.00", "100000.00"]
        
        for amount in test_amounts:
            self.log(f"Checking limits for amount ₹{amount}...", "INFO")
            result = self.make_request(
                "POST",
                f"/api/v1/partners/{self.partner_id}/merchants/{merchant_id}/check-limits",
                {"amount": amount},
                200
            )
            
            if result.success and result.response_data:
                allowed = result.response_data.get('allowed', False)
                status = "✓ Allowed" if allowed else "✗ Blocked"
                self.log(f"  Amount ₹{amount}: {status}", "INFO")
    
    def test_merchant_search(self):
        """Test merchant search functionality"""
        search_queries = [
            {"q": "Test"},
            {"status": "ACTIVE"},
            {"business_type": "RETAIL"},
            {"q": "Automated", "status": "ACTIVE"}
        ]
        
        for query in search_queries:
            query_str = "&".join([f"{k}={v}" for k, v in query.items()])
            self.log(f"Searching merchants with query: {query_str}", "INFO")
            self.make_request(
                "GET",
                f"/api/v1/partners/{self.partner_id}/merchants-search?{query_str}",
                None,
                200
            )
    
    def test_merchant_statistics(self):
        """Test merchant statistics endpoint"""
        self.log("Getting merchant statistics...", "INFO")
        return self.make_request(
            "GET",
            f"/api/v1/partners/{self.partner_id}/merchants-stats",
            None,
            200
        )
    
    def test_qr_generation(self, merchant_id: str):
        """Test QR code generation"""
        qr_data = {
            "merchant_id": merchant_id,
            "amount": "500.00",
            "currency": "INR",
            "note": "Automated test payment",
            "expiry_minutes": 15
        }
        
        self.log("Generating QR code...", "INFO")
        result = self.make_request(
            "POST",
            "/api/v1/qr-generate",
            qr_data,
            201
        )
        
        if result.success and result.response_data:
            qr_id = result.response_data.get('qr_id')
            if qr_id:
                self.created_qr_codes.append(qr_id)
                self.log(f"QR code generated with ID: {qr_id}", "SUCCESS")
                return qr_id
        
        return None
    
    def test_qr_status(self, qr_id: str):
        """Test QR code status checking"""
        self.log(f"Checking QR status for {qr_id}...", "INFO")
        return self.make_request(
            "GET",
            f"/api/v1/qr-status/{qr_id}",
            None,
            200
        )
    
    def test_error_scenarios(self):
        """Test various error scenarios"""
        self.log("Testing error scenarios...", "WARN")
        
        # Test 1: Invalid merchant creation (duplicate code)
        if self.created_merchants:
            first_merchant_result = self.make_request(
                "GET",
                f"/api/v1/partners/{self.partner_id}/merchants/{self.created_merchants[0]}",
                None,
                200
            )
            
            if first_merchant_result.success and first_merchant_result.response_data:
                merchant_code = first_merchant_result.response_data.get('merchant_code')
                if merchant_code:
                    self.log("Testing duplicate merchant code error...", "WARN")
                    self.make_request(
                        "POST",
                        f"/api/v1/partners/{self.partner_id}/merchants",
                        {
                            "merchant_code": merchant_code,
                            "brand_name": "Duplicate Test",
                            "merchant_vpa": "duplicate@upi",
                            "business_type": "RETAIL"
                        },
                        409  # Expect conflict
                    )
        
        # Test 2: Invalid VPA format
        self.log("Testing invalid VPA format error...", "WARN")
        self.make_request(
            "POST",
            f"/api/v1/partners/{self.partner_id}/merchants",
            {
                "merchant_code": f"INVALID{int(time.time())}",
                "brand_name": "Invalid VPA Test",
                "merchant_vpa": "invalid-vpa-format",
                "business_type": "RETAIL"
            },
            422  # Expect validation error
        )
        
        # Test 3: Missing required fields
        self.log("Testing missing required fields error...", "WARN")
        self.make_request(
            "POST",
            f"/api/v1/partners/{self.partner_id}/merchants",
            {
                "merchant_code": f"INCOMPLETE{int(time.time())}"
            },
            422  # Expect validation error
        )
        
        # Test 4: Non-existent merchant access
        self.log("Testing non-existent merchant access...", "WARN")
        self.make_request(
            "GET",
            f"/api/v1/partners/{self.partner_id}/merchants/NON_EXISTENT_ID",
            None,
            404  # Expect not found
        )
        
        # Test 5: Invalid partner access
        self.log("Testing invalid partner access...", "WARN")
        self.make_request(
            "GET",
            "/api/v1/partners/INVALID_PARTNER/merchants",
            None,
            403  # Expect forbidden
        )
    
    def test_rate_limiting(self):
        """Test rate limiting behavior"""
        self.log("Testing rate limiting (this may take time)...", "WARN")
        
        # Make rapid requests to test rate limiting
        rapid_requests = 105  # Slightly over the 100/min limit
        start_time = time.time()
        
        for i in range(rapid_requests):
            result = self.make_request(
                "GET",
                f"/api/v1/partners/{self.partner_id}/merchants",
                None,
                200
            )
            
            # If we hit rate limit, expect 429
            if result.status_code == 429:
                elapsed = time.time() - start_time
                self.log(f"Rate limit hit after {i+1} requests in {elapsed:.1f}s", "SUCCESS")
                break
            
            if i % 20 == 0:
                self.log(f"Completed {i+1}/{rapid_requests} requests...", "INFO")
    
    def run_comprehensive_test_suite(self):
        """Run the complete test suite"""
        self.log("="*60, "INFO")
        self.log("STARTING COMPREHENSIVE MERCHANT API TEST SUITE", "INFO")
        self.log("="*60, "INFO")
        
        start_time = time.time()
        
        try:
            # Phase 1: Basic CRUD Operations
            self.log("\n🏗️  PHASE 1: BASIC CRUD OPERATIONS", "INFO")
            
            # Create different types of merchants
            retail_merchant = self.test_create_merchant("retail")
            restaurant_merchant = self.test_create_merchant("restaurant")
            ecommerce_merchant = self.test_create_merchant("ecommerce")
            international_merchant = self.test_create_merchant("international")
            
            # List merchants
            self.test_list_merchants()
            self.test_list_merchants({"status": "ACTIVE"})
            self.test_list_merchants({"business_type": "RETAIL"})
            
            # Get merchant details
            if retail_merchant:
                self.test_get_merchant(retail_merchant)
                self.test_update_merchant(retail_merchant)
            
            # Phase 2: Status Management
            self.log("\n🔧 PHASE 2: STATUS MANAGEMENT", "INFO")
            
            if restaurant_merchant:
                self.test_merchant_status_updates(restaurant_merchant)
            
            # Phase 3: Validation and Limits
            self.log("\n✅ PHASE 3: VALIDATION AND LIMITS", "INFO")
            
            if ecommerce_merchant:
                self.test_merchant_validation(ecommerce_merchant)
                self.test_transaction_limits(ecommerce_merchant)
            
            # Phase 4: Search and Analytics
            self.log("\n🔍 PHASE 4: SEARCH AND ANALYTICS", "INFO")
            
            self.test_merchant_search()
            self.test_merchant_statistics()
            
            # Phase 5: QR Code Management
            self.log("\n🏷️  PHASE 5: QR CODE MANAGEMENT", "INFO")
            
            if international_merchant:
                qr_id = self.test_qr_generation(international_merchant)
                if qr_id:
                    self.test_qr_status(qr_id)
            
            # Phase 6: Error Scenarios
            self.log("\n❌ PHASE 6: ERROR SCENARIOS", "INFO")
            
            self.test_error_scenarios()
            
            # Phase 7: Performance Testing
            if len(sys.argv) > 1 and "--include-performance" in sys.argv:
                self.log("\n⚡ PHASE 7: PERFORMANCE TESTING", "INFO")
                self.test_rate_limiting()
            
        except KeyboardInterrupt:
            self.log("Test suite interrupted by user", "WARN")
        except Exception as e:
            self.log(f"Test suite failed with error: {str(e)}", "ERROR")
        
        finally:
            # Generate test report
            self.generate_test_report(time.time() - start_time)
    
    def generate_test_report(self, total_time: float):
        """Generate comprehensive test report"""
        self.log("\n" + "="*60, "INFO")
        self.log("TEST SUITE REPORT", "INFO")
        self.log("="*60, "INFO")
        
        total_tests = len(self.results)
        passed_tests = sum(1 for r in self.results if r.success)
        failed_tests = total_tests - passed_tests
        
        success_rate = (passed_tests / total_tests * 100) if total_tests > 0 else 0
        avg_response_time = sum(r.response_time for r in self.results) / total_tests if total_tests > 0 else 0
        
        # Summary stats
        self.log(f"Total Tests: {total_tests}", "INFO")
        self.log(f"Passed: {passed_tests} ({success_rate:.1f}%)", "SUCCESS" if success_rate > 90 else "WARN")
        self.log(f"Failed: {failed_tests}", "ERROR" if failed_tests > 0 else "INFO")
        self.log(f"Total Time: {total_time:.1f}s", "INFO")
        self.log(f"Average Response Time: {avg_response_time:.1f}ms", "INFO")
        
        # Failed tests details
        if failed_tests > 0:
            self.log("\n❌ FAILED TESTS:", "ERROR")
            for result in self.results:
                if not result.success:
                    self.log(f"  • {result.test_name}: {result.error_message}", "ERROR")
        
        # Performance summary
        fast_tests = sum(1 for r in self.results if r.response_time < 200)
        slow_tests = sum(1 for r in self.results if r.response_time > 1000)
        
        self.log(f"\n⚡ PERFORMANCE SUMMARY:", "INFO")
        self.log(f"Fast responses (<200ms): {fast_tests}", "SUCCESS")
        self.log(f"Slow responses (>1000ms): {slow_tests}", "WARN" if slow_tests > 0 else "INFO")
        
        # Resources created
        self.log(f"\n🏗️  RESOURCES CREATED:", "INFO")
        self.log(f"Merchants: {len(self.created_merchants)}", "INFO")
        self.log(f"QR Codes: {len(self.created_qr_codes)}", "INFO")
        
        # Generate JSON report
        report_data = {
            "summary": {
                "total_tests": total_tests,
                "passed_tests": passed_tests,
                "failed_tests": failed_tests,
                "success_rate": success_rate,
                "total_time": total_time,
                "avg_response_time": avg_response_time
            },
            "results": [
                {
                    "test_name": r.test_name,
                    "endpoint": r.endpoint,
                    "method": r.method,
                    "status_code": r.status_code,
                    "expected_status": r.expected_status,
                    "response_time": r.response_time,
                    "success": r.success,
                    "error_message": r.error_message
                }
                for r in self.results
            ],
            "created_resources": {
                "merchants": self.created_merchants,
                "qr_codes": self.created_qr_codes
            }
        }
        
        # Save report to file
        report_filename = f"merchant_api_test_report_{int(time.time())}.json"
        with open(report_filename, 'w') as f:
            json.dump(report_data, f, indent=2)
        
        self.log(f"\n📋 Detailed report saved to: {report_filename}", "INFO")
        self.log("="*60, "INFO")

def main():
    parser = argparse.ArgumentParser(description='Merchant API Testing Suite')
    parser.add_argument('--base-url', default=os.getenv('BASE_URL', 'https://api.mercurypay.com'),
                       help='Base API URL')
    parser.add_argument('--api-token', default=os.getenv('API_TOKEN'),
                       help='API authentication token')
    parser.add_argument('--partner-id', default=os.getenv('PARTNER_ID', 'PARTNER_001'),
                       help='Partner ID for testing')
    parser.add_argument('--verbose', '-v', action='store_true',
                       help='Enable verbose output')
    parser.add_argument('--include-performance', action='store_true',
                       help='Include performance/rate limiting tests')
    
    args = parser.parse_args()
    
    if not args.api_token:
        print("Error: API token is required. Set API_TOKEN environment variable or use --api-token")
        sys.exit(1)
    
    tester = MerchantAPITester(
        base_url=args.base_url,
        api_token=args.api_token,
        partner_id=args.partner_id,
        verbose=args.verbose
    )
    
    tester.run_comprehensive_test_suite()

if __name__ == "__main__":
    main()
