#!/usr/bin/env python3
"""
Robust line-by-line parser to fix MySQL schema:
- Adds PRIMARY KEY to CREATE TABLE statements with AUTO_INCREMENT
- Comments out duplicate ALTER TABLE PRIMARY KEY statements
"""

import re

def fix_schema(input_file, output_file):
    print(f"Reading {input_file}...")
    with open(input_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    # Step 1: Build PRIMARY KEY map from ALTER TABLE statements
    print("Building PRIMARY KEY map from ALTER TABLE statements...")
    pk_map = {}
    i = 0
    while i < len(lines):
        line = lines[i]
        # Match: ALTER TABLE table_name (may span multiple lines)
        alter_match = re.search(r'ALTER TABLE `?([^\s`]+)`?', line, re.IGNORECASE)
        if alter_match:
            table_name = alter_match.group(1).strip('`').strip().lower()
            # Read next lines to find PRIMARY KEY
            j = i
            pk_found = False
            pk_cols = []
            constraint_name = None
            
            while j < len(lines) and j < i + 5:  # Look ahead up to 5 lines
                next_line = lines[j]
                # Check for PRIMARY KEY
                pk_match = re.search(r'PRIMARY KEY\s*\(([^)]+)\)', next_line, re.IGNORECASE)
                if pk_match:
                    pk_cols = [col.strip('`').strip() for col in pk_match.group(1).split(',')]
                    pk_found = True
                    break
                if next_line.strip().endswith(';') and not pk_found:
                    break
                j += 1
            
            if pk_found and pk_cols:
                pk_map[table_name] = pk_cols
                print(f"  Found PK for {table_name}: {pk_cols}")
        i += 1
    
    print(f"Found {len(pk_map)} PRIMARY KEY definitions\n")
    
    # Step 2: Process CREATE TABLE statements line by line
    print("Processing CREATE TABLE statements...")
    output_lines = []
    i = 0
    tables_fixed = 0
    
    while i < len(lines):
        line = lines[i]
        
        # Check if this is a CREATE TABLE line
        create_match = re.search(r'CREATE TABLE `?([^\s`]+)`?', line, re.IGNORECASE)
        if create_match:
            table_name = create_match.group(1).strip('`').strip().lower()
            table_key = table_name
            print(f"  Found CREATE TABLE: {table_name} at line {i+1}")
            
            # Collect all lines for this CREATE TABLE
            table_lines = [line]
            j = i + 1
            found_auto_inc = False
            auto_inc_col = None
            last_column_line_idx = -1
            has_primary_key = False
            
            # Read until we find the closing );
            while j < len(lines):
                next_line = lines[j]
                table_lines.append(next_line)
                
                # Check for AUTO_INCREMENT
                if re.search(r'AUTO_INCREMENT', next_line, re.IGNORECASE):
                    found_auto_inc = True
                    # Extract column name
                    col_match = re.search(r'`([^`]+)`', next_line)
                    if col_match:
                        auto_inc_col = col_match.group(1).strip()
                        print(f"  Table {table_name}: Found AUTO_INCREMENT column '{auto_inc_col}' at line {j+1}")
                
                # Check if PRIMARY KEY already exists in CREATE TABLE
                if re.search(r'PRIMARY KEY', next_line, re.IGNORECASE):
                    has_primary_key = True
                
                # Track the last column definition (before PRIMARY KEY or );
                if (not re.search(r'\);', next_line) and 
                    not re.search(r'PRIMARY KEY', next_line, re.IGNORECASE) and
                    not re.search(r'ENGINE=', next_line, re.IGNORECASE) and
                    next_line.strip() and 
                    not next_line.strip().startswith('--')):
                    # This looks like a column definition
                    if re.search(r'`[^`]+`', next_line):
                        last_column_line_idx = len(table_lines) - 1
                
                # Check for closing ); (can be ); or just ) on a line, or ); at end of line)
                closing_match = re.search(r'\);', next_line) or (next_line.strip() == ')' and j+1 < len(lines) and lines[j+1].strip().startswith(';'))
                if closing_match:
                    # Process this CREATE TABLE
                    print(f"  Processing {table_name} at line {j+1}: found_auto_inc={found_auto_inc}, has_primary_key={has_primary_key}, table_key='{table_key}', in_pk_map={table_key in pk_map}")
                    if found_auto_inc and not has_primary_key and table_key in pk_map:
                        pk_cols = pk_map[table_key]
                        if auto_inc_col and auto_inc_col in pk_cols:
                            print(f"  Fixing {table_name}: Adding PRIMARY KEY ({', '.join(pk_cols)})")
                            
                            # Add comma to last column if needed
                            if last_column_line_idx >= 0:
                                last_col_line = table_lines[last_column_line_idx]
                                # Check if it already ends with comma
                                if not last_col_line.rstrip().endswith(',') and not last_col_line.rstrip().endswith('('):
                                    table_lines[last_column_line_idx] = last_col_line.rstrip() + ',\n'
                            
                            # Insert PRIMARY KEY before );
                            pk_str = ', '.join([f'`{col}`' for col in pk_cols])
                            
                            # Handle different closing patterns
                            closing_line = table_lines[-1]
                            if closing_line.strip() == ');':
                                # ); is on its own line - replace with PRIMARY KEY line + );
                                table_lines[-1] = f'    PRIMARY KEY ({pk_str})\n);\n'
                            elif closing_line.strip() == ')':
                                # ) is on its own line
                                table_lines[-1] = f'    PRIMARY KEY ({pk_str})\n);\n'
                            elif ');' in closing_line:
                                # ); is on same line as other content
                                table_lines[-1] = re.sub(
                                    r'\);',
                                    f',\n    PRIMARY KEY ({pk_str})\n);',
                                    closing_line
                                )
                            else:
                                # No ); found, add it
                                table_lines.append(f'    PRIMARY KEY ({pk_str})\n);\n')
                            
                            tables_fixed += 1
                            print(f"    ✓ Added PRIMARY KEY to {table_name}")
                        else:
                            print(f"  Skipping {table_name}: AUTO_INCREMENT col '{auto_inc_col}' not in PK {pk_cols}")
                    elif found_auto_inc and not has_primary_key:
                        print(f"  Skipping {table_name}: No PK found in map (table_key: {table_key})")
                    
                    # Write all table lines (first line already in output_lines)
                    output_lines.extend(table_lines[1:])
                    i = j
                    break
                
                j += 1
        else:
            # Not a CREATE TABLE, just add the line
            output_lines.append(line)
        
        i += 1
    
    print(f"\nFixed {tables_fixed} CREATE TABLE statements\n")
    
    # Step 3: Comment out ALTER TABLE PRIMARY KEY for tables that have it in CREATE TABLE
    print("Commenting out duplicate ALTER TABLE PRIMARY KEY statements...")
    final_output = []
    altered_tables = set()
    
    # First pass: identify which tables have PRIMARY KEY in CREATE TABLE
    content_str = ''.join(output_lines)
    for table_name in pk_map.keys():
        # Check if CREATE TABLE has PRIMARY KEY
        create_pattern = rf'CREATE TABLE `?{re.escape(table_name)}`?\s*\([^)]+PRIMARY KEY[^)]+\);'
        if re.search(create_pattern, content_str, re.IGNORECASE | re.DOTALL):
            altered_tables.add(table_name)
            print(f"  Table {table_name} has PRIMARY KEY in CREATE TABLE")
    
    # Second pass: comment out ALTER TABLE statements
    in_alter_block = False
    alter_table_name = None
    
    for line in output_lines:
        # Check if this starts an ALTER TABLE PRIMARY KEY block
        alter_match = re.search(
            r'ALTER TABLE `?([^\s`]+)`?\s+ADD CONSTRAINT[^P]+PRIMARY KEY',
            line,
            re.IGNORECASE
        )
        if alter_match:
            alter_table_name = alter_match.group(1).strip('`').strip().lower()
            if alter_table_name in altered_tables:
                line = '-- ' + line.rstrip() + ' -- Removed: PRIMARY KEY already in CREATE TABLE\n'
                in_alter_block = True
        elif in_alter_block:
            # Continue commenting until we hit the semicolon
            if not line.strip().startswith('--'):
                line = '-- ' + line
            if line.strip().endswith(';'):
                in_alter_block = False
                alter_table_name = None
        
        final_output.append(line)
    
    # Write output
    print(f"\nWriting fixed schema to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        f.writelines(final_output)
    
    print(f"✅ Successfully fixed schema!")
    print(f"   - Fixed {tables_fixed} CREATE TABLE statements")
    print(f"   - Commented out {len(altered_tables)} duplicate ALTER TABLE statements")

if __name__ == "__main__":
    import sys
    
    input_file = "mysql/schema_mysql_20251113_185155.sql"
    output_file = "mysql/schema_mysql_20251113_185155.sql"
    
    if len(sys.argv) > 1:
        input_file = sys.argv[1]
    if len(sys.argv) > 2:
        output_file = sys.argv[2]
    
    fix_schema(input_file, output_file)

