Spaces:
Running on Zero
Running on Zero
| import os | |
| import time | |
| import heapq | |
| from collections import defaultdict, deque | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import numpy as np | |
| import trimesh | |
| from numba import njit, prange | |
| from scipy.sparse import coo_matrix, csr_matrix | |
| from scipy.sparse.csgraph import connected_components | |
| from scipy.spatial import cKDTree | |
| DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD = 1e-4 | |
| def _get_face_centroids(mesh): | |
| """Return face centroids in a compact dtype for NN queries.""" | |
| return np.asarray(mesh.triangles_center, dtype=np.float32) | |
| def _query_nearest(tree, query_points): | |
| """ | |
| Query nearest neighbors with all available CPU workers. | |
| SciPy exposes ``workers`` in cKDTree.query; using -1 lets it parallelize. | |
| """ | |
| return tree.query(query_points, k=1, workers=-1) | |
| def _mesh_face_connected_components(mesh): | |
| """Return mesh face connected components as dense int64 arrays.""" | |
| components = trimesh.graph.connected_components( | |
| edges=mesh.face_adjacency, | |
| nodes=np.arange(len(mesh.faces)), | |
| min_len=1, | |
| ) | |
| return [np.array(list(component), dtype=np.int64) for component in components] | |
| def assign_undefined_faces_to_nearest_defined(mesh, face_part_ids): | |
| """ | |
| Fill undefined (-1) face labels by nearest defined face label. | |
| For scalability this uses a KD-tree over defined face centroids. | |
| """ | |
| face_part_ids_filled = face_part_ids.copy() | |
| undefined_faces = np.flatnonzero(face_part_ids_filled == -1) | |
| if len(undefined_faces) == 0: | |
| return face_part_ids_filled | |
| defined_faces = np.flatnonzero(face_part_ids_filled != -1) | |
| if len(defined_faces) == 0: | |
| return face_part_ids_filled | |
| centroids = _get_face_centroids(mesh) | |
| tree = cKDTree(centroids[defined_faces]) | |
| _, nearest_local = _query_nearest(tree, centroids[undefined_faces]) | |
| nearest_local = np.atleast_1d(nearest_local) | |
| nearest_defined_faces = defined_faces[nearest_local] | |
| face_part_ids_filled[undefined_faces] = face_part_ids_filled[nearest_defined_faces] | |
| return face_part_ids_filled | |
| def refine_part_ids_strict(mesh, face_part_ids): | |
| """ | |
| Refine face part IDs by treating each connected component (CC) independently. | |
| For each CC: | |
| - If it has any defined labels, all faces are overwritten with the dominant | |
| part ID by surface area. | |
| - If all faces are undefined (-1), assign all faces from the nearest defined CC. | |
| Args: | |
| mesh: trimesh object | |
| face_part_ids: part ID for each face [num_faces] | |
| Returns: | |
| refined_face_part_ids: refined part ID for each face [num_faces] | |
| """ | |
| face_part_ids = np.asarray(face_part_ids, dtype=np.int32).copy() | |
| mesh_components = _mesh_face_connected_components(mesh) | |
| component_dominant_part_id = {} | |
| undefined_components = [] | |
| # For each connected component, find the dominant part ID by surface area. | |
| for comp_idx, component in enumerate(mesh_components): | |
| if len(component) == 0: | |
| continue | |
| component_part_ids = face_part_ids[component] | |
| valid_mask = component_part_ids != -1 | |
| if not np.any(valid_mask): | |
| undefined_components.append(comp_idx) | |
| continue | |
| valid_part_ids = component_part_ids[valid_mask].astype(np.int64) | |
| valid_face_areas = mesh.area_faces[component[valid_mask]] | |
| unique_part_ids, inverse = np.unique(valid_part_ids, return_inverse=True) | |
| part_area_sums = np.bincount(inverse, weights=valid_face_areas) | |
| dominant_part_id = int(unique_part_ids[np.argmax(part_area_sums)]) | |
| component_dominant_part_id[comp_idx] = dominant_part_id | |
| face_part_ids[component] = dominant_part_id | |
| # Components that are entirely undefined are assigned from the nearest | |
| # component that has a defined dominant part label. | |
| if undefined_components and component_dominant_part_id: | |
| centroids = _get_face_centroids(mesh) | |
| face_to_component = np.full(len(mesh.faces), -1, dtype=np.int32) | |
| defined_face_chunks = [] | |
| undefined_face_chunks = [] | |
| for comp_idx in component_dominant_part_id.keys(): | |
| comp_faces = mesh_components[comp_idx] | |
| face_to_component[comp_faces] = comp_idx | |
| defined_face_chunks.append(comp_faces) | |
| for comp_idx in undefined_components: | |
| comp_faces = mesh_components[comp_idx] | |
| face_to_component[comp_faces] = comp_idx | |
| undefined_face_chunks.append(comp_faces) | |
| defined_faces = np.concatenate(defined_face_chunks, axis=0) | |
| undefined_faces = np.concatenate(undefined_face_chunks, axis=0) | |
| tree = cKDTree(centroids[defined_faces]) | |
| nearest_dist, nearest_local = _query_nearest(tree, centroids[undefined_faces]) | |
| nearest_local = np.atleast_1d(nearest_local) | |
| nearest_dist = np.atleast_1d(nearest_dist) | |
| undefined_face_components = face_to_component[undefined_faces] | |
| nearest_defined_faces = defined_faces[nearest_local] | |
| nearest_defined_components = face_to_component[nearest_defined_faces] | |
| order = np.argsort(undefined_face_components, kind="mergesort") | |
| sorted_undefined_comps = undefined_face_components[order] | |
| sorted_dists = nearest_dist[order] | |
| sorted_nearest_defined_comps = nearest_defined_components[order] | |
| unique_undefined_comps, group_start = np.unique(sorted_undefined_comps, return_index=True) | |
| group_end = np.concatenate([group_start[1:], np.array([len(sorted_undefined_comps)])]) | |
| for comp_idx, start, end in zip(unique_undefined_comps, group_start, group_end): | |
| best_local = start + int(np.argmin(sorted_dists[start:end])) | |
| nearest_defined_comp = int(sorted_nearest_defined_comps[best_local]) | |
| face_part_ids[mesh_components[int(comp_idx)]] = component_dominant_part_id[nearest_defined_comp] | |
| return face_part_ids | |
| def _majority_vote_face_part_ids(mesh, part_ids, face_indices): | |
| """Assigns each sampled face the majority label of its query points. | |
| Faces that never received a sampled query point remain `-1`. | |
| Args: | |
| mesh: trimesh object | |
| part_ids: part IDs for each sampled point [num_points] | |
| face_indices: which face each point lies on (-1 means on edge) [num_points] | |
| Returns: | |
| Face labels with unresolved faces marked as `-1`. | |
| """ | |
| num_faces = len(mesh.faces) | |
| face_part_ids = np.full(num_faces, -1, dtype=np.int32) | |
| face_to_points = {} | |
| for point_idx, face_idx in enumerate(face_indices): | |
| if face_idx == -1: | |
| continue | |
| if face_idx not in face_to_points: | |
| face_to_points[face_idx] = [] | |
| face_to_points[face_idx].append(part_ids[point_idx]) | |
| for face_idx, point_part_ids in face_to_points.items(): | |
| counts = np.bincount(point_part_ids) | |
| majority_part_id = np.argmax(counts) | |
| face_part_ids[face_idx] = majority_part_id | |
| return face_part_ids | |
| def find_unrefined_part_ids_for_faces(mesh, part_ids, face_indices): | |
| """Builds the face labels used before connected-component refinement. | |
| This matches the user's requested "unrefined" representation: | |
| 1. majority-vote query labels per sampled face | |
| 2. nearest-neighbor fill for unsampled faces | |
| Args: | |
| mesh: trimesh object | |
| part_ids: part IDs for each sampled point [num_points] | |
| face_indices: which face each point lies on (-1 means on edge) [num_points] | |
| Returns: | |
| Face labels after majority vote plus nearest-face fill. | |
| """ | |
| initial_face_part_ids = _majority_vote_face_part_ids(mesh, part_ids, face_indices) | |
| return assign_undefined_faces_to_nearest_defined(mesh, initial_face_part_ids) | |
| def refine_face_part_ids(mesh, face_part_ids, strict=False): | |
| """Apply the base face-label post-processing stage. | |
| Args: | |
| mesh: trimesh object | |
| face_part_ids: Face labels after the unrefined majority-vote stage. | |
| strict: Whether to use strict refinement. When False, the unrefined | |
| labels are returned unchanged. | |
| Returns: | |
| Base per-face part IDs used for the final segmentation export. | |
| """ | |
| if strict: | |
| return refine_part_ids_strict(mesh, face_part_ids) | |
| return np.asarray(face_part_ids, dtype=np.int32).copy() | |
| def point_to_triangle_distance_batch(points_batch, tri_verts_batch): | |
| """Compute squared distances from batched points to batched triangles.""" | |
| v0 = tri_verts_batch[:, 0, :] | |
| v1 = tri_verts_batch[:, 1, :] | |
| v2 = tri_verts_batch[:, 2, :] | |
| edge0 = v1 - v0 | |
| edge1 = v2 - v0 | |
| normals = np.cross(edge0, edge1) | |
| normal_norms = np.linalg.norm(normals, axis=1, keepdims=True) | |
| valid_mask = normal_norms[:, 0] >= 1e-10 | |
| normals = normals / np.maximum(normal_norms, 1e-10) | |
| to_points = points_batch - v0[:, np.newaxis, :] | |
| dist_to_plane = np.einsum("mnk,mk->mn", to_points, normals) | |
| points_on_plane = ( | |
| points_batch - dist_to_plane[:, :, np.newaxis] * normals[:, np.newaxis, :] | |
| ) | |
| v = points_on_plane - v0[:, np.newaxis, :] | |
| d00 = np.einsum("mk,mk->m", edge0, edge0) | |
| d01 = np.einsum("mk,mk->m", edge0, edge1) | |
| d11 = np.einsum("mk,mk->m", edge1, edge1) | |
| d20 = np.einsum("mnk,mk->mn", v, edge0) | |
| d21 = np.einsum("mnk,mk->mn", v, edge1) | |
| denom = d00 * d11 - d01 * d01 | |
| valid_denom = np.abs(denom) >= 1e-10 | |
| denom = np.where(valid_denom, denom, 1.0)[:, np.newaxis] | |
| bary_v = (d11[:, np.newaxis] * d20 - d01[:, np.newaxis] * d21) / denom | |
| bary_w = (d00[:, np.newaxis] * d21 - d01[:, np.newaxis] * d20) / denom | |
| bary_u = 1.0 - bary_v - bary_w | |
| inside_mask = (bary_u >= -1e-10) & (bary_v >= -1e-10) & (bary_w >= -1e-10) | |
| inside_mask = inside_mask & valid_mask[:, np.newaxis] & valid_denom[:, np.newaxis] | |
| distances_sq = dist_to_plane * dist_to_plane | |
| outside_mask = ~inside_mask | |
| if np.any(outside_mask): | |
| edge = edge0 | |
| edge_len_sq = d00 | |
| ap = points_batch - v0[:, np.newaxis, :] | |
| t = np.clip( | |
| np.einsum("mnk,mk->mn", ap, edge) | |
| / np.maximum(edge_len_sq[:, np.newaxis], 1e-10), | |
| 0, | |
| 1, | |
| ) | |
| proj = v0[:, np.newaxis, :] + t[:, :, np.newaxis] * edge[:, np.newaxis, :] | |
| diff = points_batch - proj | |
| dist_edge0_sq = np.einsum("mnk,mnk->mn", diff, diff) | |
| edge = v2 - v1 | |
| edge_len_sq = np.einsum("mk,mk->m", edge, edge) | |
| ap = points_batch - v1[:, np.newaxis, :] | |
| t = np.clip( | |
| np.einsum("mnk,mk->mn", ap, edge) | |
| / np.maximum(edge_len_sq[:, np.newaxis], 1e-10), | |
| 0, | |
| 1, | |
| ) | |
| proj = v1[:, np.newaxis, :] + t[:, :, np.newaxis] * edge[:, np.newaxis, :] | |
| diff = points_batch - proj | |
| dist_edge1_sq = np.einsum("mnk,mnk->mn", diff, diff) | |
| edge = v0 - v2 | |
| edge_len_sq = np.einsum("mk,mk->m", edge, edge) | |
| ap = points_batch - v2[:, np.newaxis, :] | |
| t = np.clip( | |
| np.einsum("mnk,mk->mn", ap, edge) | |
| / np.maximum(edge_len_sq[:, np.newaxis], 1e-10), | |
| 0, | |
| 1, | |
| ) | |
| proj = v2[:, np.newaxis, :] + t[:, :, np.newaxis] * edge[:, np.newaxis, :] | |
| diff = points_batch - proj | |
| dist_edge2_sq = np.einsum("mnk,mnk->mn", diff, diff) | |
| min_edge_dist_sq = np.minimum( | |
| dist_edge0_sq, | |
| np.minimum(dist_edge1_sq, dist_edge2_sq), | |
| ) | |
| distances_sq = np.where(outside_mask, min_edge_dist_sq, distances_sq) | |
| return distances_sq | |
| def resolve_point_prompt_face_ids( | |
| mesh, | |
| point_prompts, | |
| *, | |
| exact_batch_size=8192, | |
| ): | |
| """Resolve each point prompt to the nearest mesh face in the same coordinate frame. | |
| This is used for backward compatibility when saved prompt-face IDs are unavailable. | |
| The implementation is exact but keeps the work bounded by: | |
| - initializing the best candidate from a face-centroid KD-tree | |
| - pruning exact triangle checks with face AABB lower bounds | |
| """ | |
| point_prompts = np.asarray(point_prompts, dtype=np.float32) | |
| if point_prompts.ndim != 2 or point_prompts.shape[1] != 3: | |
| raise ValueError( | |
| "point_prompts must have shape (num_prompts, 3), " | |
| f"got {point_prompts.shape}" | |
| ) | |
| if point_prompts.shape[0] == 0: | |
| return np.zeros((0,), dtype=np.int64) | |
| face_verts = np.asarray(mesh.triangles, dtype=np.float64) | |
| if face_verts.shape[0] == 0: | |
| raise ValueError("cannot resolve point-prompt faces on a mesh with zero faces") | |
| bbox_mins = np.min(face_verts, axis=1) | |
| bbox_maxs = np.max(face_verts, axis=1) | |
| face_centroids = np.asarray(mesh.triangles_center, dtype=np.float64) | |
| centroid_tree = cKDTree(face_centroids) | |
| face_ids = np.zeros((point_prompts.shape[0],), dtype=np.int64) | |
| all_face_ids = np.arange(len(face_verts), dtype=np.int64) | |
| for prompt_idx, point_prompt in enumerate(point_prompts.astype(np.float64, copy=False)): | |
| _, seed_face_local = _query_nearest(centroid_tree, point_prompt[None, :]) | |
| seed_face_id = int(np.atleast_1d(seed_face_local)[0]) | |
| best_face_id = seed_face_id | |
| best_sq = float( | |
| point_to_triangle_distance_batch( | |
| np.broadcast_to(point_prompt, (1, 1, 3)), | |
| face_verts[[seed_face_id]], | |
| )[0, 0] | |
| ) | |
| axis_gap = np.maximum( | |
| np.maximum(bbox_mins - point_prompt[None, :], point_prompt[None, :] - bbox_maxs), | |
| 0.0, | |
| ) | |
| lower_bound_sq = np.einsum("ij,ij->i", axis_gap, axis_gap) | |
| candidate_face_ids = all_face_ids[lower_bound_sq < best_sq] | |
| if candidate_face_ids.size == 0: | |
| face_ids[prompt_idx] = best_face_id | |
| continue | |
| for start in range(0, len(candidate_face_ids), int(exact_batch_size)): | |
| batch_face_ids = candidate_face_ids[start:start + int(exact_batch_size)] | |
| batch_dist_sq = point_to_triangle_distance_batch( | |
| np.broadcast_to(point_prompt, (len(batch_face_ids), 1, 3)), | |
| face_verts[batch_face_ids], | |
| )[:, 0] | |
| batch_best_local = int(np.argmin(batch_dist_sq)) | |
| batch_best_sq = float(batch_dist_sq[batch_best_local]) | |
| if batch_best_sq < best_sq: | |
| best_sq = batch_best_sq | |
| best_face_id = int(batch_face_ids[batch_best_local]) | |
| face_ids[prompt_idx] = best_face_id | |
| return face_ids | |
| def segment_segment_distance_sq_batch(p1, q1, p2, q2): | |
| """Compute squared distances between batched 3D line segments.""" | |
| eps = 1e-12 | |
| u = q1 - p1 | |
| v = q2 - p2 | |
| w = p1 - p2 | |
| a = np.einsum("ij,ij->i", u, u) | |
| b = np.einsum("ij,ij->i", u, v) | |
| c = np.einsum("ij,ij->i", v, v) | |
| d = np.einsum("ij,ij->i", u, w) | |
| e = np.einsum("ij,ij->i", v, w) | |
| det = a * c - b * b | |
| s_n = np.empty_like(det) | |
| t_n = np.empty_like(det) | |
| s_d = np.empty_like(det) | |
| t_d = np.empty_like(det) | |
| parallel_mask = det < eps | |
| non_parallel_mask = ~parallel_mask | |
| s_n[parallel_mask] = 0.0 | |
| s_d[parallel_mask] = 1.0 | |
| t_n[parallel_mask] = e[parallel_mask] | |
| t_d[parallel_mask] = c[parallel_mask] | |
| s_n[non_parallel_mask] = ( | |
| b[non_parallel_mask] * e[non_parallel_mask] | |
| - c[non_parallel_mask] * d[non_parallel_mask] | |
| ) | |
| t_n[non_parallel_mask] = ( | |
| a[non_parallel_mask] * e[non_parallel_mask] | |
| - b[non_parallel_mask] * d[non_parallel_mask] | |
| ) | |
| s_d[non_parallel_mask] = det[non_parallel_mask] | |
| t_d[non_parallel_mask] = det[non_parallel_mask] | |
| mask = non_parallel_mask & (s_n < 0.0) | |
| s_n[mask] = 0.0 | |
| t_n[mask] = e[mask] | |
| t_d[mask] = c[mask] | |
| mask = non_parallel_mask & (s_n > s_d) | |
| s_n[mask] = s_d[mask] | |
| t_n[mask] = e[mask] + b[mask] | |
| t_d[mask] = c[mask] | |
| mask = t_n < 0.0 | |
| t_n[mask] = 0.0 | |
| s_n[mask] = -d[mask] | |
| s_d[mask] = a[mask] | |
| mask2 = mask & (s_n < 0.0) | |
| s_n[mask2] = 0.0 | |
| mask2 = mask & (s_n > s_d) | |
| s_n[mask2] = s_d[mask2] | |
| mask = t_n > t_d | |
| t_n[mask] = t_d[mask] | |
| s_n[mask] = -d[mask] + b[mask] | |
| s_d[mask] = a[mask] | |
| mask2 = mask & (s_n < 0.0) | |
| s_n[mask2] = 0.0 | |
| mask2 = mask & (s_n > s_d) | |
| s_n[mask2] = s_d[mask2] | |
| sc = np.zeros_like(s_n) | |
| tc = np.zeros_like(t_n) | |
| valid_s = np.abs(s_d) > eps | |
| valid_t = np.abs(t_d) > eps | |
| sc[valid_s] = s_n[valid_s] / s_d[valid_s] | |
| tc[valid_t] = t_n[valid_t] / t_d[valid_t] | |
| delta = w + sc[:, np.newaxis] * u - tc[:, np.newaxis] * v | |
| return np.einsum("ij,ij->i", delta, delta) | |
| def segment_intersects_triangle_batch(seg_start, seg_end, tri_verts_batch, eps=1e-10): | |
| """Test batched segment-triangle intersections.""" | |
| direction = seg_end - seg_start | |
| v0 = tri_verts_batch[:, 0, :] | |
| v1 = tri_verts_batch[:, 1, :] | |
| v2 = tri_verts_batch[:, 2, :] | |
| edge1 = v1 - v0 | |
| edge2 = v2 - v0 | |
| pvec = np.cross(direction, edge2) | |
| det = np.einsum("ij,ij->i", edge1, pvec) | |
| non_parallel = np.abs(det) > eps | |
| inv_det = np.zeros_like(det) | |
| inv_det[non_parallel] = 1.0 / det[non_parallel] | |
| tvec = seg_start - v0 | |
| u = np.einsum("ij,ij->i", tvec, pvec) * inv_det | |
| qvec = np.cross(tvec, edge1) | |
| v = np.einsum("ij,ij->i", direction, qvec) * inv_det | |
| t = np.einsum("ij,ij->i", edge2, qvec) * inv_det | |
| return ( | |
| non_parallel | |
| & (u >= -eps) | |
| & (v >= -eps) | |
| & (u + v <= 1.0 + eps) | |
| & (t >= -eps) | |
| & (t <= 1.0 + eps) | |
| ) | |
| def triangle_pairs_within_threshold_batch(tri_a_batch, tri_b_batch, threshold_sq): | |
| """Return mask of triangle pairs with exact distance < threshold.""" | |
| num_pairs = len(tri_a_batch) | |
| if num_pairs == 0: | |
| return np.zeros(0, dtype=bool) | |
| adjacent = np.zeros(num_pairs, dtype=bool) | |
| edge_indices = ((0, 1), (1, 2), (2, 0)) | |
| min_vv_sq = np.full(num_pairs, np.inf, dtype=tri_a_batch.dtype) | |
| for ia in range(3): | |
| pa = tri_a_batch[:, ia, :] | |
| for ib in range(3): | |
| pb = tri_b_batch[:, ib, :] | |
| diff = pa - pb | |
| vv_sq = np.einsum("ij,ij->i", diff, diff) | |
| min_vv_sq = np.minimum(min_vv_sq, vv_sq) | |
| adjacent |= min_vv_sq < threshold_sq | |
| remaining_mask = ~adjacent | |
| if not np.any(remaining_mask): | |
| return adjacent | |
| remaining_idx = np.flatnonzero(remaining_mask) | |
| tri_a_rem = tri_a_batch[remaining_idx] | |
| tri_b_rem = tri_b_batch[remaining_idx] | |
| d_a_to_b_sq = point_to_triangle_distance_batch(tri_a_rem, tri_b_rem) | |
| d_b_to_a_sq = point_to_triangle_distance_batch(tri_b_rem, tri_a_rem) | |
| min_pt_sq = np.minimum( | |
| np.min(d_a_to_b_sq, axis=1), | |
| np.min(d_b_to_a_sq, axis=1), | |
| ) | |
| pt_adjacent = min_pt_sq < threshold_sq | |
| if np.any(pt_adjacent): | |
| adjacent[remaining_idx[pt_adjacent]] = True | |
| remaining_mask = ~adjacent | |
| if not np.any(remaining_mask): | |
| return adjacent | |
| remaining_idx = np.flatnonzero(remaining_mask) | |
| tri_a_rem = tri_a_batch[remaining_idx] | |
| tri_b_rem = tri_b_batch[remaining_idx] | |
| min_edge_sq = np.full(len(remaining_idx), np.inf, dtype=tri_a_rem.dtype) | |
| for a0, a1 in edge_indices: | |
| p1 = tri_a_rem[:, a0, :] | |
| q1 = tri_a_rem[:, a1, :] | |
| for b0, b1 in edge_indices: | |
| p2 = tri_b_rem[:, b0, :] | |
| q2 = tri_b_rem[:, b1, :] | |
| edge_dist_sq = segment_segment_distance_sq_batch(p1, q1, p2, q2) | |
| min_edge_sq = np.minimum(min_edge_sq, edge_dist_sq) | |
| edge_adjacent = min_edge_sq < threshold_sq | |
| if np.any(edge_adjacent): | |
| adjacent[remaining_idx[edge_adjacent]] = True | |
| remaining_mask = ~adjacent | |
| if not np.any(remaining_mask): | |
| return adjacent | |
| remaining_idx = np.flatnonzero(remaining_mask) | |
| tri_a_rem = tri_a_batch[remaining_idx] | |
| tri_b_rem = tri_b_batch[remaining_idx] | |
| intersects = np.zeros(len(remaining_idx), dtype=bool) | |
| for a0, a1 in edge_indices: | |
| intersects |= segment_intersects_triangle_batch( | |
| tri_a_rem[:, a0, :], | |
| tri_a_rem[:, a1, :], | |
| tri_b_rem, | |
| ) | |
| intersects |= segment_intersects_triangle_batch( | |
| tri_b_rem[:, a0, :], | |
| tri_b_rem[:, a1, :], | |
| tri_a_rem, | |
| ) | |
| if np.any(intersects): | |
| adjacent[remaining_idx[intersects]] = True | |
| return adjacent | |
| def _triangle_pair_distance_sq_batch(tri_a_batch, tri_b_batch): | |
| """Compute exact triangle-triangle squared distances for a batch.""" | |
| num_pairs = len(tri_a_batch) | |
| if num_pairs == 0: | |
| return np.zeros(0, dtype=np.float64) | |
| edge_indices = ((0, 1), (1, 2), (2, 0)) | |
| min_sq = np.full(num_pairs, np.inf, dtype=np.float64) | |
| d_a_to_b_sq = point_to_triangle_distance_batch(tri_a_batch, tri_b_batch) | |
| d_b_to_a_sq = point_to_triangle_distance_batch(tri_b_batch, tri_a_batch) | |
| min_sq = np.minimum(min_sq, np.min(d_a_to_b_sq, axis=1)) | |
| min_sq = np.minimum(min_sq, np.min(d_b_to_a_sq, axis=1)) | |
| for a0, a1 in edge_indices: | |
| p1 = tri_a_batch[:, a0, :] | |
| q1 = tri_a_batch[:, a1, :] | |
| for b0, b1 in edge_indices: | |
| p2 = tri_b_batch[:, b0, :] | |
| q2 = tri_b_batch[:, b1, :] | |
| min_sq = np.minimum( | |
| min_sq, | |
| segment_segment_distance_sq_batch(p1, q1, p2, q2), | |
| ) | |
| intersects = np.zeros(num_pairs, dtype=bool) | |
| for a0, a1 in edge_indices: | |
| intersects |= segment_intersects_triangle_batch( | |
| tri_a_batch[:, a0, :], | |
| tri_a_batch[:, a1, :], | |
| tri_b_batch, | |
| ) | |
| intersects |= segment_intersects_triangle_batch( | |
| tri_b_batch[:, a0, :], | |
| tri_b_batch[:, a1, :], | |
| tri_a_batch, | |
| ) | |
| min_sq[intersects] = 0.0 | |
| return min_sq | |
| def generate_candidate_pairs_sweep_numba( | |
| order, | |
| mins_a, | |
| maxs_a, | |
| mins_b, | |
| maxs_b, | |
| upper_bounds, | |
| distance_threshold, | |
| ): | |
| """Generate exact bbox candidate pairs using the reviewed parallel sweep-line logic.""" | |
| n_faces = len(order) | |
| counts = np.zeros(n_faces, dtype=np.int64) | |
| for i in prange(n_faces): | |
| upper_bound = upper_bounds[i] | |
| if upper_bound <= i + 1: | |
| continue | |
| min_ai = mins_a[i] | |
| max_ai = maxs_a[i] | |
| min_bi = mins_b[i] | |
| max_bi = maxs_b[i] | |
| local_count = 0 | |
| for j in range(i + 1, upper_bound): | |
| if min_ai - maxs_a[j] >= distance_threshold: | |
| continue | |
| if mins_a[j] - max_ai >= distance_threshold: | |
| continue | |
| if min_bi - maxs_b[j] >= distance_threshold: | |
| continue | |
| if mins_b[j] - max_bi >= distance_threshold: | |
| continue | |
| local_count += 1 | |
| counts[i] = local_count | |
| offsets = np.empty(n_faces, dtype=np.int64) | |
| total_count = 0 | |
| for i in range(n_faces): | |
| offsets[i] = total_count | |
| total_count += counts[i] | |
| candidate_pairs = np.empty((total_count, 2), dtype=np.int64) | |
| for i in prange(n_faces): | |
| upper_bound = upper_bounds[i] | |
| if upper_bound <= i + 1: | |
| continue | |
| min_ai = mins_a[i] | |
| max_ai = maxs_a[i] | |
| min_bi = mins_b[i] | |
| max_bi = maxs_b[i] | |
| face_i = order[i] | |
| out_idx = offsets[i] | |
| for j in range(i + 1, upper_bound): | |
| if min_ai - maxs_a[j] >= distance_threshold: | |
| continue | |
| if mins_a[j] - max_ai >= distance_threshold: | |
| continue | |
| if min_bi - maxs_b[j] >= distance_threshold: | |
| continue | |
| if mins_b[j] - max_bi >= distance_threshold: | |
| continue | |
| face_j = order[j] | |
| if face_i < face_j: | |
| candidate_pairs[out_idx, 0] = face_i | |
| candidate_pairs[out_idx, 1] = face_j | |
| else: | |
| candidate_pairs[out_idx, 0] = face_j | |
| candidate_pairs[out_idx, 1] = face_i | |
| out_idx += 1 | |
| return candidate_pairs | |
| def filter_adjacent_pairs_batch(batch_pairs, verts, faces, threshold_sq): | |
| """Filter candidate pairs to those with exact triangle distance < threshold.""" | |
| if len(batch_pairs) == 0: | |
| return batch_pairs | |
| face_i_indices = batch_pairs[:, 0] | |
| face_j_indices = batch_pairs[:, 1] | |
| face_i_vids = faces[face_i_indices] | |
| face_j_vids = faces[face_j_indices] | |
| adjacent_mask = np.any( | |
| face_i_vids[:, :, np.newaxis] == face_j_vids[:, np.newaxis, :], | |
| axis=(1, 2), | |
| ) | |
| remaining_mask = ~adjacent_mask | |
| if np.any(remaining_mask): | |
| remaining_idx = np.flatnonzero(remaining_mask) | |
| face_i_vids_rem = face_i_vids[remaining_idx] | |
| face_j_vids_rem = face_j_vids[remaining_idx] | |
| face_i_verts_rem = verts[face_i_vids_rem] | |
| face_j_verts_rem = verts[face_j_vids_rem] | |
| mins_i = np.min(face_i_verts_rem, axis=1) | |
| maxs_i = np.max(face_i_verts_rem, axis=1) | |
| mins_j = np.min(face_j_verts_rem, axis=1) | |
| maxs_j = np.max(face_j_verts_rem, axis=1) | |
| axis_gap = np.maximum(mins_i - maxs_j, mins_j - maxs_i) | |
| axis_gap = np.maximum(axis_gap, 0.0) | |
| lower_bound_sq = np.einsum("ij,ij->i", axis_gap, axis_gap) | |
| maybe_adjacent = lower_bound_sq < threshold_sq | |
| if np.any(maybe_adjacent): | |
| geom_idx = remaining_idx[maybe_adjacent] | |
| tri_adjacent_mask = triangle_pairs_within_threshold_batch( | |
| face_i_verts_rem[maybe_adjacent], | |
| face_j_verts_rem[maybe_adjacent], | |
| threshold_sq, | |
| ) | |
| if np.any(tri_adjacent_mask): | |
| adjacent_mask[geom_idx[tri_adjacent_mask]] = True | |
| return batch_pairs[adjacent_mask] | |
| def build_face_edge_adjacency(mesh): | |
| """Build adjacency from exact mesh face-edge connectivity.""" | |
| face_adjacency = defaultdict(set) | |
| for face_i, face_j in np.asarray(mesh.face_adjacency, dtype=np.int64): | |
| face_adjacency[int(face_i)].add(int(face_j)) | |
| face_adjacency[int(face_j)].add(int(face_i)) | |
| return face_adjacency | |
| def build_face_distance_adjacency( | |
| verts, | |
| faces, | |
| distance_threshold=DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD, | |
| max_distance_workers=None, | |
| component_labels=None, | |
| cross_component_only=False, | |
| log_prefix="", | |
| ): | |
| """Build face adjacency based on exact triangle-to-triangle distances.""" | |
| n_faces = len(faces) | |
| face_verts_all = verts[faces] | |
| bbox_mins = np.min(face_verts_all, axis=1) | |
| bbox_maxs = np.max(face_verts_all, axis=1) | |
| if cross_component_only: | |
| if component_labels is None: | |
| raise ValueError("component_labels is required when cross_component_only=True") | |
| component_labels = np.asarray(component_labels, dtype=np.int32) | |
| if component_labels.shape != (n_faces,): | |
| raise ValueError( | |
| "component_labels must have one entry per face, " | |
| f"got shape {component_labels.shape} for {n_faces} faces" | |
| ) | |
| step1_start = time.time() | |
| n_pairs_total = n_faces * (n_faces - 1) // 2 | |
| if cross_component_only: | |
| component_sizes = np.bincount(component_labels.astype(np.int64, copy=False)) | |
| intra_component_pairs = int( | |
| np.sum(component_sizes * np.maximum(component_sizes - 1, 0) // 2) | |
| ) | |
| n_pairs_total -= intra_component_pairs | |
| idx = np.arange(n_faces, dtype=np.int64) | |
| best_axis = 0 | |
| best_order = np.argsort(bbox_mins[:, 0], kind="mergesort") | |
| mins_axis = bbox_mins[best_order, 0] | |
| maxs_axis = bbox_maxs[best_order, 0] | |
| best_upper = np.searchsorted( | |
| mins_axis, | |
| maxs_axis + distance_threshold, | |
| side="left", | |
| ).astype(np.int64, copy=False) | |
| window = best_upper.copy() | |
| window -= idx | |
| window -= 1 | |
| window[window < 0] = 0 | |
| best_estimated_checks = int(np.sum(window)) | |
| for axis in (1, 2): | |
| order_axis = np.argsort(bbox_mins[:, axis], kind="mergesort") | |
| mins_axis = bbox_mins[order_axis, axis] | |
| maxs_axis = bbox_maxs[order_axis, axis] | |
| upper_axis = np.searchsorted( | |
| mins_axis, | |
| maxs_axis + distance_threshold, | |
| side="left", | |
| ).astype(np.int64, copy=False) | |
| window_axis = upper_axis.copy() | |
| window_axis -= idx | |
| window_axis -= 1 | |
| window_axis[window_axis < 0] = 0 | |
| estimated_checks = int(np.sum(window_axis)) | |
| if estimated_checks < best_estimated_checks: | |
| best_estimated_checks = estimated_checks | |
| best_axis = axis | |
| best_order = order_axis | |
| best_upper = upper_axis | |
| other_axes = [axis for axis in (0, 1, 2) if axis != best_axis] | |
| axis_a, axis_b = other_axes | |
| sorted_mins = bbox_mins[best_order] | |
| sorted_maxs = bbox_maxs[best_order] | |
| candidate_pairs = generate_candidate_pairs_sweep_numba( | |
| best_order, | |
| sorted_mins[:, axis_a], | |
| sorted_maxs[:, axis_a], | |
| sorted_mins[:, axis_b], | |
| sorted_maxs[:, axis_b], | |
| best_upper, | |
| distance_threshold, | |
| ) | |
| if cross_component_only and len(candidate_pairs) > 0: | |
| cross_component_mask = ( | |
| component_labels[candidate_pairs[:, 0]] != component_labels[candidate_pairs[:, 1]] | |
| ) | |
| candidate_pairs = candidate_pairs[cross_component_mask] | |
| step1_time = time.time() - step1_start | |
| axis_name = "xyz"[best_axis] | |
| sparsity = len(candidate_pairs) / n_pairs_total if n_pairs_total > 0 else 0.0 | |
| prefix = f"{log_prefix} " if log_prefix else "" | |
| print( | |
| f"{prefix}Step 1 (Exact sparse candidate generation): {step1_time:.4f}s - " | |
| f"{best_estimated_checks} axis-{axis_name} sweep checks -> " | |
| f"{len(candidate_pairs):,} candidates ({sparsity:.8f} of all pairs)" | |
| ) | |
| step2_start = time.time() | |
| face_adjacency = defaultdict(set) | |
| batch_size = 200_000 | |
| threshold_sq = distance_threshold * distance_threshold | |
| if max_distance_workers is None: | |
| cpu_count = os.cpu_count() or 1 | |
| max_workers = min(cpu_count, 64) | |
| else: | |
| max_workers = max(1, int(max_distance_workers)) | |
| accepted_rows = [] | |
| accepted_cols = [] | |
| total_distance_batches = 0 | |
| if max_workers == 1: | |
| for start in range(0, len(candidate_pairs), batch_size): | |
| end = min(start + batch_size, len(candidate_pairs)) | |
| adjacent_pairs = filter_adjacent_pairs_batch( | |
| candidate_pairs[start:end], | |
| verts, | |
| faces, | |
| threshold_sq, | |
| ) | |
| total_distance_batches += 1 | |
| if len(adjacent_pairs) > 0: | |
| accepted_rows.append(adjacent_pairs[:, 0].astype(np.int32, copy=False)) | |
| accepted_cols.append(adjacent_pairs[:, 1].astype(np.int32, copy=False)) | |
| else: | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| futures = [] | |
| for start in range(0, len(candidate_pairs), batch_size): | |
| end = min(start + batch_size, len(candidate_pairs)) | |
| futures.append( | |
| executor.submit( | |
| filter_adjacent_pairs_batch, | |
| candidate_pairs[start:end], | |
| verts, | |
| faces, | |
| threshold_sq, | |
| ) | |
| ) | |
| total_distance_batches += 1 | |
| for future in as_completed(futures): | |
| adjacent_pairs = future.result() | |
| if len(adjacent_pairs) > 0: | |
| accepted_rows.append(adjacent_pairs[:, 0].astype(np.int32, copy=False)) | |
| accepted_cols.append(adjacent_pairs[:, 1].astype(np.int32, copy=False)) | |
| if accepted_rows: | |
| rows = np.concatenate(accepted_rows, axis=0) | |
| cols = np.concatenate(accepted_cols, axis=0) | |
| sym_rows = np.concatenate((rows, cols), axis=0) | |
| sym_cols = np.concatenate((cols, rows), axis=0) | |
| data = np.ones(len(sym_rows), dtype=np.uint8) | |
| adjacency_csr = coo_matrix( | |
| (data, (sym_rows, sym_cols)), | |
| shape=(n_faces, n_faces), | |
| dtype=np.uint8, | |
| ).tocsr() | |
| indptr = adjacency_csr.indptr | |
| indices = adjacency_csr.indices | |
| non_empty_rows = np.flatnonzero(np.diff(indptr)) | |
| for face_i in non_empty_rows: | |
| start = indptr[face_i] | |
| end = indptr[face_i + 1] | |
| face_adjacency[int(face_i)] = set(indices[start:end].tolist()) | |
| step2_time = time.time() - step2_start | |
| print( | |
| f"{prefix}Step 2 (Precise distance computation): {step2_time:.4f}s - " | |
| f"{len(candidate_pairs):,} candidates across {total_distance_batches:,} " | |
| f"distance batches; {len(face_adjacency)} faces have adjacencies" | |
| ) | |
| return face_adjacency | |
| def _component_face_ids_from_labels(component_labels, n_components): | |
| """Group face IDs by connected-component label without scanning once per component.""" | |
| component_labels = np.asarray(component_labels, dtype=np.int64) | |
| order = np.argsort(component_labels, kind="mergesort") | |
| counts = np.bincount(component_labels, minlength=int(n_components)) | |
| offsets = np.concatenate( | |
| ( | |
| np.zeros((1,), dtype=np.int64), | |
| np.cumsum(counts, dtype=np.int64), | |
| ) | |
| ) | |
| return [ | |
| order[offsets[component_id]:offsets[component_id + 1]] | |
| for component_id in range(int(n_components)) | |
| ] | |
| def _component_bounds_from_face_bounds( | |
| bbox_mins, | |
| bbox_maxs, | |
| component_labels, | |
| n_components, | |
| ): | |
| component_bbox_mins = np.full((int(n_components), 3), np.inf, dtype=np.float64) | |
| component_bbox_maxs = np.full((int(n_components), 3), -np.inf, dtype=np.float64) | |
| np.minimum.at(component_bbox_mins, component_labels, bbox_mins) | |
| np.maximum.at(component_bbox_maxs, component_labels, bbox_maxs) | |
| return component_bbox_mins, component_bbox_maxs | |
| def find_face_groups(faces, face_labels, adjacency=None, verts=None): | |
| """Find connected components of faces with the same link ID.""" | |
| num_faces = len(faces) | |
| if adjacency is None: | |
| if verts is None: | |
| raise ValueError("verts must be provided if adjacency is None") | |
| mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) | |
| adjacency = defaultdict(set) | |
| for face_i, face_j in mesh.face_adjacency: | |
| adjacency[int(face_i)].add(int(face_j)) | |
| adjacency[int(face_j)].add(int(face_i)) | |
| visited = np.zeros(num_faces, dtype=bool) | |
| groups = [] | |
| for start_face_id in range(num_faces): | |
| if visited[start_face_id]: | |
| continue | |
| link_id = int(face_labels[start_face_id]) | |
| component = [] | |
| queue = deque([start_face_id]) | |
| visited[start_face_id] = True | |
| while queue: | |
| face_id = queue.popleft() | |
| component.append(face_id) | |
| for neighbor_id in adjacency.get(face_id, []): | |
| if not visited[neighbor_id] and int(face_labels[neighbor_id]) == link_id: | |
| visited[neighbor_id] = True | |
| queue.append(neighbor_id) | |
| groups.append((link_id, component)) | |
| return groups | |
| def find_connected_components_fast(n_nodes, adjacency): | |
| """Find connected components in an undirected sparse graph.""" | |
| row_indices = [] | |
| col_indices = [] | |
| for node_idx in range(n_nodes): | |
| for neighbor_idx in adjacency.get(node_idx, set()): | |
| row_indices.append(node_idx) | |
| col_indices.append(neighbor_idx) | |
| if not row_indices: | |
| return n_nodes, np.arange(n_nodes, dtype=np.int32) | |
| adjacency_matrix = csr_matrix( | |
| (np.ones(len(row_indices), dtype=bool), (row_indices, col_indices)), | |
| shape=(n_nodes, n_nodes), | |
| ) | |
| n_components, labels = connected_components(adjacency_matrix, directed=False) | |
| return int(n_components), labels.astype(np.int32, copy=False) | |
| def _copy_face_adjacency(adjacency): | |
| return {int(face_id): set(int(neighbor) for neighbor in neighbors) for face_id, neighbors in adjacency.items()} | |
| def _merge_face_adjacency(base_adjacency, extra_adjacency): | |
| """Merge undirected face adjacency maps.""" | |
| merged_adjacency = _copy_face_adjacency(base_adjacency) | |
| for face_id, neighbors in extra_adjacency.items(): | |
| merged_adjacency.setdefault(int(face_id), set()).update( | |
| int(neighbor) for neighbor in neighbors | |
| ) | |
| return merged_adjacency | |
| def _closest_face_pair_between_components( | |
| face_component_a, | |
| face_component_b, | |
| *, | |
| face_verts, | |
| bbox_mins, | |
| bbox_maxs, | |
| face_centroids, | |
| upper_bound_sq=np.inf, | |
| ): | |
| """Find the exact closest face pair across two disconnected components.""" | |
| face_component_a = np.asarray(face_component_a, dtype=np.int64) | |
| face_component_b = np.asarray(face_component_b, dtype=np.int64) | |
| if face_component_a.size == 0 or face_component_b.size == 0: | |
| raise ValueError("component face lists must be non-empty") | |
| if face_component_a.size > face_component_b.size: | |
| best_pair, best_sq = _closest_face_pair_between_components( | |
| face_component_b, | |
| face_component_a, | |
| face_verts=face_verts, | |
| bbox_mins=bbox_mins, | |
| bbox_maxs=bbox_maxs, | |
| face_centroids=face_centroids, | |
| upper_bound_sq=upper_bound_sq, | |
| ) | |
| return best_pair[::-1], best_sq | |
| centroid_tree = cKDTree(face_centroids[face_component_a]) | |
| centroid_distances, nearest_local = _query_nearest( | |
| centroid_tree, | |
| face_centroids[face_component_b], | |
| ) | |
| centroid_distances = np.atleast_1d(centroid_distances) | |
| nearest_local = np.atleast_1d(nearest_local) | |
| best_b_local = int(np.argmin(centroid_distances)) | |
| best_face_a = int(face_component_a[nearest_local[best_b_local]]) | |
| best_face_b = int(face_component_b[best_b_local]) | |
| best_sq = min( | |
| float(upper_bound_sq), | |
| float( | |
| _triangle_pair_distance_sq_batch( | |
| face_verts[[best_face_a]], | |
| face_verts[[best_face_b]], | |
| )[0] | |
| ), | |
| ) | |
| best_pair = (best_face_a, best_face_b) | |
| if best_sq <= 0.0: | |
| return best_pair, best_sq | |
| block_size = 128 | |
| exact_batch_size = 8_192 | |
| for start_a in range(0, len(face_component_a), block_size): | |
| face_ids_a = face_component_a[start_a:start_a + block_size] | |
| mins_a = bbox_mins[face_ids_a] | |
| maxs_a = bbox_maxs[face_ids_a] | |
| for start_b in range(0, len(face_component_b), block_size): | |
| face_ids_b = face_component_b[start_b:start_b + block_size] | |
| mins_b = bbox_mins[face_ids_b] | |
| maxs_b = bbox_maxs[face_ids_b] | |
| axis_gap = np.maximum( | |
| mins_a[:, np.newaxis, :] - maxs_b[np.newaxis, :, :], | |
| mins_b[np.newaxis, :, :] - maxs_a[:, np.newaxis, :], | |
| ) | |
| axis_gap = np.maximum(axis_gap, 0.0) | |
| lower_bound_sq = np.einsum("ijk,ijk->ij", axis_gap, axis_gap) | |
| candidate_mask = lower_bound_sq < best_sq | |
| if not np.any(candidate_mask): | |
| continue | |
| candidate_a_local, candidate_b_local = np.nonzero(candidate_mask) | |
| candidate_lower_bounds = lower_bound_sq[candidate_a_local, candidate_b_local] | |
| candidate_order = np.argsort(candidate_lower_bounds, kind="mergesort") | |
| candidate_a = face_ids_a[candidate_a_local[candidate_order]] | |
| candidate_b = face_ids_b[candidate_b_local[candidate_order]] | |
| candidate_lower_bounds = candidate_lower_bounds[candidate_order] | |
| for batch_start in range(0, len(candidate_a), exact_batch_size): | |
| batch_end = min(batch_start + exact_batch_size, len(candidate_a)) | |
| if candidate_lower_bounds[batch_start] >= best_sq: | |
| break | |
| batch_face_ids_a = candidate_a[batch_start:batch_end] | |
| batch_face_ids_b = candidate_b[batch_start:batch_end] | |
| batch_dist_sq = _triangle_pair_distance_sq_batch( | |
| face_verts[batch_face_ids_a], | |
| face_verts[batch_face_ids_b], | |
| ) | |
| batch_best_local = int(np.argmin(batch_dist_sq)) | |
| batch_best_sq = float(batch_dist_sq[batch_best_local]) | |
| if batch_best_sq < best_sq: | |
| best_sq = batch_best_sq | |
| best_pair = ( | |
| int(batch_face_ids_a[batch_best_local]), | |
| int(batch_face_ids_b[batch_best_local]), | |
| ) | |
| if best_sq <= 0.0: | |
| return best_pair, best_sq | |
| return best_pair, best_sq | |
| class _UnionFind: | |
| def __init__(self, n_items): | |
| self.parent = np.arange(int(n_items), dtype=np.int32) | |
| self.rank = np.zeros((int(n_items),), dtype=np.int8) | |
| self.n_sets = int(n_items) | |
| def find(self, item): | |
| item = int(item) | |
| parent = self.parent | |
| while int(parent[item]) != item: | |
| parent[item] = parent[int(parent[item])] | |
| item = int(parent[item]) | |
| return item | |
| def union(self, item_a, item_b): | |
| root_a = self.find(item_a) | |
| root_b = self.find(item_b) | |
| if root_a == root_b: | |
| return False | |
| if self.rank[root_a] < self.rank[root_b]: | |
| root_a, root_b = root_b, root_a | |
| self.parent[root_b] = root_a | |
| if self.rank[root_a] == self.rank[root_b]: | |
| self.rank[root_a] += 1 | |
| self.n_sets -= 1 | |
| return True | |
| def _component_bbox_pair_lower_bound_sq( | |
| component_bbox_mins, | |
| component_bbox_maxs, | |
| component_indices_a, | |
| component_indices_b, | |
| ): | |
| """Return exact AABB lower-bound distances for component bbox pairs.""" | |
| mins_a = component_bbox_mins[component_indices_a] | |
| maxs_a = component_bbox_maxs[component_indices_a] | |
| mins_b = component_bbox_mins[component_indices_b] | |
| maxs_b = component_bbox_maxs[component_indices_b] | |
| axis_gap = np.maximum(mins_a - maxs_b, mins_b - maxs_a) | |
| axis_gap = np.maximum(axis_gap, 0.0) | |
| return np.einsum("ij,ij->i", axis_gap, axis_gap) | |
| def _sorted_component_bbox_lower_bound_pairs( | |
| component_bbox_mins, | |
| component_bbox_maxs, | |
| ): | |
| """Return all component pairs sorted by their exact AABB distance lower bound.""" | |
| n_components = len(component_bbox_mins) | |
| n_pairs = n_components * (n_components - 1) // 2 | |
| pair_rows = np.empty((n_pairs,), dtype=np.int32) | |
| pair_cols = np.empty((n_pairs,), dtype=np.int32) | |
| pair_lower_bound_sq = np.empty((n_pairs,), dtype=np.float64) | |
| write_offset = 0 | |
| for component_idx in range(n_components - 1): | |
| component_indices_b = np.arange( | |
| component_idx + 1, | |
| n_components, | |
| dtype=np.int32, | |
| ) | |
| n_row_pairs = len(component_indices_b) | |
| pair_rows[write_offset:write_offset + n_row_pairs] = component_idx | |
| pair_cols[write_offset:write_offset + n_row_pairs] = component_indices_b | |
| pair_lower_bound_sq[write_offset:write_offset + n_row_pairs] = ( | |
| _component_bbox_pair_lower_bound_sq( | |
| component_bbox_mins, | |
| component_bbox_maxs, | |
| np.full((n_row_pairs,), component_idx, dtype=np.int32), | |
| component_indices_b, | |
| ) | |
| ) | |
| write_offset += n_row_pairs | |
| sort_order = np.lexsort((pair_cols, pair_rows, pair_lower_bound_sq)) | |
| return ( | |
| pair_rows[sort_order], | |
| pair_cols[sort_order], | |
| pair_lower_bound_sq[sort_order], | |
| ) | |
| def _exact_component_bridge_edges( | |
| component_face_ids, | |
| *, | |
| face_verts, | |
| bbox_mins, | |
| bbox_maxs, | |
| face_centroids, | |
| component_bbox_mins, | |
| component_bbox_maxs, | |
| ): | |
| """Return exact MST face-pair bridges over disconnected face components. | |
| Components are connected by the exact closest triangle-to-triangle face pair. | |
| AABB distances are used only as lower bounds for lazy Kruskal ordering. | |
| """ | |
| n_components = len(component_face_ids) | |
| if n_components <= 1: | |
| return [] | |
| lower_bound_start = time.time() | |
| pair_rows, pair_cols, pair_lower_bound_sq = _sorted_component_bbox_lower_bound_pairs( | |
| component_bbox_mins, | |
| component_bbox_maxs, | |
| ) | |
| lower_bound_time = time.time() - lower_bound_start | |
| union_find = _UnionFind(n_components) | |
| bridge_edges = [] | |
| exact_edge_heap = [] | |
| pair_cursor = 0 | |
| exact_evaluations = 0 | |
| skipped_same_set_lower_bound_pairs = 0 | |
| discarded_same_set_exact_edges = 0 | |
| exact_eval_time = 0.0 | |
| n_pairs = len(pair_rows) | |
| while union_find.n_sets > 1: | |
| while ( | |
| pair_cursor < n_pairs | |
| and ( | |
| not exact_edge_heap | |
| or float(pair_lower_bound_sq[pair_cursor]) < float(exact_edge_heap[0][0]) | |
| ) | |
| ): | |
| component_idx_a = int(pair_rows[pair_cursor]) | |
| component_idx_b = int(pair_cols[pair_cursor]) | |
| pair_cursor += 1 | |
| if union_find.find(component_idx_a) == union_find.find(component_idx_b): | |
| skipped_same_set_lower_bound_pairs += 1 | |
| continue | |
| exact_start = time.time() | |
| face_pair, distance_sq = _closest_face_pair_between_components( | |
| component_face_ids[component_idx_a], | |
| component_face_ids[component_idx_b], | |
| face_verts=face_verts, | |
| bbox_mins=bbox_mins, | |
| bbox_maxs=bbox_maxs, | |
| face_centroids=face_centroids, | |
| ) | |
| exact_eval_time += time.time() - exact_start | |
| exact_evaluations += 1 | |
| heapq.heappush( | |
| exact_edge_heap, | |
| ( | |
| float(distance_sq), | |
| component_idx_a, | |
| component_idx_b, | |
| int(face_pair[0]), | |
| int(face_pair[1]), | |
| ), | |
| ) | |
| if not exact_edge_heap: | |
| raise RuntimeError("failed to connect face-distance adjacency components") | |
| _, component_idx_a, component_idx_b, face_i, face_j = heapq.heappop( | |
| exact_edge_heap | |
| ) | |
| if union_find.union(component_idx_a, component_idx_b): | |
| bridge_edges.append((face_i, face_j)) | |
| else: | |
| discarded_same_set_exact_edges += 1 | |
| print( | |
| "Exact component bridge search: " | |
| f"{lower_bound_time:.4f}s bbox lower bounds for {n_pairs:,} component pairs; " | |
| f"{exact_evaluations:,} exact component-pair evaluations in {exact_eval_time:.4f}s; " | |
| f"accepted {len(bridge_edges):,} bridges; " | |
| f"skipped {skipped_same_set_lower_bound_pairs:,} same-set lower-bound pairs; " | |
| f"discarded {discarded_same_set_exact_edges:,} same-set exact edges" | |
| ) | |
| return bridge_edges | |
| def ensure_face_adjacency_is_connected(mesh, face_adjacency): | |
| """Bridge disconnected face-distance components by component-level nearest links.""" | |
| connected_adjacency = _copy_face_adjacency(face_adjacency) | |
| num_faces = len(mesh.faces) | |
| n_components, component_labels = find_connected_components_fast( | |
| num_faces, | |
| connected_adjacency, | |
| ) | |
| if n_components <= 1: | |
| return connected_adjacency | |
| bridge_start = time.time() | |
| print( | |
| "Face connectivity graph has " | |
| f"{n_components} connected components after threshold links; " | |
| "adding nearest component bridges" | |
| ) | |
| verts = np.asarray(mesh.vertices, dtype=np.float64) | |
| faces = np.asarray(mesh.faces, dtype=np.int64) | |
| face_verts = verts[faces] | |
| bbox_mins = np.min(face_verts, axis=1) | |
| bbox_maxs = np.max(face_verts, axis=1) | |
| face_centroids = np.asarray(mesh.triangles_center, dtype=np.float64) | |
| component_face_ids = _component_face_ids_from_labels(component_labels, n_components) | |
| component_bbox_mins, component_bbox_maxs = _component_bounds_from_face_bounds( | |
| bbox_mins, | |
| bbox_maxs, | |
| component_labels, | |
| n_components, | |
| ) | |
| bridge_edges = _exact_component_bridge_edges( | |
| component_face_ids, | |
| face_verts=face_verts, | |
| bbox_mins=bbox_mins, | |
| bbox_maxs=bbox_maxs, | |
| face_centroids=face_centroids, | |
| component_bbox_mins=component_bbox_mins, | |
| component_bbox_maxs=component_bbox_maxs, | |
| ) | |
| for best_pair in bridge_edges: | |
| face_i, face_j = best_pair | |
| connected_adjacency.setdefault(face_i, set()).add(face_j) | |
| connected_adjacency.setdefault(face_j, set()).add(face_i) | |
| print( | |
| "Component connectivity bridge step: " | |
| f"{time.time() - bridge_start:.4f}s - added {len(bridge_edges):,} " | |
| "nearest component bridges" | |
| ) | |
| return connected_adjacency | |
| def build_face_connectivity_adjacency_for_inference( | |
| mesh, | |
| *, | |
| distance_threshold=DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD, | |
| max_distance_workers=None, | |
| ): | |
| """Build the inference connectivity graph from edge adjacency plus cross-CC distance links.""" | |
| base_face_adjacency = build_face_edge_adjacency(mesh) | |
| num_faces = len(mesh.faces) | |
| n_components, component_labels = find_connected_components_fast( | |
| num_faces, | |
| base_face_adjacency, | |
| ) | |
| if n_components <= 1: | |
| return base_face_adjacency | |
| print( | |
| "Base face-edge adjacency has " | |
| f"{n_components} connected components across {num_faces} faces; " | |
| "adding face-level cross-component distance links" | |
| ) | |
| cross_component_adjacency = build_face_distance_adjacency( | |
| np.asarray(mesh.vertices, dtype=np.float64), | |
| np.asarray(mesh.faces, dtype=np.int64), | |
| distance_threshold=float(distance_threshold), | |
| component_labels=component_labels, | |
| cross_component_only=True, | |
| max_distance_workers=max_distance_workers, | |
| log_prefix="Cross-component", | |
| ) | |
| return _merge_face_adjacency(base_face_adjacency, cross_component_adjacency) | |
| def _compute_face_probability_statistics(num_faces, point_part_probabilities, face_indices): | |
| """Aggregate point softmax probabilities onto faces.""" | |
| point_part_probabilities = np.asarray(point_part_probabilities, dtype=np.float32) | |
| face_indices = np.asarray(face_indices, dtype=np.int64) | |
| if point_part_probabilities.ndim != 2: | |
| raise ValueError( | |
| "point_part_probabilities must have shape (num_points, num_parts), " | |
| f"got {point_part_probabilities.shape}" | |
| ) | |
| if face_indices.shape != (point_part_probabilities.shape[0],): | |
| raise ValueError( | |
| "face_indices must have one entry per point, " | |
| f"got {face_indices.shape} for {point_part_probabilities.shape[0]} points" | |
| ) | |
| num_parts = point_part_probabilities.shape[1] | |
| face_probability_sums = np.zeros((num_faces, num_parts), dtype=np.float64) | |
| face_probability_counts = np.zeros(num_faces, dtype=np.int64) | |
| valid_mask = face_indices >= 0 | |
| if np.any(valid_mask): | |
| valid_face_indices = face_indices[valid_mask] | |
| np.add.at( | |
| face_probability_sums, | |
| valid_face_indices, | |
| point_part_probabilities[valid_mask], | |
| ) | |
| np.add.at( | |
| face_probability_counts, | |
| valid_face_indices, | |
| np.ones(valid_face_indices.shape[0], dtype=np.int64), | |
| ) | |
| return face_probability_sums, face_probability_counts | |
| def _build_filled_face_probability_means(mesh, face_probability_sums, face_probability_counts): | |
| """Fill unsampled face probabilities from the nearest sampled face.""" | |
| num_faces, num_parts = face_probability_sums.shape | |
| filled_face_probability_means = np.zeros((num_faces, num_parts), dtype=np.float32) | |
| defined_mask = face_probability_counts > 0 | |
| if np.any(defined_mask): | |
| defined_faces = np.flatnonzero(defined_mask) | |
| filled_face_probability_means[defined_faces] = ( | |
| face_probability_sums[defined_faces] | |
| / face_probability_counts[defined_faces, np.newaxis] | |
| ).astype(np.float32, copy=False) | |
| else: | |
| filled_face_probability_means.fill(1.0 / max(num_parts, 1)) | |
| return filled_face_probability_means | |
| undefined_faces = np.flatnonzero(~defined_mask) | |
| if undefined_faces.size == 0: | |
| return filled_face_probability_means | |
| centroids = _get_face_centroids(mesh) | |
| tree = cKDTree(centroids[defined_faces]) | |
| _, nearest_local = _query_nearest(tree, centroids[undefined_faces]) | |
| nearest_local = np.atleast_1d(nearest_local) | |
| nearest_defined_faces = defined_faces[nearest_local] | |
| filled_face_probability_means[undefined_faces] = filled_face_probability_means[ | |
| nearest_defined_faces | |
| ] | |
| return filled_face_probability_means | |
| def _group_confidence_vector( | |
| group_face_ids, | |
| face_probability_sums, | |
| face_probability_counts, | |
| filled_face_probability_means, | |
| ): | |
| """Aggregate point softmax probabilities for one face group.""" | |
| group_face_ids = np.asarray(group_face_ids, dtype=np.int64) | |
| group_point_count = int(face_probability_counts[group_face_ids].sum()) | |
| if group_point_count > 0: | |
| return ( | |
| face_probability_sums[group_face_ids].sum(axis=0) | |
| / float(group_point_count) | |
| ).astype(np.float32, copy=False) | |
| return filled_face_probability_means[group_face_ids].mean(axis=0).astype( | |
| np.float32, | |
| copy=False, | |
| ) | |
| def _adjacent_group_indices_by_part_id_for_group(groups, face_adjacency, face_to_group, group_idx): | |
| adjacent_groups_by_part_id = defaultdict(set) | |
| for face_id in groups[group_idx][1]: | |
| for adjacent_face_id in face_adjacency.get(face_id, set()): | |
| adjacent_group_idx = int(face_to_group[int(adjacent_face_id)]) | |
| if adjacent_group_idx < 0 or adjacent_group_idx == group_idx: | |
| continue | |
| adjacent_part_id = int(groups[adjacent_group_idx][0]) | |
| if adjacent_part_id >= 0: | |
| adjacent_groups_by_part_id[adjacent_part_id].add(adjacent_group_idx) | |
| return adjacent_groups_by_part_id | |
| def _iterative_single_group_reassignment( | |
| face_part_ids, | |
| *, | |
| face_adjacency, | |
| input_part_ids, | |
| face_probability_sums, | |
| face_probability_counts, | |
| filled_face_probability_means, | |
| ): | |
| """Iteratively enforce one face group per part ID.""" | |
| input_part_ids = np.asarray(input_part_ids, dtype=np.int64) | |
| face_part_ids = np.asarray(face_part_ids, dtype=np.int32).copy() | |
| max_iterations = max(8, 4 * max(1, int(input_part_ids.size))) | |
| seen_states = set() | |
| for _ in range(max_iterations): | |
| state_key = face_part_ids.tobytes() | |
| if state_key in seen_states: | |
| return face_part_ids | |
| seen_states.add(state_key) | |
| groups = find_face_groups( | |
| np.empty((face_part_ids.shape[0], 3), dtype=np.int64), | |
| face_part_ids, | |
| adjacency=face_adjacency, | |
| ) | |
| groups_by_part_id = defaultdict(list) | |
| group_confidences = [] | |
| group_sizes = [] | |
| for group_idx, (part_id, group_face_ids) in enumerate(groups): | |
| groups_by_part_id[int(part_id)].append(group_idx) | |
| group_face_ids = np.asarray(group_face_ids, dtype=np.int64) | |
| group_confidences.append( | |
| _group_confidence_vector( | |
| group_face_ids, | |
| face_probability_sums, | |
| face_probability_counts, | |
| filled_face_probability_means, | |
| ) | |
| ) | |
| group_sizes.append(len(group_face_ids)) | |
| duplicate_part_ids = [ | |
| int(part_id) | |
| for part_id, group_indices in groups_by_part_id.items() | |
| if len(group_indices) > 1 and part_id >= 0 | |
| ] | |
| if not duplicate_part_ids: | |
| return face_part_ids | |
| existing_part_ids = { | |
| int(part_id) | |
| for part_id in np.unique(face_part_ids) | |
| if int(part_id) >= 0 | |
| } | |
| missing_part_ids = sorted( | |
| int(part_id) for part_id in input_part_ids if int(part_id) not in existing_part_ids | |
| ) | |
| updates = {} | |
| groups_to_keep = {} | |
| for part_id in duplicate_part_ids: | |
| groups_to_keep[part_id] = max( | |
| groups_by_part_id[part_id], | |
| key=lambda group_idx: ( | |
| group_sizes[group_idx], | |
| float(group_confidences[group_idx][part_id]), | |
| -group_idx, | |
| ), | |
| ) | |
| face_to_group = np.full(face_part_ids.shape[0], -1, dtype=np.int32) | |
| for group_idx, (_, group_face_ids) in enumerate(groups): | |
| face_to_group[np.asarray(group_face_ids, dtype=np.int64)] = int(group_idx) | |
| available_missing_part_ids = set(missing_part_ids) | |
| for part_id in duplicate_part_ids: | |
| group_to_keep = groups_to_keep[part_id] | |
| for group_to_update in groups_by_part_id[part_id]: | |
| if group_to_update == group_to_keep: | |
| continue | |
| adjacent_groups_by_part_id = _adjacent_group_indices_by_part_id_for_group( | |
| groups, | |
| face_adjacency, | |
| face_to_group, | |
| group_to_update, | |
| ) | |
| safe_adjacent_part_ids = set() | |
| for adjacent_part_id, adjacent_group_indices in adjacent_groups_by_part_id.items(): | |
| if adjacent_part_id == part_id: | |
| continue | |
| target_group_indices = groups_by_part_id.get(adjacent_part_id, []) | |
| if len(target_group_indices) == 1: | |
| safe_adjacent_part_ids.add(adjacent_part_id) | |
| continue | |
| target_group_to_keep = groups_to_keep.get(adjacent_part_id) | |
| if ( | |
| target_group_to_keep is not None | |
| and target_group_to_keep in adjacent_group_indices | |
| ): | |
| safe_adjacent_part_ids.add(adjacent_part_id) | |
| replacement_candidates = sorted( | |
| safe_adjacent_part_ids | available_missing_part_ids | |
| ) | |
| if not replacement_candidates: | |
| continue | |
| confidence_vector = group_confidences[group_to_update] | |
| best_replacement_part_id = max( | |
| replacement_candidates, | |
| key=lambda candidate_part_id: ( | |
| float(confidence_vector[candidate_part_id]), | |
| -int(candidate_part_id), | |
| ), | |
| ) | |
| updates[group_to_update] = int(best_replacement_part_id) | |
| available_missing_part_ids.discard(int(best_replacement_part_id)) | |
| if not updates: | |
| return face_part_ids | |
| updated_face_part_ids = face_part_ids.copy() | |
| for group_idx, replacement_part_id in updates.items(): | |
| updated_face_part_ids[np.asarray(groups[group_idx][1], dtype=np.int64)] = replacement_part_id | |
| if np.array_equal(updated_face_part_ids, face_part_ids): | |
| return face_part_ids | |
| face_part_ids = updated_face_part_ids | |
| raise RuntimeError("single-group face post-processing did not converge") | |
| def refine_face_part_ids_for_inference( | |
| mesh, | |
| face_part_ids, | |
| *, | |
| point_part_probabilities=None, | |
| face_indices=None, | |
| input_part_ids=None, | |
| strict=False, | |
| enforce_connectivity_per_part=False, | |
| distance_threshold=DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD, | |
| ): | |
| """Inference-time face post-processing layered on top of the base face pass.""" | |
| face_part_ids = np.asarray(face_part_ids, dtype=np.int32) | |
| base_face_part_ids = refine_face_part_ids( | |
| mesh, | |
| face_part_ids, | |
| strict=bool(strict), | |
| ) | |
| if not enforce_connectivity_per_part: | |
| return base_face_part_ids | |
| if point_part_probabilities is None: | |
| raise ValueError( | |
| "point_part_probabilities is required when enforce_connectivity_per_part is enabled" | |
| ) | |
| if face_indices is None: | |
| raise ValueError("face_indices is required when enforce_connectivity_per_part is enabled") | |
| if input_part_ids is None: | |
| raise ValueError("input_part_ids is required when enforce_connectivity_per_part is enabled") | |
| point_part_probabilities = np.asarray(point_part_probabilities, dtype=np.float32) | |
| face_indices = np.asarray(face_indices, dtype=np.int64) | |
| input_part_ids = np.asarray(input_part_ids, dtype=np.int64) | |
| if point_part_probabilities.ndim != 2: | |
| raise ValueError("point_part_probabilities must have shape [num_points, num_parts]") | |
| if point_part_probabilities.shape[1] <= 0: | |
| raise ValueError("point_part_probabilities must contain at least one part column") | |
| if np.any(input_part_ids < 0): | |
| raise ValueError("input_part_ids must be non-negative") | |
| if np.any(input_part_ids >= point_part_probabilities.shape[1]): | |
| raise ValueError( | |
| "input_part_ids must be within the probability columns, " | |
| f"got max {int(input_part_ids.max())} for {point_part_probabilities.shape[1]} columns" | |
| ) | |
| face_connectivity_adjacency = build_face_connectivity_adjacency_for_inference( | |
| mesh, | |
| distance_threshold=float(distance_threshold), | |
| ) | |
| face_connectivity_adjacency = ensure_face_adjacency_is_connected( | |
| mesh, | |
| face_connectivity_adjacency, | |
| ) | |
| faces = np.asarray(mesh.faces, dtype=np.int64) | |
| face_probability_sums, face_probability_counts = _compute_face_probability_statistics( | |
| len(faces), | |
| point_part_probabilities, | |
| face_indices, | |
| ) | |
| filled_face_probability_means = _build_filled_face_probability_means( | |
| mesh, | |
| face_probability_sums, | |
| face_probability_counts, | |
| ) | |
| return _iterative_single_group_reassignment( | |
| base_face_part_ids, | |
| face_adjacency=face_connectivity_adjacency, | |
| input_part_ids=input_part_ids, | |
| face_probability_sums=face_probability_sums, | |
| face_probability_counts=face_probability_counts, | |
| filled_face_probability_means=filled_face_probability_means, | |
| ) | |