File size: 11,299 Bytes
7fdaedc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
#!/usr/bin/python3
# _*_ coding: utf-8 _*_
# ---------------------------------------------------
# @Time    : 2026-03-10 8:58 p.m.
# @Author  : shangfeng
# @Organization: University of Calgary
# @File    : evaluate.py.py
# @IDE     : PyCharm
# -----------------Evaluation TASK---------------------
# Evaluation
# 1. Chamfer Distance (CD): Measures the geometric discrepancy between the predicted mesh and the ground-truth mesh, reflecting the overall reconstruction accuracy.
#
# 2. Edge Chamfer Distance (ECD): Evaluates the geometric similarity between the edges of the reconstructed mesh and those of the ground-truth mesh, serving as an indicator of edge sharpness and structural fidelity.
#
# 3. Normal Consistency (NC): Assesses the alignment between surface normals of the predicted mesh and the ground-truth mesh, indicating the consistency of local surface orientation.
#
# 4. V_Ratio: Defined as the ratio between the number of vertices in the predicted mesh and that of the ground-truth mesh, reflecting changes in geometric complexity.
#
# 5. F_Ratio: Defined as the ratio between the number of faces in the predicted mesh and that of the ground-truth mesh, indicating variations in mesh resolution.
# ---------------------------------------------------
import os
import trimesh
import numpy as np
from scipy.spatial import cKDTree
import faiss


# --------------------------- Load mesh using trimesh and normalization --------------------------------------
def load_mesh(p_file, gt_file):
    """

    :param p_file:

    :param gt_file:

    :return:

    """
    p_mesh = trimesh.load(p_file)
    gt_mesh = trimesh.load(gt_file)
    return p_mesh, gt_mesh

def normalization(p_mesh, gt_mesh):
    gt_vertices = np.asarray(gt_mesh.vertices)
    p_vertices = np.asarray(p_mesh.vertices)
    vert_min = gt_vertices.min(axis=0)
    vert_max = gt_vertices.max(axis=0)

    vert_center = 0.5 * (vert_min + vert_max)

    gt_vertices = gt_vertices - vert_center
    # p_vertices = p_vertices - vert_center

    vert_min = gt_vertices.min(axis=0)
    vert_max = gt_vertices.max(axis=0)
    extents = vert_max - vert_min
    scale = np.max(extents)

    gt_vertices = gt_vertices / (scale + 1e-6)
    # p_vertices = p_vertices / (scale + 1e-6)
    p_vertices = p_vertices * np.sqrt(np.sum(extents ** 2)) / (scale + 1e-6)

    return trimesh.Trimesh(vertices=p_vertices,faces=p_mesh.faces), trimesh.Trimesh(vertices=gt_vertices,faces=gt_mesh.faces)


# --------------------------- L1 Chamfer distance --------------------------------------
def chamfer_l1_distance_kdtree(p, q):
    """

    p: (N,3) prediction

    q: (M,3) ground truth

    """

    # --- Remove invalid points to ensure numerical stability
    p = p[np.isfinite(p).all(axis=1)]
    q = q[np.isfinite(q).all(axis=1)]

    # --- KDTree
    tree_p = cKDTree(p)
    tree_q = cKDTree(q)

    # --- Distance
    dist_pq, _ = tree_q.query(p) # P → Q
    dist_qp, _ = tree_p.query(q) # Q → P

    # L1 Chamfer Distance
    chamfer_distance = np.mean(dist_pq) + np.mean(dist_qp)

    return chamfer_distance

def chamfer_l1_distance_faiss(p, q, use_gpu=False):
    """

    p: (N,3) prediction

    q: (M,3) ground truth

    """

    # ---------- 1. remove invalid ----------
    p = p[np.isfinite(p).all(axis=1)]
    q = q[np.isfinite(q).all(axis=1)]

    # FAISS
    p = p.astype(np.float32)
    q = q.astype(np.float32)

    # ---------- 2. build index ----------
    index_p = faiss.IndexFlatL2(3)  # dim=3
    index_q = faiss.IndexFlatL2(3)

    # ---------- 3. optional GPU ----------
    if use_gpu:
        res = faiss.StandardGpuResources()
        index_p = faiss.index_cpu_to_gpu(res, 0, index_p)
        index_q = faiss.index_cpu_to_gpu(res, 0, index_q)

    index_p.add(p)
    index_q.add(q)

    # ---------- 4. nearest neighbor ----------
    # FAISS return square distance
    D_pq, _ = index_q.search(p, 1)  # p → q
    D_qp, _ = index_p.search(q, 1)  # q → p

    # ---------- 5. convert to L1 ----------
    dist_pq = np.sqrt(D_pq[:, 0])
    dist_qp = np.sqrt(D_qp[:, 0])

    chamfer_distance = dist_pq.mean() + dist_qp.mean()

    return float(chamfer_distance)

# --------------------------- Mesh sampling points --------------------------------------
def mesh_sample_points(p_mesh, gt_mesh, sample_points=1000000):
    """

    :param p_mesh: trimesh mesh

    :param gt_mesh: Trimesh mesh

    :param sample_points:

    :return: (sample_points, 3)

    """
    p_points = p_mesh.sample(sample_points)
    gt_points = gt_mesh.sample(sample_points)
    return p_points, gt_points

# --------------------------- Edge Chamfer L1 Distance --------------------------------------
def extract_sharp_edges(mesh, angle_threshold_deg=30.0):
    """

    Version-agnostic sharp edge extraction.

    Works with any trimesh version.

    """
    faces = np.asarray(mesh.faces)
    face_normals = np.asarray(mesh.face_normals)

    # ------------------ normalize normals ----------------------
    face_normals = face_normals / (
        np.linalg.norm(face_normals, axis=1, keepdims=True) + 1e-12
    )

    # --- Step 1: build edge -> faces mapping ---
    edge_faces = dict()

    for f_idx, face in enumerate(faces):
        edges = [
            tuple(sorted((face[0], face[1]))),
            tuple(sorted((face[1], face[2]))),
            tuple(sorted((face[2], face[0]))),
        ]
        for e in edges:
            if e not in edge_faces:
                edge_faces[e] = []
            edge_faces[e].append(f_idx)

    # --- Step 2: detect sharp edges ---
    cos_thresh = np.cos(np.deg2rad(angle_threshold_deg))
    sharp_edges = []

    for edge, f_list in edge_faces.items():
        # boundary edge → sharp
        if len(f_list) == 1:
            sharp_edges.append(edge)
            continue

        # non-manifold (>2 faces) → treat as sharp
        if len(f_list) > 2:
            sharp_edges.append(edge)
            continue

        # exactly two adjacent faces
        f1, f2 = f_list
        n1 = face_normals[f1]
        n2 = face_normals[f2]

        dot = np.dot(n1, n2)
        dot = np.clip(dot, -1.0, 1.0)
        if np.abs(dot) < cos_thresh:
            sharp_edges.append(edge)

    if len(sharp_edges) == 0:
        return np.zeros((0, 2), dtype=np.int64)

    return np.asarray(sharp_edges, dtype=np.int64)


def sample_points_on_edges_global(vertices, edges, total_samples=100000):
    """

    Sample points uniformly along edges, proportional to edge length.



    Args:

        vertices (np.ndarray): (V, 3)

        edges (np.ndarray): (E, 2)

        total_samples (int): total number of sampled points



    Returns:

        np.ndarray: (total_samples, 3)

    """

    if edges.shape[0] == 0:
        return np.zeros((0, 3), dtype=np.float32)

    # --- 1. Endpoints of edges --------------
    p1 = vertices[edges[:, 0]]  # (E, 3)
    p2 = vertices[edges[:, 1]]  # (E, 3)

    # --- 2. Calculate the length of edge --------------
    edge_lengths = np.linalg.norm(p2 - p1, axis=1)  # (E,)

    # --- 3. Calculate probability --------------
    probs = edge_lengths / (edge_lengths.sum() + 1e-12)

    # --- 4. edge weight --------------
    edge_indices = np.random.choice(len(edges), size=total_samples, p=probs)

    # --- 5. random points --------------
    t = np.random.rand(total_samples, 1)  # (N,1)

    sampled_p1 = p1[edge_indices]
    sampled_p2 = p2[edge_indices]

    points = (1 - t) * sampled_p1 + t * sampled_p2

    return points.astype(np.float32)


def compute_edge_chamfer_distance(p_mesh, gt_mesh, angle_threshold_deg=30.0):
    """

    :param p_mesh:

    :param gt_mesh:

    :param angle_threshold_deg:

    :return:

    """
    # ---------- Extract sharp edges ----------
    sharp_edges_gt = extract_sharp_edges(gt_mesh, angle_threshold_deg)
    sharp_edges_pred = extract_sharp_edges(p_mesh, angle_threshold_deg)

    # ---------- Sample points on edges ----------
    edge_pts_gt = sample_points_on_edges_global(
        gt_mesh.vertices, sharp_edges_gt
    )
    edge_pts_pred = sample_points_on_edges_global(
        p_mesh.vertices, sharp_edges_pred
    )

    # ---------- Compute ECD ----------
    ecd = chamfer_l1_distance_kdtree(edge_pts_pred, edge_pts_gt)

    return ecd



# --------------------------- Normal Consistency (NC) --------------------------------------
def normal_consistency(

    p_mesh,

    gt_mesh,

    num_samples=100000

):
    """

    mesh_gt, mesh_pred: trimesh.Trimesh

    return: NC in [0, 1]

    """

    # ---------- 1. sample surface points from GT ----------
    pts_gt, face_ids = trimesh.sample.sample_surface(gt_mesh, num_samples)
    normals_gt = gt_mesh.face_normals[face_ids]

    # ---------- 2. find closest face on pred mesh---------
    closest_points, distance, face_id = p_mesh.nearest.on_surface(pts_gt)
    normals_pred = p_mesh.face_normals[face_id]

    # ---------- 3. normalize  ----------
    normals_gt = normals_gt / np.linalg.norm(normals_gt, axis=1, keepdims=True)
    normals_pred = normals_pred / np.linalg.norm(normals_pred, axis=1, keepdims=True)

    # ---------- 4. cosine similarity ----------
    cos_sim = np.abs(np.sum(normals_gt * normals_pred, axis=1))

    return float(cos_sim.mean())

# --------------------------- V_Ratio & F_Ratio --------------------------------------
def calculate_vertices_face_ratio(p_mesh, gt_mesh):
    """

    :param p_mesh: trimesh.Trimesh

    :param gt_mesh: trimesh.Trimesh

    :return: float, float

    """
    f_ratio = len(p_mesh.faces) / len(gt_mesh.faces)
    v_ratio = len(p_mesh.vertices) / len(gt_mesh.vertices)
    return v_ratio, f_ratio


# --------------------------- Mesh Evaluation For 3rd USM3D ----------------------------
def mesh_evaluation(p_file, gt_file):
    """

    :param p_file: the path of predicted mesh

    :param gt_file: the path of ground truth mesh

    :return: mesh_chamfer_distance

    """
    # --------------- Load Mesh using trimesh & normalization----------------
    p_mesh, gt_mesh = load_mesh(p_file, gt_file)
    p_mesh, gt_mesh = normalization(p_mesh, gt_mesh)

    # ----------------------- Mesh Chamfer Distance --------------------------
    p_points, gt_points = mesh_sample_points(p_mesh, gt_mesh)
    mesh_chamfer_distance = chamfer_l1_distance_kdtree(p_points, gt_points)

    # ---------------------- Edge Chamfer Distance ---------------------------
    edge_chamfer_distance = compute_edge_chamfer_distance(p_mesh, gt_mesh, angle_threshold_deg=30.0)

    # ---------------------- Normal Consistency --------------------------
    normals_consistency = normal_consistency(p_mesh, gt_mesh)

    # ---------------------- V_ratio & F_ratio ---------------------------
    v_ratio, f_ratio = calculate_vertices_face_ratio(p_mesh, gt_mesh)

    return mesh_chamfer_distance, edge_chamfer_distance, normals_consistency,  v_ratio, f_ratio


# if __name__ == '__main__':
#     p_file = r'./pred/1a_0.obj'
#     gt_file = r'./gt/1a_0.obj'
#     print(mesh_evaluation(p_file, gt_file))