Page MenuHomec4science

mexFistaPathCoding.cpp
No OneTemporary

File Metadata

Created
Tue, May 13, 20:37

mexFistaPathCoding.cpp

/* Software SPAMS v2.3 - Copyright 2009-2011 Julien Mairal
*
* This file is part of SPAMS.
*
* SPAMS is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* SPAMS is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with SPAMS. If not, see <http://www.gnu.org/licenses/>.
*/
#include <mex.h>
#include <mexutils.h>
#include <fista.h>
using namespace FISTA;
template <typename T>
inline void callFunction(mxArray* plhs[], const mxArray*prhs[],
const int nlhs) {
if (!mexCheckType<T>(prhs[0]))
mexErrMsgTxt("type of argument 1 is not consistent");
if (mxIsSparse(prhs[0]))
mexErrMsgTxt("argument 1 should not be sparse");
if (!mexCheckType<T>(prhs[1]))
mexErrMsgTxt("type of argument 2 is not consistent");
if (!mexCheckType<T>(prhs[2]))
mexErrMsgTxt("type of argument 3 is not consistent");
if (mxIsSparse(prhs[2]))
mexErrMsgTxt("argument 3 should not be sparse");
if (!mxIsStruct(prhs[3]))
mexErrMsgTxt("argument 4 should be struct");
if (!mxIsStruct(prhs[4]))
mexErrMsgTxt("argument 5 should be struct");
T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
const mwSize* dimsX=mxGetDimensions(prhs[0]);
int m=static_cast<int>(dimsX[0]);
int n=static_cast<int>(dimsX[1]);
Matrix<T> X(prX,m,n);
const mwSize* dimsD=mxGetDimensions(prhs[1]);
int mD=static_cast<int>(dimsD[0]);
int p=static_cast<int>(dimsD[1]);
AbstractMatrixB<T>* D;
double* D_v;
mwSize* D_r, *D_pB, *D_pE;
int* D_r2, *D_pB2, *D_pE2;
T* D_v2;
if (mxIsSparse(prhs[1])) {
D_v=static_cast<double*>(mxGetPr(prhs[1]));
D_r=mxGetIr(prhs[1]);
D_pB=mxGetJc(prhs[1]);
D_pE=D_pB+1;
createCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
D_v,D_r,D_pB,D_pE,p);
D = new SpMatrix<T>(D_v2,D_r2,D_pB2,D_pE2,mD,p,D_pB2[p]);
} else {
T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
D = new Matrix<T>(prD,m,p);
}
T* pr_alpha0 = reinterpret_cast<T*>(mxGetPr(prhs[2]));
const mwSize* dimsAlpha=mxGetDimensions(prhs[2]);
int pAlpha=static_cast<int>(dimsAlpha[0]);
int nAlpha=static_cast<int>(dimsAlpha[1]);
Matrix<T> alpha0(pr_alpha0,pAlpha,nAlpha);
mxArray* ppr_GG = mxGetField(prhs[3],0,"weights");
if (!mxIsSparse(ppr_GG))
mexErrMsgTxt("field weights should be sparse");
T* graph_weights = reinterpret_cast<T*>(mxGetPr(ppr_GG));
mwSize* GG_r=mxGetIr(ppr_GG);
mwSize* GG_pB=mxGetJc(ppr_GG);
const mwSize* dims_GG=mxGetDimensions(ppr_GG);
int GGm=static_cast<int>(dims_GG[0]);
int GGn=static_cast<int>(dims_GG[1]);
mxArray* ppr_weights = mxGetField(prhs[3],0,"start_weights");
if (mxIsSparse(ppr_weights))
mexErrMsgTxt("field start_weights should not be sparse");
T* start_weights = reinterpret_cast<T*>(mxGetPr(ppr_weights));
const mwSize* dims_weights=mxGetDimensions(ppr_weights);
int nweights=static_cast<int>(dims_weights[0])*static_cast<int>(dims_weights[1]);
mxArray* ppr_weights2 = mxGetField(prhs[3],0,"stop_weights");
if (mxIsSparse(ppr_weights2))
mexErrMsgTxt("field stop_weights should not be sparse");
T* stop_weights = reinterpret_cast<T*>(mxGetPr(ppr_weights2));
const mwSize* dims_weights2=mxGetDimensions(ppr_weights2);
int nweights2=static_cast<int>(dims_weights2[0])*static_cast<int>(dims_weights2[1]);
plhs[0]=createMatrix<T>(pAlpha,nAlpha);
T* pr_alpha=reinterpret_cast<T*>(mxGetPr(plhs[0]));
Matrix<T> alpha(pr_alpha,pAlpha,nAlpha);
FISTA::ParamFISTA<T> param;
param.num_threads = getScalarStructDef<int>(prhs[4],"numThreads",-1);
param.pos = getScalarStructDef<bool>(prhs[4],"pos",false);
param.max_it = getScalarStructDef<int>(prhs[4],"max_it",1000);
param.tol = getScalarStructDef<T>(prhs[4],"tol",0.000001);
param.it0 = getScalarStructDef<int>(prhs[4],"it0",100);
param.compute_gram = getScalarStructDef<bool>(prhs[4],"compute_gram",false);
param.max_iter_backtracking = getScalarStructDef<int>(prhs[4],"max_iter_backtracking",1000);
param.L0 = getScalarStructDef<T>(prhs[4],"L0",1.0);
param.gamma = MAX(1.01,getScalarStructDef<T>(prhs[4],"gamma",1.5));
param.c= getScalarStructDef<T>(prhs[4],"c",1.0);
param.lambda= getScalarStructDef<T>(prhs[4],"lambda",T(1.0));
getStringStruct(prhs[4],"regul",param.name_regul,param.length_names);
param.regul = regul_from_string(param.name_regul);
if (param.regul==INCORRECT_REG)
mexErrMsgTxt("Unknown regularization");
getStringStruct(prhs[4],"loss",param.name_loss,param.length_names);
param.loss = loss_from_string(param.name_loss);
if (param.loss==INCORRECT_LOSS)
mexErrMsgTxt("Unknown loss");
param.intercept = getScalarStructDef<bool>(prhs[4],"intercept",false);
param.verbose = getScalarStructDef<bool>(prhs[4],"verbose",false);
param.eval = nlhs==2;
param.delta = getScalarStructDef<T>(prhs[4],"delta",1.0);
param.lambda2= getScalarStructDef<T>(prhs[4],"lambda2",0.0);
param.lambda3= getScalarStructDef<T>(prhs[4],"lambda3",0.0);
param.size_group= getScalarStructDef<int>(prhs[4],"size_group",1);
param.admm = getScalarStructDef<bool>(prhs[4],"admm",false);
param.lin_admm = getScalarStructDef<bool>(prhs[4],"lin_admm",false);
param.sqrt_step = getScalarStructDef<bool>(prhs[4],"sqrt_step",true);
param.is_inner_weights = getScalarStructDef<bool>(prhs[4],"is_inner_weights",false);
param.intercept = getScalarStructDef<bool>(prhs[4],"intercept",false);
param.resetflow = getScalarStructDef<bool>(prhs[4],"resetflow",false);
param.verbose = getScalarStructDef<bool>(prhs[4],"verbose",false);
param.clever = getScalarStructDef<bool>(prhs[4],"clever",false);
param.ista= getScalarStructDef<bool>(prhs[4],"ista",false);
param.subgrad= getScalarStructDef<bool>(prhs[4],"subgrad",false);
param.transpose = getScalarStructDef<bool>(prhs[4],"transpose",false);
param.log= getScalarStructDef<bool>(prhs[4],"log",false);
param.a= getScalarStructDef<T>(prhs[4],"a",T(1.0));
param.b= getScalarStructDef<T>(prhs[4],"b",0);
if (param.transpose) {
if (GGm != GGn || GGm != nAlpha)
mexErrMsgTxt("size of field weights is not consistent");
if (nweights != nAlpha)
mexErrMsgTxt("size of field start_weights is not consistent");
if (nweights2 != nAlpha)
mexErrMsgTxt("size of field stop_weights is not consistent");
} else {
if (GGm != GGn || GGm != pAlpha)
mexErrMsgTxt("size of field weights is not consistent");
if (nweights != pAlpha)
mexErrMsgTxt("size of field start_weights is not consistent");
if (nweights2 != pAlpha)
mexErrMsgTxt("size of field stop_weights is not consistent");
}
if (param.log) {
mxArray *stringData = mxGetField(prhs[4],0,"logName");
if (!stringData)
mexErrMsgTxt("Missing field logName");
int stringLength = mxGetN(stringData)+1;
param.logName= new char[stringLength];
mxGetString(stringData,param.logName,stringLength);
}
if (param.regul==GRAPH || param.regul==GRAPHMULT)
mexErrMsgTxt("Error: mexFistaGraph should be used instead");
if (param.regul==TREE_L0 || param.regul==TREEMULT || param.regul==TREE_L2 || param.regul==TREE_LINF)
mexErrMsgTxt("Error: mexFistaTree should be used instead");
if (param.num_threads == -1) {
param.num_threads=1;
#ifdef _OPENMP
param.num_threads = MIN(MAX_THREADS,omp_get_num_procs());
#endif
}
GraphPathStruct<T> graph;
graph.n=param.transpose ? nAlpha : pAlpha;
graph.m=GG_pB[graph.n]-GG_pB[0];
graph.weights=graph_weights;
graph.start_weights=start_weights;
graph.stop_weights=stop_weights;
graph.ir=GG_r;
graph.jc=GG_pB;
graph.precision = getScalarStructDef<long long>(prhs[4],"precision",10000000000);
Matrix<T> duality_gap;
FISTA::solver<T>(X,*D,alpha0,alpha,param,duality_gap,NULL,NULL,&graph);
if (nlhs==2) {
plhs[1]=createMatrix<T>(duality_gap.m(),duality_gap.n());
T* pr_dualitygap=reinterpret_cast<T*>(mxGetPr(plhs[1]));
for (int i = 0; i<duality_gap.n()*duality_gap.m(); ++i) pr_dualitygap[i]=duality_gap[i];
}
if (param.logName) delete[](param.logName);
if (mxIsSparse(prhs[1])) {
deleteCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
D_v,D_r);
}
delete(D);
}
void mexFunction(int nlhs, mxArray *plhs[],int nrhs, const mxArray *prhs[]) {
if (nrhs != 5)
mexErrMsgTxt("Bad number of inputs arguments");
if (nlhs != 1 && nlhs != 2)
mexErrMsgTxt("Bad number of output arguments");
if (mxGetClassID(prhs[0]) == mxDOUBLE_CLASS) {
callFunction<double>(plhs,prhs,nlhs);
} else {
callFunction<float>(plhs,prhs,nlhs);
}
}

Event Timeline