Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F121755966
mexFistaPathCoding.cpp
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Jul 13, 16:06
Size
8 KB
Mime Type
text/x-c++
Expires
Tue, Jul 15, 16:06 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
27385478
Attached To
R1908 Research Scripts (Thomas Bolton)
mexFistaPathCoding.cpp
View Options
/* Software SPAMS v2.3 - Copyright 2009-2011 Julien Mairal
*
* This file is part of SPAMS.
*
* SPAMS is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* SPAMS is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with SPAMS. If not, see <http://www.gnu.org/licenses/>.
*/
#include <mex.h>
#include <mexutils.h>
#include <fista.h>
using
namespace
FISTA
;
template
<
typename
T
>
inline
void
callFunction
(
mxArray
*
plhs
[],
const
mxArray
*
prhs
[],
const
int
nlhs
)
{
if
(
!
mexCheckType
<
T
>
(
prhs
[
0
]))
mexErrMsgTxt
(
"type of argument 1 is not consistent"
);
if
(
mxIsSparse
(
prhs
[
0
]))
mexErrMsgTxt
(
"argument 1 should not be sparse"
);
if
(
!
mexCheckType
<
T
>
(
prhs
[
1
]))
mexErrMsgTxt
(
"type of argument 2 is not consistent"
);
if
(
!
mexCheckType
<
T
>
(
prhs
[
2
]))
mexErrMsgTxt
(
"type of argument 3 is not consistent"
);
if
(
mxIsSparse
(
prhs
[
2
]))
mexErrMsgTxt
(
"argument 3 should not be sparse"
);
if
(
!
mxIsStruct
(
prhs
[
3
]))
mexErrMsgTxt
(
"argument 4 should be struct"
);
if
(
!
mxIsStruct
(
prhs
[
4
]))
mexErrMsgTxt
(
"argument 5 should be struct"
);
T
*
prX
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
prhs
[
0
]));
const
mwSize
*
dimsX
=
mxGetDimensions
(
prhs
[
0
]);
int
m
=
static_cast
<
int
>
(
dimsX
[
0
]);
int
n
=
static_cast
<
int
>
(
dimsX
[
1
]);
Matrix
<
T
>
X
(
prX
,
m
,
n
);
const
mwSize
*
dimsD
=
mxGetDimensions
(
prhs
[
1
]);
int
mD
=
static_cast
<
int
>
(
dimsD
[
0
]);
int
p
=
static_cast
<
int
>
(
dimsD
[
1
]);
AbstractMatrixB
<
T
>*
D
;
double
*
D_v
;
mwSize
*
D_r
,
*
D_pB
,
*
D_pE
;
int
*
D_r2
,
*
D_pB2
,
*
D_pE2
;
T
*
D_v2
;
if
(
mxIsSparse
(
prhs
[
1
]))
{
D_v
=
static_cast
<
double
*>
(
mxGetPr
(
prhs
[
1
]));
D_r
=
mxGetIr
(
prhs
[
1
]);
D_pB
=
mxGetJc
(
prhs
[
1
]);
D_pE
=
D_pB
+
1
;
createCopySparse
<
T
>
(
D_v2
,
D_r2
,
D_pB2
,
D_pE2
,
D_v
,
D_r
,
D_pB
,
D_pE
,
p
);
D
=
new
SpMatrix
<
T
>
(
D_v2
,
D_r2
,
D_pB2
,
D_pE2
,
mD
,
p
,
D_pB2
[
p
]);
}
else
{
T
*
prD
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
prhs
[
1
]));
D
=
new
Matrix
<
T
>
(
prD
,
m
,
p
);
}
T
*
pr_alpha0
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
prhs
[
2
]));
const
mwSize
*
dimsAlpha
=
mxGetDimensions
(
prhs
[
2
]);
int
pAlpha
=
static_cast
<
int
>
(
dimsAlpha
[
0
]);
int
nAlpha
=
static_cast
<
int
>
(
dimsAlpha
[
1
]);
Matrix
<
T
>
alpha0
(
pr_alpha0
,
pAlpha
,
nAlpha
);
mxArray
*
ppr_GG
=
mxGetField
(
prhs
[
3
],
0
,
"weights"
);
if
(
!
mxIsSparse
(
ppr_GG
))
mexErrMsgTxt
(
"field weights should be sparse"
);
T
*
graph_weights
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
ppr_GG
));
mwSize
*
GG_r
=
mxGetIr
(
ppr_GG
);
mwSize
*
GG_pB
=
mxGetJc
(
ppr_GG
);
const
mwSize
*
dims_GG
=
mxGetDimensions
(
ppr_GG
);
int
GGm
=
static_cast
<
int
>
(
dims_GG
[
0
]);
int
GGn
=
static_cast
<
int
>
(
dims_GG
[
1
]);
mxArray
*
ppr_weights
=
mxGetField
(
prhs
[
3
],
0
,
"start_weights"
);
if
(
mxIsSparse
(
ppr_weights
))
mexErrMsgTxt
(
"field start_weights should not be sparse"
);
T
*
start_weights
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
ppr_weights
));
const
mwSize
*
dims_weights
=
mxGetDimensions
(
ppr_weights
);
int
nweights
=
static_cast
<
int
>
(
dims_weights
[
0
])
*
static_cast
<
int
>
(
dims_weights
[
1
]);
mxArray
*
ppr_weights2
=
mxGetField
(
prhs
[
3
],
0
,
"stop_weights"
);
if
(
mxIsSparse
(
ppr_weights2
))
mexErrMsgTxt
(
"field stop_weights should not be sparse"
);
T
*
stop_weights
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
ppr_weights2
));
const
mwSize
*
dims_weights2
=
mxGetDimensions
(
ppr_weights2
);
int
nweights2
=
static_cast
<
int
>
(
dims_weights2
[
0
])
*
static_cast
<
int
>
(
dims_weights2
[
1
]);
plhs
[
0
]
=
createMatrix
<
T
>
(
pAlpha
,
nAlpha
);
T
*
pr_alpha
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
plhs
[
0
]));
Matrix
<
T
>
alpha
(
pr_alpha
,
pAlpha
,
nAlpha
);
FISTA
::
ParamFISTA
<
T
>
param
;
param
.
num_threads
=
getScalarStructDef
<
int
>
(
prhs
[
4
],
"numThreads"
,
-
1
);
param
.
pos
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"pos"
,
false
);
param
.
max_it
=
getScalarStructDef
<
int
>
(
prhs
[
4
],
"max_it"
,
1000
);
param
.
tol
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"tol"
,
0.000001
);
param
.
it0
=
getScalarStructDef
<
int
>
(
prhs
[
4
],
"it0"
,
100
);
param
.
compute_gram
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"compute_gram"
,
false
);
param
.
max_iter_backtracking
=
getScalarStructDef
<
int
>
(
prhs
[
4
],
"max_iter_backtracking"
,
1000
);
param
.
L0
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"L0"
,
1.0
);
param
.
gamma
=
MAX
(
1.01
,
getScalarStructDef
<
T
>
(
prhs
[
4
],
"gamma"
,
1.5
));
param
.
c
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"c"
,
1.0
);
param
.
lambda
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"lambda"
,
T
(
1.0
));
getStringStruct
(
prhs
[
4
],
"regul"
,
param
.
name_regul
,
param
.
length_names
);
param
.
regul
=
regul_from_string
(
param
.
name_regul
);
if
(
param
.
regul
==
INCORRECT_REG
)
mexErrMsgTxt
(
"Unknown regularization"
);
getStringStruct
(
prhs
[
4
],
"loss"
,
param
.
name_loss
,
param
.
length_names
);
param
.
loss
=
loss_from_string
(
param
.
name_loss
);
if
(
param
.
loss
==
INCORRECT_LOSS
)
mexErrMsgTxt
(
"Unknown loss"
);
param
.
intercept
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"intercept"
,
false
);
param
.
verbose
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"verbose"
,
false
);
param
.
eval
=
nlhs
==
2
;
param
.
delta
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"delta"
,
1.0
);
param
.
lambda2
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"lambda2"
,
0.0
);
param
.
lambda3
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"lambda3"
,
0.0
);
param
.
size_group
=
getScalarStructDef
<
int
>
(
prhs
[
4
],
"size_group"
,
1
);
param
.
admm
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"admm"
,
false
);
param
.
lin_admm
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"lin_admm"
,
false
);
param
.
sqrt_step
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"sqrt_step"
,
true
);
param
.
is_inner_weights
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"is_inner_weights"
,
false
);
param
.
intercept
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"intercept"
,
false
);
param
.
resetflow
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"resetflow"
,
false
);
param
.
verbose
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"verbose"
,
false
);
param
.
clever
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"clever"
,
false
);
param
.
ista
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"ista"
,
false
);
param
.
subgrad
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"subgrad"
,
false
);
param
.
transpose
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"transpose"
,
false
);
param
.
log
=
getScalarStructDef
<
bool
>
(
prhs
[
4
],
"log"
,
false
);
param
.
a
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"a"
,
T
(
1.0
));
param
.
b
=
getScalarStructDef
<
T
>
(
prhs
[
4
],
"b"
,
0
);
if
(
param
.
transpose
)
{
if
(
GGm
!=
GGn
||
GGm
!=
nAlpha
)
mexErrMsgTxt
(
"size of field weights is not consistent"
);
if
(
nweights
!=
nAlpha
)
mexErrMsgTxt
(
"size of field start_weights is not consistent"
);
if
(
nweights2
!=
nAlpha
)
mexErrMsgTxt
(
"size of field stop_weights is not consistent"
);
}
else
{
if
(
GGm
!=
GGn
||
GGm
!=
pAlpha
)
mexErrMsgTxt
(
"size of field weights is not consistent"
);
if
(
nweights
!=
pAlpha
)
mexErrMsgTxt
(
"size of field start_weights is not consistent"
);
if
(
nweights2
!=
pAlpha
)
mexErrMsgTxt
(
"size of field stop_weights is not consistent"
);
}
if
(
param
.
log
)
{
mxArray
*
stringData
=
mxGetField
(
prhs
[
4
],
0
,
"logName"
);
if
(
!
stringData
)
mexErrMsgTxt
(
"Missing field logName"
);
int
stringLength
=
mxGetN
(
stringData
)
+
1
;
param
.
logName
=
new
char
[
stringLength
];
mxGetString
(
stringData
,
param
.
logName
,
stringLength
);
}
if
(
param
.
regul
==
GRAPH
||
param
.
regul
==
GRAPHMULT
)
mexErrMsgTxt
(
"Error: mexFistaGraph should be used instead"
);
if
(
param
.
regul
==
TREE_L0
||
param
.
regul
==
TREEMULT
||
param
.
regul
==
TREE_L2
||
param
.
regul
==
TREE_LINF
)
mexErrMsgTxt
(
"Error: mexFistaTree should be used instead"
);
if
(
param
.
num_threads
==
-
1
)
{
param
.
num_threads
=
1
;
#ifdef _OPENMP
param
.
num_threads
=
MIN
(
MAX_THREADS
,
omp_get_num_procs
());
#endif
}
GraphPathStruct
<
T
>
graph
;
graph
.
n
=
param
.
transpose
?
nAlpha
:
pAlpha
;
graph
.
m
=
GG_pB
[
graph
.
n
]
-
GG_pB
[
0
];
graph
.
weights
=
graph_weights
;
graph
.
start_weights
=
start_weights
;
graph
.
stop_weights
=
stop_weights
;
graph
.
ir
=
GG_r
;
graph
.
jc
=
GG_pB
;
graph
.
precision
=
getScalarStructDef
<
long
long
>
(
prhs
[
4
],
"precision"
,
10000000000
);
Matrix
<
T
>
duality_gap
;
FISTA
::
solver
<
T
>
(
X
,
*
D
,
alpha0
,
alpha
,
param
,
duality_gap
,
NULL
,
NULL
,
&
graph
);
if
(
nlhs
==
2
)
{
plhs
[
1
]
=
createMatrix
<
T
>
(
duality_gap
.
m
(),
duality_gap
.
n
());
T
*
pr_dualitygap
=
reinterpret_cast
<
T
*>
(
mxGetPr
(
plhs
[
1
]));
for
(
int
i
=
0
;
i
<
duality_gap
.
n
()
*
duality_gap
.
m
();
++
i
)
pr_dualitygap
[
i
]
=
duality_gap
[
i
];
}
if
(
param
.
logName
)
delete
[](
param
.
logName
);
if
(
mxIsSparse
(
prhs
[
1
]))
{
deleteCopySparse
<
T
>
(
D_v2
,
D_r2
,
D_pB2
,
D_pE2
,
D_v
,
D_r
);
}
delete
(
D
);
}
void
mexFunction
(
int
nlhs
,
mxArray
*
plhs
[],
int
nrhs
,
const
mxArray
*
prhs
[])
{
if
(
nrhs
!=
5
)
mexErrMsgTxt
(
"Bad number of inputs arguments"
);
if
(
nlhs
!=
1
&&
nlhs
!=
2
)
mexErrMsgTxt
(
"Bad number of output arguments"
);
if
(
mxGetClassID
(
prhs
[
0
])
==
mxDOUBLE_CLASS
)
{
callFunction
<
double
>
(
plhs
,
prhs
,
nlhs
);
}
else
{
callFunction
<
float
>
(
plhs
,
prhs
,
nlhs
);
}
}
Event Timeline
Log In to Comment