将下列代码转换成matlab代码%update(self,iiter,H,Y,eta,loss): %"""Update the trace_var in new iteration""" if iiter <= self.niter_trace+1: self.H[iiter] = H self.Y[iiter] = Y elif iiter >self.niter - self.niter_trace + 1: self.H[self.ltrace+iiter-self.niter-1] = H self.Y[self.ltrace+iiter-self.niter-1] = Y self.etas[iiter] = eta self.loss[iiter] = loss if self.loss[iiter] < self.lmin: self.Yh = Y self.lmin = self.loss[iiter] self.miniter = iiter if not iiter == -1 else self.niter + 1 end end end end
时间: 2023-07-14 18:14:16 浏览: 55
function obj = update(obj, iiter, H, Y, eta, loss)
% Update the trace_var in new iteration
if iiter <= obj.niter_trace + 1
obj.H{iiter} = H;
obj.Y{iiter} = Y;
elseif iiter > obj.niter - obj.niter_trace + 1
obj.H{obj.ltrace + iiter - obj.niter - 1} = H;
obj.Y{obj.ltrace + iiter - obj.niter - 1} = Y;
end
obj.etas(iiter) = eta;
obj.loss(iiter) = loss;
if obj.loss(iiter) < obj.lmin
obj.Yh = Y;
obj.lmin = obj.loss(iiter);
obj.miniter = iiter;
end
end
相关问题
转成matlab:self.eta = self.eta_max if kwargs.get('first_iter',False) and not self.linesearch_first: self.eta = kwargs.get('eta_first',1) loss_diff = 1 while loss_diff > 0: loss_diff, temp_embedding, delta = self._linesearch_once( update_embedding_with,grad,calc_loss,loss,**kwargs) if self.eta <= self.eta_min and loss_diff > 0: loss_diff, temp_embedding, delta = self._linesearch_once( update_embedding_with,grad,calc_loss,loss,**kwargs) loss_diff = -1 self.eta *= 2 update_embedding_with(new_embedding=temp_embedding) return delta
self.eta = self.eta_max;
if isfield(kwargs,'first_iter') && kwargs.first_iter && ~self.linesearch_first
self.eta = kwargs.eta_first;
end
loss_diff = 1;
while loss_diff > 0
[loss_diff, temp_embedding, delta] = self._linesearch_once(update_embedding_with,grad,calc_loss,loss,kwargs);
if self.eta <= self.eta_min && loss_diff > 0
[loss_diff, temp_embedding, delta] = self._linesearch_once(update_embedding_with,grad,calc_loss,loss,kwargs);
loss_diff = -1;
end
end
self.eta = self.eta * 2;
update_embedding_with('new_embedding',temp_embedding);
end_return = delta;
转成matlab: def _apply_linesearch_optimzation(self, update_embedding_with, grad, calc_loss, loss, **kwargs): self.eta = self.eta_max if kwargs.get('first_iter',False) and not self.linesearch_first: self.eta = kwargs.get('eta_first',1) loss_diff = 1 while loss_diff > 0: loss_diff, temp_embedding, delta = self._linesearch_once( update_embedding_with,grad,calc_loss,loss,**kwargs) if self.eta <= self.eta_min and loss_diff > 0: loss_diff, temp_embedding, delta = self._linesearch_once( update_embedding_with,grad,calc_loss,loss,**kwargs) loss_diff = -1 self.eta *= 2 update_embedding_with(new_embedding=temp_embedding) return delta def _linesearch_once(self, update_embedding_with, grad, calc_loss, loss, **kwargs): delta = self._calc_delta(grad) temp_embedding = update_embedding_with(delta=delta,copy=True) loss_diff = calc_loss(temp_embedding) - loss self.eta /= 2 return loss_diff, temp_embedding, delta
function delta = _apply_linesearch_optimzation(self, update_embedding_with, grad, calc_loss, loss, varargin)
eta = self.eta_max;
if nargin > 4 && varargin{1} && ~self.linesearch_first
eta = varargin{2};
end
loss_diff = 1;
while loss_diff > 0
[loss_diff, temp_embedding, delta] = self._linesearch_once(update_embedding_with, grad, calc_loss, loss, varargin{:});
if eta <= self.eta_min && loss_diff > 0
[loss_diff, temp_embedding, delta] = self._linesearch_once(update_embedding_with, grad, calc_loss, loss, varargin{:});
loss_diff = -1;
end
end
eta = eta * 2;
update_embedding_with('new_embedding', temp_embedding);
end
function [loss_diff, temp_embedding, delta] = _linesearch_once(self, update_embedding_with, grad, calc_loss, loss, varargin)
delta = self._calc_delta(grad);
temp_embedding = update_embedding_with('delta', delta, 'copy', true);
loss_diff = calc_loss(temp_embedding) - loss;
self.eta = self.eta / 2;
end