Skip to content

Commit a932ee9

Browse files
committed
multilabel QUIRE speed up a bit
1 parent 145e131 commit a932ee9

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

libact/query_strategies/multilabel/multilabel_quire.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,17 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1.,
9797

9898
self.random_state_ = seed_random_state(random_state)
9999

100+
101+
_, lbled_Y = zip(*dataset.get_labeled_entries())
102+
n = len(X)
103+
m = np.shape(lbled_Y)[1]
104+
# label correlation matrix
105+
R = np.corrcoef(np.array(lbled_Y).T)
106+
R = np.nan_to_num(R)
107+
108+
self.L = lamba * (np.linalg.pinv(np.kron(R, self.K) \
109+
+ lamba * np.eye(n*m)))
110+
100111
@inherit_docstring_from(QueryStrategy)
101112
def make_query(self):
102113
dataset = self.dataset
@@ -105,42 +116,36 @@ def make_query(self):
105116

106117
X = np.array(X)
107118
K = self.K
108-
n = len(X)
119+
n_instance = len(X)
109120
m = np.shape(lbled_Y)[1]
110121
lamba = self.lamba
111122

112123
# index for labeled and unlabeled instance
113-
l = np.array([i for i in range(len(Y)) if Y[i] is not None])
114-
l = np.tile(l, m)
115-
u = np.array([i for i in range(len(Y)) if Y[i] is None])
116-
u = np.tile(u, m)
117-
118-
# label correlation matrix
119-
R = np.corrcoef(np.array(lbled_Y).T)
120-
R = np.nan_to_num(R)
121-
122-
L = lamba * (np.linalg.pinv(np.kron(R, K) + lamba * np.eye(n*m)))
123-
inv_L = np.linalg.pinv(L)
124-
125-
vecY = np.reshape(np.array([y for y in Y if y is not None]), (-1, 1))
126-
invLuu = np.linalg.pinv(L[np.ix_(u, u)])
127-
128-
score = np.zeros((n, m))
129-
for a in range(n):
130-
for b in range(m):
131-
s = b*n + a
132-
U = np.dot(L[np.ix_(u, l)], vecY) + L[np.ix_(u, [s])]
133-
temp1 = 2 * np.dot(L[[s], l], vecY) \
134-
- np.dot(np.dot(U.T, invLuu), U)
135-
U = np.dot(L[np.ix_(u, l)], vecY)
136-
temp0 = -(np.dot(np.dot(U.T, invLuu), U))
137-
score[a, b] = L[s, s] \
138-
+ np.dot(np.dot(vecY.T, L[np.ix_(l, l)]),
139-
vecY)[0, 0]\
140-
+ np.max((temp1[0, 0], temp0[0, 0]))
141-
142-
score = np.sum(score, axis=1)
143-
144-
ask_id = self.random_state_.choice(np.where(score == np.min(score))[0])
145-
146-
return ask_id
124+
l_id = []
125+
a_id = []
126+
for i in range(n_instance * m):
127+
if Y[i%n_instance] is None:
128+
a_id.append(i)
129+
else:
130+
l_id.append(i)
131+
132+
L = self.L
133+
vecY = np.reshape(np.array([y for y in Y if y is not None]).T, (-1, 1))
134+
detLaa = np.linalg.det(L[np.ix_(a_id, a_id)])
135+
136+
score = []
137+
for i, s in enumerate(a_id):
138+
u_id = a_id[:i] + a_id[i+1:]
139+
invLuu = L[np.ix_(u_id, u_id)] \
140+
- 1./L[s, s] * np.dot(L[u_id, s], L[u_id, s].T)
141+
score.append(L[s, s] - detLaa / L[s, s] \
142+
+ 2 * np.abs(np.dot(L[np.ix_([s], l_id)] \
143+
- np.dot(np.dot(L[s, u_id], invLuu),
144+
L[np.ix_(u_id, l_id)]), vecY))[0][0])
145+
146+
import ipdb; ipdb.set_trace()
147+
score = np.sum(np.array(score).reshape(m, -1).T, axis=1)
148+
149+
ask_idx = self.random_state_.choice(np.where(score == np.min(score))[0])
150+
151+
return a_id[ask_idx]

0 commit comments

Comments
 (0)