some new features
This commit is contained in:
@ -0,0 +1,38 @@
|
||||
# Author: Mathieu Blondel, Tom Dupre la Tour
|
||||
# License: BSD 3 clause
|
||||
|
||||
from cython cimport floating
|
||||
from libc.math cimport fabs
|
||||
|
||||
|
||||
def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt,
|
||||
floating[:, :] XHt, Py_ssize_t[::1] permutation):
|
||||
cdef:
|
||||
floating violation = 0
|
||||
Py_ssize_t n_components = W.shape[1]
|
||||
Py_ssize_t n_samples = W.shape[0] # n_features for H update
|
||||
floating grad, pg, hess
|
||||
Py_ssize_t i, r, s, t
|
||||
|
||||
with nogil:
|
||||
for s in range(n_components):
|
||||
t = permutation[s]
|
||||
|
||||
for i in range(n_samples):
|
||||
# gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt
|
||||
grad = -XHt[i, t]
|
||||
|
||||
for r in range(n_components):
|
||||
grad += HHt[t, r] * W[i, r]
|
||||
|
||||
# projected gradient
|
||||
pg = min(0., grad) if W[i, t] == 0 else grad
|
||||
violation += fabs(pg)
|
||||
|
||||
# Hessian
|
||||
hess = HHt[t, t]
|
||||
|
||||
if hess != 0:
|
||||
W[i, t] = max(W[i, t] - grad / hess, 0.)
|
||||
|
||||
return violation
|
||||
Reference in New Issue
Block a user