Skip to content

Commit c789f05

Browse files
committed
Added support for 3D boxes with function weighted_boxes_fusion_3d
1 parent 3d40923 commit c789f05

File tree

5 files changed

+370
-1
lines changed

5 files changed

+370
-1
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ boxes, scores, labels = weighted_boxes_fusion([boxes_list], [scores_list], [labe
7272

7373
More examples can be found in [example.py](./example.py)
7474

75+
#### 3D version
76+
77+
There is support for 3D boxes in WBF method with `weighted_boxes_fusion_3d` function. Check example of usage in [example_3d.py](./example_3d.py)
78+
7579
## Accuracy and speed comparison
7680

7781
Comparison was made for ensemble of 5 different object detection models predictions trained on [Open Images Dataset](https://storage.googleapis.com/openimages/web/index.html) (500 classes).

ensemble_boxes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .ensemble_boxes_nms import nms_method
77
from .ensemble_boxes_nms import nms
88
from .ensemble_boxes_nms import soft_nms
9+
from .ensemble_boxes_wbf_3d import weighted_boxes_fusion_3d
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# coding: utf-8
2+
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
3+
4+
5+
import warnings
6+
import numpy as np
7+
from numba import jit
8+
9+
10+
@jit(nopython=True)
11+
def bb_intersection_over_union_3d(A, B) -> float:
12+
xA = max(A[0], B[0])
13+
yA = max(A[1], B[1])
14+
zA = max(A[2], B[2])
15+
xB = min(A[3], B[3])
16+
yB = min(A[4], B[4])
17+
zB = min(A[5], B[5])
18+
19+
interVol = max(0, xB - xA) * max(0, yB - yA) * max(0, zB - zA)
20+
if interVol == 0:
21+
return 0.0
22+
23+
# compute the volume of both the prediction and ground-truth rectangular boxes
24+
boxAVol = (A[3] - A[0]) * (A[4] - A[1]) * (A[5] - A[2])
25+
boxBVol = (B[3] - B[0]) * (B[4] - B[1]) * (B[5] - B[2])
26+
27+
iou = interVol / float(boxAVol + boxBVol - interVol)
28+
return iou
29+
30+
31+
def prefilter_boxes(boxes, scores, labels, weights, thr):
32+
# Create dict with boxes stored by its label
33+
new_boxes = dict()
34+
35+
for t in range(len(boxes)):
36+
37+
if len(boxes[t]) != len(scores[t]):
38+
print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t])))
39+
exit()
40+
41+
if len(boxes[t]) != len(labels[t]):
42+
print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t])))
43+
exit()
44+
45+
for j in range(len(boxes[t])):
46+
score = scores[t][j]
47+
if score < thr:
48+
continue
49+
label = int(labels[t][j])
50+
box_part = boxes[t][j]
51+
x1 = float(box_part[0])
52+
y1 = float(box_part[1])
53+
z1 = float(box_part[2])
54+
x2 = float(box_part[3])
55+
y2 = float(box_part[4])
56+
z2 = float(box_part[5])
57+
58+
# Box data checks
59+
if x2 < x1:
60+
warnings.warn('X2 < X1 value in box. Swap them.')
61+
x1, x2 = x2, x1
62+
if y2 < y1:
63+
warnings.warn('Y2 < Y1 value in box. Swap them.')
64+
y1, y2 = y2, y1
65+
if z2 < z1:
66+
warnings.warn('Z2 < Z1 value in box. Swap them.')
67+
z1, z2 = z2, z1
68+
if x1 < 0:
69+
warnings.warn('X1 < 0 in box. Set it to 0.')
70+
x1 = 0
71+
if x1 > 1:
72+
warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
73+
x1 = 1
74+
if x2 < 0:
75+
warnings.warn('X2 < 0 in box. Set it to 0.')
76+
x2 = 0
77+
if x2 > 1:
78+
warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
79+
x2 = 1
80+
if y1 < 0:
81+
warnings.warn('Y1 < 0 in box. Set it to 0.')
82+
y1 = 0
83+
if y1 > 1:
84+
warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
85+
y1 = 1
86+
if y2 < 0:
87+
warnings.warn('Y2 < 0 in box. Set it to 0.')
88+
y2 = 0
89+
if y2 > 1:
90+
warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
91+
y2 = 1
92+
if z1 < 0:
93+
warnings.warn('Z1 < 0 in box. Set it to 0.')
94+
z1 = 0
95+
if z1 > 1:
96+
warnings.warn('Z1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
97+
z1 = 1
98+
if z2 < 0:
99+
warnings.warn('Z2 < 0 in box. Set it to 0.')
100+
z2 = 0
101+
if z2 > 1:
102+
warnings.warn('Z2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
103+
z2 = 1
104+
if (x2 - x1) * (y2 - y1) * (z2 - z1) == 0.0:
105+
warnings.warn("Zero volume box skipped: {}.".format(box_part))
106+
continue
107+
108+
b = [int(label), float(score) * weights[t], x1, y1, z1, x2, y2, z2]
109+
if label not in new_boxes:
110+
new_boxes[label] = []
111+
new_boxes[label].append(b)
112+
113+
# Sort each list in dict by score and transform it to numpy array
114+
for k in new_boxes:
115+
current_boxes = np.array(new_boxes[k])
116+
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
117+
118+
return new_boxes
119+
120+
121+
def get_weighted_box(boxes, conf_type='avg'):
122+
"""
123+
Create weighted box for set of boxes
124+
:param boxes: set of boxes to fuse
125+
:param conf_type: type of confidence one of 'avg' or 'max'
126+
:return: weighted box
127+
"""
128+
129+
box = np.zeros(8, dtype=np.float32)
130+
conf = 0
131+
conf_list = []
132+
for b in boxes:
133+
box[2:] += (b[1] * b[2:])
134+
conf += b[1]
135+
conf_list.append(b[1])
136+
box[0] = boxes[0][0]
137+
if conf_type == 'avg':
138+
box[1] = conf / len(boxes)
139+
elif conf_type == 'max':
140+
box[1] = np.array(conf_list).max()
141+
box[2:] /= conf
142+
return box
143+
144+
145+
def find_matching_box(boxes_list, new_box, match_iou):
146+
best_iou = match_iou
147+
best_index = -1
148+
for i in range(len(boxes_list)):
149+
box = boxes_list[i]
150+
if box[0] != new_box[0]:
151+
continue
152+
iou = bb_intersection_over_union_3d(box[2:], new_box[2:])
153+
if iou > best_iou:
154+
best_index = i
155+
best_iou = iou
156+
157+
return best_index, best_iou
158+
159+
160+
def weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
161+
'''
162+
:param boxes_list: list of boxes predictions from each model, each box is 6 numbers.
163+
It has 3 dimensions (models_number, model_preds, 6)
164+
Order of boxes: x1, y1, z1, x2, y2 z2. We expect float normalized coordinates [0; 1]
165+
:param scores_list: list of scores for each model
166+
:param labels_list: list of labels for each model
167+
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model
168+
:param iou_thr: IoU value for boxes to be a match
169+
:param skip_box_thr: exclude boxes with score lower than this variable
170+
:param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value
171+
:param allows_overflow: false if we want confidence score not exceed 1.0
172+
173+
:return: boxes: boxes coordinates (Order of boxes: x1, y1, z1, x2, y2, z2).
174+
:return: scores: confidence scores
175+
:return: labels: boxes labels
176+
'''
177+
178+
if weights is None:
179+
weights = np.ones(len(boxes_list))
180+
if len(weights) != len(boxes_list):
181+
print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
182+
weights = np.ones(len(boxes_list))
183+
weights = np.array(weights)
184+
185+
if conf_type not in ['avg', 'max']:
186+
print('Error. Unknown conf_type: {}. Must be "avg" or "max". Use "avg"'.format(conf_type))
187+
conf_type = 'avg'
188+
189+
filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
190+
if len(filtered_boxes) == 0:
191+
return np.zeros((0, 6)), np.zeros((0,)), np.zeros((0,))
192+
193+
overall_boxes = []
194+
for label in filtered_boxes:
195+
boxes = filtered_boxes[label]
196+
new_boxes = []
197+
weighted_boxes = []
198+
199+
# Clusterize boxes
200+
for j in range(0, len(boxes)):
201+
index, best_iou = find_matching_box(weighted_boxes, boxes[j], iou_thr)
202+
if index != -1:
203+
new_boxes[index].append(boxes[j])
204+
weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
205+
else:
206+
new_boxes.append([boxes[j].copy()])
207+
weighted_boxes.append(boxes[j].copy())
208+
209+
# Rescale confidence based on number of models and boxes
210+
for i in range(len(new_boxes)):
211+
if not allows_overflow:
212+
weighted_boxes[i][1] = weighted_boxes[i][1] * min(weights.sum(), len(new_boxes[i])) / weights.sum()
213+
else:
214+
weighted_boxes[i][1] = weighted_boxes[i][1] * len(new_boxes[i]) / weights.sum()
215+
overall_boxes.append(np.array(weighted_boxes))
216+
217+
overall_boxes = np.concatenate(overall_boxes, axis=0)
218+
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
219+
boxes = overall_boxes[:, 2:]
220+
scores = overall_boxes[:, 1]
221+
labels = overall_boxes[:, 0]
222+
return boxes, scores, labels

example_3d.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# coding: utf-8
2+
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
3+
4+
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8+
from ensemble_boxes import *
9+
10+
11+
def plot_cube(ax, cube_definition, lbl, thickness):
12+
cube_definition_array = [
13+
np.array(list(item))
14+
for item in cube_definition
15+
]
16+
17+
points = []
18+
points += cube_definition_array
19+
vectors = [
20+
cube_definition_array[1] - cube_definition_array[0],
21+
cube_definition_array[2] - cube_definition_array[0],
22+
cube_definition_array[3] - cube_definition_array[0]
23+
]
24+
25+
points += [cube_definition_array[0] + vectors[0] + vectors[1]]
26+
points += [cube_definition_array[0] + vectors[0] + vectors[2]]
27+
points += [cube_definition_array[0] + vectors[1] + vectors[2]]
28+
points += [cube_definition_array[0] + vectors[0] + vectors[1] + vectors[2]]
29+
30+
points = np.array(points)
31+
32+
edges = [
33+
[points[0], points[3], points[5], points[1]],
34+
[points[1], points[5], points[7], points[4]],
35+
[points[4], points[2], points[6], points[7]],
36+
[points[2], points[6], points[3], points[0]],
37+
[points[0], points[2], points[4], points[1]],
38+
[points[3], points[6], points[7], points[5]]
39+
]
40+
41+
faces = Poly3DCollection(edges, linewidths=thickness + 1)
42+
if lbl == 0:
43+
faces.set_edgecolor((1, 0, 0))
44+
else:
45+
faces.set_edgecolor((0, 0, 1))
46+
faces.set_facecolor((0, 0, 1, 0.1))
47+
48+
ax.add_collection3d(faces)
49+
ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=0)
50+
51+
52+
def show_boxes(boxes_list, scores_list, labels_list, image_size=800):
53+
image = np.zeros((image_size, image_size, 3), dtype=np.uint8)
54+
image[...] = 255
55+
fig = plt.figure()
56+
ax = fig.add_subplot(111, projection='3d')
57+
58+
for i in range(len(boxes_list)):
59+
for j in range(len(boxes_list[i])):
60+
x1 = int(image_size * boxes_list[i][j][0])
61+
y1 = int(image_size * boxes_list[i][j][1])
62+
z1 = int(image_size * boxes_list[i][j][2])
63+
x2 = int(image_size * boxes_list[i][j][3])
64+
y2 = int(image_size * boxes_list[i][j][4])
65+
z2 = int(image_size * boxes_list[i][j][5])
66+
lbl = labels_list[i][j]
67+
cube_definition = [
68+
(x1, y1, z1), (x1, y2, z1), (x2, y1, z1), (x1, y1, z2)
69+
]
70+
plot_cube(ax, cube_definition, lbl, int(4 * scores_list[i][j]))
71+
72+
plt.show()
73+
74+
75+
def example_wbf_3d_2_models(iou_thr=0.55, draw_image=True):
76+
"""
77+
This example shows how to ensemble boxes from 2 models using WBF_3D method
78+
:return:
79+
"""
80+
81+
boxes_list = [
82+
[
83+
[0.00, 0.51, 0.41, 0.81, 0.91, 0.78],
84+
[0.10, 0.31, 0.45, 0.71, 0.61, 0.85],
85+
[0.01, 0.32, 0.55, 0.83, 0.93, 0.95],
86+
[0.02, 0.53, 0.11, 0.11, 0.94, 0.55],
87+
[0.03, 0.24, 0.34, 0.12, 0.35, 0.67],
88+
],
89+
[
90+
[0.04, 0.56, 0.36, 0.84, 0.92, 0.82],
91+
[0.12, 0.33, 0.46, 0.72, 0.64, 0.75],
92+
[0.38, 0.66, 0.55, 0.79, 0.95, 0.90],
93+
[0.08, 0.49, 0.15, 0.21, 0.89, 0.67],
94+
],
95+
]
96+
scores_list = [
97+
[
98+
0.9,
99+
0.8,
100+
0.2,
101+
0.4,
102+
0.7,
103+
],
104+
[
105+
0.5,
106+
0.8,
107+
0.7,
108+
0.3,
109+
]
110+
]
111+
labels_list = [
112+
[
113+
0,
114+
1,
115+
0,
116+
1,
117+
1,
118+
],
119+
[
120+
1,
121+
1,
122+
1,
123+
0,
124+
]
125+
]
126+
weights = [2, 1]
127+
if draw_image:
128+
show_boxes(boxes_list, scores_list, labels_list)
129+
130+
boxes, scores, labels = weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=0.0)
131+
132+
if draw_image:
133+
show_boxes([boxes], [scores], [labels.astype(np.int32)])
134+
135+
print(len(boxes))
136+
print(boxes)
137+
138+
139+
if __name__ == '__main__':
140+
draw_image = True
141+
example_wbf_3d_2_models(iou_thr=0.2, draw_image=draw_image)
142+

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='ensemble_boxes',
8-
version='1.0.3',
8+
version='1.0.4',
99
author='Roman Solovyev (ZFTurbo)',
1010
packages=['ensemble_boxes', ],
1111
url='https://github.com/ZFTurbo/Weighted-Boxes-Fusion',

0 commit comments

Comments
 (0)