import heapq
import math
from typing import List, Tuple, Dict, Set

class Node:
    def __init__(self, id: int, x: float, y: float):
        self.id = id
        self.x = x
        self.y = y
    
    def distance_to(self, other: 'Node') -> float:
        """Calculate Euclidean distance to another node"""
        return math.sqrt((self.x - other.x)**2 + (self.y - other.y)**2)
    
    def __repr__(self):
        return f"Node({self.id}, {self.x}, {self.y})"

class OptimalNodePairing:
    def __init__(self, nodes: List[Node]):
        self.nodes = nodes
        if len(nodes) % 2 != 0:
            raise ValueError("Number of nodes must be even for perfect pairing")
    
    def greedy_pairing(self) -> Tuple[List[Tuple[Node, Node]], float]:
        """
        Greedy algorithm: repeatedly find the closest pair among remaining nodes.
        Fast but not guaranteed to be globally optimal.
        Time complexity: O(n^3)
        """
        remaining = set(self.nodes)
        pairs = []
        total_distance = 0.0
        
        while len(remaining) > 0:
            min_distance = float('inf')
            best_pair = None
            
            # Find the closest pair among remaining nodes
            for node1 in remaining:
                for node2 in remaining:
                    if node1 != node2:
                        distance = node1.distance_to(node2)
                        if distance < min_distance:
                            min_distance = distance
                            best_pair = (node1, node2)
            
            if best_pair:
                pairs.append(best_pair)
                total_distance += min_distance
                remaining.remove(best_pair[0])
                remaining.remove(best_pair[1])
        
        return pairs, total_distance
    
    def nearest_neighbor_pairing(self) -> Tuple[List[Tuple[Node, Node]], float]:
        """
        Nearest neighbor approach: for each unpaired node, pair it with its closest unpaired neighbor.
        Time complexity: O(n^2)
        """
        unpaired = set(self.nodes)
        pairs = []
        total_distance = 0.0
        
        while unpaired:
            # Take any unpaired node
            current = next(iter(unpaired))
            unpaired.remove(current)
            
            # Find its nearest unpaired neighbor
            min_distance = float('inf')
            nearest = None
            
            for candidate in unpaired:
                distance = current.distance_to(candidate)
                if distance < min_distance:
                    min_distance = distance
                    nearest = candidate
            
            if nearest:
                pairs.append((current, nearest))
                total_distance += min_distance
                unpaired.remove(nearest)
        
        return pairs, total_distance
    
    def minimum_weight_perfect_matching_approx(self) -> Tuple[List[Tuple[Node, Node]], float]:
        """
        Approximation algorithm using minimum spanning tree approach.
        Better quality than greedy, still polynomial time.
        """
        # For small instances, use brute force
        if len(self.nodes) <= 8:
            return self.brute_force_optimal()
        
        # For larger instances, use improved greedy with local optimization
        pairs, distance = self.greedy_pairing()
        
        # Try local improvements (2-opt style swaps)
        improved = True
        iterations = 0
        max_iterations = 10
        
        while improved and iterations < max_iterations:
            improved = False
            iterations += 1
            
            for i in range(len(pairs)):
                for j in range(i + 1, len(pairs)):
                    # Current pairing: (A,B) and (C,D)
                    A, B = pairs[i]
                    C, D = pairs[j]
                    
                    current_dist = A.distance_to(B) + C.distance_to(D)
                    
                    # Try alternative pairing: (A,C) and (B,D)
                    alt1_dist = A.distance_to(C) + B.distance_to(D)
                    # Try alternative pairing: (A,D) and (B,C)
                    alt2_dist = A.distance_to(D) + B.distance_to(C)
                    
                    if alt1_dist < current_dist and alt1_dist < alt2_dist:
                        pairs[i] = (A, C)
                        pairs[j] = (B, D)
                        improved = True
                        distance = distance - current_dist + alt1_dist
                    elif alt2_dist < current_dist:
                        pairs[i] = (A, D)
                        pairs[j] = (B, C)
                        improved = True
                        distance = distance - current_dist + alt2_dist
        
        return pairs, distance
    
    def brute_force_optimal(self) -> Tuple[List[Tuple[Node, Node]], float]:
        """
        Brute force optimal solution using recursive backtracking.
        Guaranteed optimal but exponential time - only use for small instances.
        Time complexity: O((n-1)!! ≈ 2^(n/2) * (n/2)!)
        """
        def backtrack(remaining_nodes, current_distance):
            if len(remaining_nodes) == 0:
                return [], current_distance
            
            if len(remaining_nodes) == 2:
                pair = (remaining_nodes[0], remaining_nodes[1])
                pair_distance = remaining_nodes[0].distance_to(remaining_nodes[1])
                return [pair], current_distance + pair_distance
            
            best_pairs = None
            best_distance = float('inf')
            
            first_node = remaining_nodes[0]
            
            # Try pairing first node with each other remaining node
            for i in range(1, len(remaining_nodes)):
                partner = remaining_nodes[i]
                pair_distance = first_node.distance_to(partner)
                
                # Recursive call with remaining nodes
                new_remaining = [remaining_nodes[j] for j in range(1, len(remaining_nodes)) if j != i]
                sub_pairs, sub_distance = backtrack(new_remaining, current_distance + pair_distance)
                
                total_distance = sub_distance
                if total_distance < best_distance:
                    best_distance = total_distance
                    best_pairs = [(first_node, partner)] + sub_pairs
            
            return best_pairs, best_distance
        
        pairs, total_distance = backtrack(self.nodes, 0)
        return pairs, total_distance

# Example usage and testing
def create_sample_nodes() -> List[Node]:
    """Create a sample set of nodes for testing"""
    return [
        Node(0, 0.0, 0.0),
        Node(1, 1.0, 1.0),
        Node(2, 3.0, 0.0),
        Node(3, 4.0, 1.0),
        Node(4, 1.5, 2.5),
        Node(5, 2.5, 1.5),
    ]

def test_algorithms():
    """Test and compare different pairing algorithms"""
    nodes = create_sample_nodes()
    pairing = OptimalNodePairing(nodes)
    
    print("Nodes:")
    for node in nodes:
        print(f"  {node}")
    print()
    
    # Test greedy algorithm
    print("=== Greedy Pairing ===")
    pairs, distance = pairing.greedy_pairing()
    print(f"Total distance: {distance:.2f}")
    for i, (n1, n2) in enumerate(pairs):
        print(f"Pair {i+1}: Node {n1.id} - Node {n2.id} (distance: {n1.distance_to(n2):.2f})")
    print()
    
    # Test nearest neighbor
    print("=== Nearest Neighbor Pairing ===")
    pairs, distance = pairing.nearest_neighbor_pairing()
    print(f"Total distance: {distance:.2f}")
    for i, (n1, n2) in enumerate(pairs):
        print(f"Pair {i+1}: Node {n1.id} - Node {n2.id} (distance: {n1.distance_to(n2):.2f})")
    print()
    
    # Test approximation algorithm
    print("=== Approximation Algorithm ===")
    pairs, distance = pairing.minimum_weight_perfect_matching_approx()
    print(f"Total distance: {distance:.2f}")
    for i, (n1, n2) in enumerate(pairs):
        print(f"Pair {i+1}: Node {n1.id} - Node {n2.id} (distance: {n1.distance_to(n2):.2f})")
    print()
    
    # Test brute force (optimal)
    print("=== Brute Force Optimal ===")
    pairs, distance = pairing.brute_force_optimal()
    print(f"Total distance: {distance:.2f}")
    for i, (n1, n2) in enumerate(pairs):
        print(f"Pair {i+1}: Node {n1.id} - Node {n2.id} (distance: {n1.distance_to(n2):.2f})")

if __name__ == "__main__":
    test_algorithms()