Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F112917163
mexProximalPathCoding.cpp
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Tue, May 13, 16:37
Size
5 KB
Mime Type
text/x-c++
Expires
Thu, May 15, 16:37 (2 d)
Engine
blob
Format
Raw Data
Handle
26140932
Attached To
R1908 Research Scripts (Thomas Bolton)
mexProximalPathCoding.cpp
View Options
/* Software SPAMS v2.3 - Copyright 2009-2011 Julien Mairal
*
* This file is part of SPAMS.
*
* SPAMS is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* SPAMS is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with SPAMS. If not, see <http://www.gnu.org/licenses/>.
*/
#include <mex.h>
#include <mexutils.h>
#include <fista.h>
using namespace FISTA;
template <typename T>
inline void callFunction(mxArray* plhs[], const mxArray*prhs[],
const int nlhs) {
if (!mexCheckType<T>(prhs[0]))
mexErrMsgTxt("type of argument 1 is not consistent");
if (mxIsSparse(prhs[0]))
mexErrMsgTxt("argument 1 should not be sparse");
if (!mxIsStruct(prhs[1]))
mexErrMsgTxt("argument 2 should be struct");
if (!mxIsStruct(prhs[2]))
mexErrMsgTxt("argument 3 should be struct");
T* pr_alpha0 = reinterpret_cast<T*>(mxGetPr(prhs[0]));
const mwSize* dimsAlpha=mxGetDimensions(prhs[0]);
int pAlpha=static_cast<int>(dimsAlpha[0]);
int nAlpha=static_cast<int>(dimsAlpha[1]);
Matrix<T> alpha0(pr_alpha0,pAlpha,nAlpha);
mxArray* ppr_GG = mxGetField(prhs[1],0,"weights");
if (!mxIsSparse(ppr_GG))
mexErrMsgTxt("field groups should be sparse");
T* graph_weights = reinterpret_cast<T*>(mxGetPr(ppr_GG));
mwSize* GG_r=mxGetIr(ppr_GG);
mwSize* GG_pB=mxGetJc(ppr_GG);
const mwSize* dims_GG=mxGetDimensions(ppr_GG);
int GGm=static_cast<int>(dims_GG[0]);
int GGn=static_cast<int>(dims_GG[1]);
if (GGm != GGn || GGm != pAlpha)
mexErrMsgTxt("size of field groups is not consistent");
mxArray* ppr_weights = mxGetField(prhs[1],0,"start_weights");
if (mxIsSparse(ppr_weights))
mexErrMsgTxt("field start_weights should not be sparse");
T* start_weights = reinterpret_cast<T*>(mxGetPr(ppr_weights));
const mwSize* dims_weights=mxGetDimensions(ppr_weights);
int nweights=static_cast<int>(dims_weights[0])*static_cast<int>(dims_weights[1]);
if (nweights != pAlpha)
mexErrMsgTxt("size of field start_weights is not consistent");
mxArray* ppr_weights2 = mxGetField(prhs[1],0,"stop_weights");
if (mxIsSparse(ppr_weights2))
mexErrMsgTxt("field stop_weights should not be sparse");
T* stop_weights = reinterpret_cast<T*>(mxGetPr(ppr_weights2));
const mwSize* dims_weights2=mxGetDimensions(ppr_weights2);
int nweights2=static_cast<int>(dims_weights2[0])*static_cast<int>(dims_weights2[1]);
if (nweights2 != pAlpha)
mexErrMsgTxt("size of field stop_weights is not consistent");
plhs[0]=createMatrix<T>(pAlpha,nAlpha);
T* pr_alpha=reinterpret_cast<T*>(mxGetPr(plhs[0]));
Matrix<T> alpha(pr_alpha,pAlpha,nAlpha);
FISTA::ParamFISTA<T> param;
param.num_threads = getScalarStructDef<int>(prhs[2],"numThreads",-1);
param.pos = getScalarStructDef<bool>(prhs[2],"pos",false);
param.lambda= getScalarStructDef<T>(prhs[2],"lambda",T(1.0));
param.lambda2= getScalarStructDef<T>(prhs[2],"lambda2",0);
getStringStruct(prhs[2],"regul",param.name_regul,param.length_names);
param.regul = regul_from_string(param.name_regul);
if (param.regul==INCORRECT_REG)
mexErrMsgTxt("Unknown regularization");
param.intercept = getScalarStructDef<bool>(prhs[2],"intercept",false);
param.verbose = getScalarStructDef<bool>(prhs[2],"verbose",false);
param.transpose = getScalarStructDef<bool>(prhs[2],"transpose",false);
param.eval = nlhs==2;
if (param.regul != GRAPH_PATH_L0 && param.regul != GRAPH_PATH_CONV)
mexErrMsgTxt("Use a different mexProximal* function");
if (param.num_threads == -1) {
param.num_threads=1;
#ifdef _OPENMP
param.num_threads = MIN(MAX_THREADS,omp_get_num_procs());
#endif
}
GraphPathStruct<T> graph;
graph.n=pAlpha;
graph.m=GG_pB[graph.n]-GG_pB[0];
graph.weights=graph_weights;
graph.start_weights=start_weights;
graph.stop_weights=stop_weights;
graph.ir=GG_r;
graph.jc=GG_pB;
graph.precision = getScalarStructDef<long long>(prhs[2],"precision",100000000000000000);
Vector<T> val;
FISTA::PROX<T>(alpha0,alpha,param,val,NULL,NULL,&graph);
if (nlhs==2) {
plhs[1]=createMatrix<T>(1,val.n());
T* pr_val=reinterpret_cast<T*>(mxGetPr(plhs[1]));
for (int i = 0; i<val.n(); ++i) pr_val[i]=val[i];
}
}
void mexFunction(int nlhs, mxArray *plhs[],int nrhs, const mxArray *prhs[]) {
if (nrhs != 3)
mexErrMsgTxt("Bad number of inputs arguments");
if (nlhs != 1 && nlhs != 2)
mexErrMsgTxt("Bad number of output arguments");
if (mxGetClassID(prhs[0]) == mxDOUBLE_CLASS) {
callFunction<double>(plhs,prhs,nlhs);
} else {
callFunction<float>(plhs,prhs,nlhs);
}
}
Event Timeline
Log In to Comment