#!/usr/bin/env python3
"""
Fix MySQL schema - Add PRIMARY KEY to CREATE TABLE for AUTO_INCREMENT columns
"""

import re
import sys

def fix_schema(input_file, output_file):
    """Fix MySQL schema by adding PRIMARY KEY to CREATE TABLE statements"""
    
    with open(input_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    output_lines = []
    i = 0
    table_pks = {}
    
    # First pass: collect all PRIMARY KEY definitions
    while i < len(lines):
        line = lines[i]
        # Match: ALTER TABLE table_name ADD CONSTRAINT ... PRIMARY KEY (col);
        pk_match = re.search(r'ALTER TABLE `?([^\s`]+)`?\s+ADD CONSTRAINT[^P]+PRIMARY KEY\s*\(([^)]+)\);', line, re.IGNORECASE)
        if pk_match:
            table_name = pk_match.group(1).strip('`').strip()
            pk_cols = [col.strip('`').strip() for col in pk_match.group(2).split(',')]
            table_pks[table_name.lower()] = pk_cols
        i += 1
    
    # Second pass: fix CREATE TABLE statements
    i = 0
    while i < len(lines):
        line = lines[i]
        output_lines.append(line)
        
        # Check if this is a CREATE TABLE line
        create_match = re.search(r'CREATE TABLE `?([^\s`]+)`?', line, re.IGNORECASE)
        if create_match:
            table_name = create_match.group(1).strip('`').strip()
            table_key = table_name.lower()
            
            # Read until we find the closing );
            create_table_lines = [line]
            j = i + 1
            found_auto_inc = False
            auto_inc_col = None
            
            while j < len(lines):
                next_line = lines[j]
                create_table_lines.append(next_line)
                
                # Check for AUTO_INCREMENT
                auto_inc_match = re.search(r'`([^`]+)`\s+[^,\n)]+AUTO_INCREMENT', next_line, re.IGNORECASE)
                if auto_inc_match:
                    found_auto_inc = True
                    auto_inc_col = auto_inc_match.group(1).strip()
                
                # Check if this is the closing ); (can be on same line or separate)
                if re.search(r'\);', next_line) or re.search(r'^\s*\);?\s*$', next_line):
                    # If we have AUTO_INCREMENT and this table has a PRIMARY KEY
                    if found_auto_inc and table_key in table_pks:
                        pk_cols = table_pks[table_key]
                        if auto_inc_col and auto_inc_col in pk_cols:
                            # Check if PRIMARY KEY is not already in CREATE TABLE
                            create_table_text = ''.join(create_table_lines)
                            if 'PRIMARY KEY' not in create_table_text.upper():
                                # Add PRIMARY KEY before closing );
                                pk_cols_str = ', '.join([f'`{col}`' for col in pk_cols])
                                # Replace the last line (closing );) with PRIMARY KEY and closing
                                indent = '    '  # Match the indentation
                                create_table_lines[-1] = f'{indent}PRIMARY KEY ({pk_cols_str}),\n{create_table_lines[-1]}'
                    
                    # Write all CREATE TABLE lines (skip first line as it's already written)
                    output_lines.extend(create_table_lines[1:])
                    i = j
                    break
                
                j += 1
            else:
                # Didn't find closing, continue normally
                i += 1
                continue
        
        i += 1
    
    # Write output
    with open(output_file, 'w', encoding='utf-8') as f:
        f.writelines(output_lines)
    
    print(f"✅ Fixed schema: {output_file}")

if __name__ == "__main__":
    input_file = "mysql/schema_mysql_20251113_185155.sql"
    output_file = "mysql/schema_mysql_20251113_185155.sql"
    
    if len(sys.argv) > 1:
        input_file = sys.argv[1]
    if len(sys.argv) > 2:
        output_file = sys.argv[2]
    
    fix_schema(input_file, output_file)

