Page MenuHomec4science

cast.hh
No OneTemporary

File Metadata

Created
Sun, Apr 28, 11:33
/*
* SPDX-License-Indentifier: AGPL-3.0-or-later
*
* Copyright (©) 2016-2022 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
* Copyright (©) 2020-2022 Lucas Frérot
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
/* -------------------------------------------------------------------------- */
#ifndef CAST_HH
#define CAST_HH
/* -------------------------------------------------------------------------- */
#include "grid.hh"
#include "grid_base.hh"
#include "numpy.hh"
/* -------------------------------------------------------------------------- */
#include <boost/preprocessor/seq.hpp>
#include <pybind11/cast.h>
/* -------------------------------------------------------------------------- */
namespace pybind11 {
// Format descriptor necessary for correct wrap of tamaas complex type
template <typename T>
struct format_descriptor<
tamaas::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = {'Z', c, '\0'};
static std::string format() { return std::string(value); }
};
#ifndef PYBIND11_CPP17
template <typename T>
constexpr const char format_descriptor<
tamaas::complex<T>,
detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
#endif
namespace detail {
// declare tamaas complex as a complex type for pybind11
template <typename T>
struct is_complex<tamaas::complex<T>> : std::true_type {};
template <typename T>
struct is_fmt_numeric<tamaas::complex<T>,
detail::enable_if_t<std::is_floating_point<T>::value>>
: std::true_type {
static constexpr int index = is_fmt_numeric<T>::index + 3;
};
static inline handle policy_switch(return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::copy:
case return_value_policy::move:
return handle();
case return_value_policy::automatic_reference: // happens in python-derived
// classes
case return_value_policy::reference:
return none();
case return_value_policy::reference_internal:
return parent;
default:
TAMAAS_EXCEPTION("Policy is not handled");
}
}
template <class array, typename T, tamaas::UInt dim>
handle grid_to_python(const tamaas::Grid<T, dim>& grid,
return_value_policy policy, handle parent) {
parent = policy_switch(policy, parent); // reusing variable
std::vector<tamaas::UInt> sizes(dim);
std::copy(grid.sizes().begin(), grid.sizes().end(), sizes.begin());
if (grid.getNbComponents() != 1)
sizes.push_back(grid.getNbComponents());
return array(sizes, grid.getInternalData(), parent).release();
}
template <class array, typename T>
handle grid_to_python(const tamaas::GridBase<T>& grid,
return_value_policy policy, handle parent) {
parent = policy_switch(policy, parent); // reusing variable
std::vector<tamaas::UInt> sizes = {grid.dataSize()};
return array(sizes, grid.getInternalData(), parent).release();
}
/**
* Type caster for grid classes
* inspired by https://tinyurl.com/y8m47qh3 from T. De Geus
* and pybind11/eigen.h
*/
template <template <typename, tamaas::UInt> class G, typename T,
tamaas::UInt dim>
struct type_caster<G<T, dim>> {
using type = G<T, dim>;
using array_type =
array_t<typename type::value_type, array::c_style | array::forcecast>;
public:
// NOLINTNEXTLINE(readability-else-after-return)
PYBIND11_TYPE_CASTER(type, _("GridWrap<T, dim>"));
/**
* Conversion part 1 (Python->C++): convert a PyObject into a grid
* instance or return false upon failure. The second argument
* indicates whether implicit conversions should be applied.
*/
bool load(handle src, bool convert) {
if (!array_type::check_(src) or !convert)
return false;
auto buf = array_type::ensure(src);
if (!buf)
return false;
value.move(tamaas::wrap::GridNumpy<G<T, dim>>(buf));
return true;
}
/**
* Conversion part 2 (C++ -> Python): convert a grid instance into
* a Python object. The second and third arguments are used to
* indicate the return value policy and parent object (for
* ``return_value_policy::reference_internal``) and are generally
* ignored by implicit casters.
*/
static handle cast(const type& src, return_value_policy policy,
handle parent) {
return grid_to_python<array_type>(src, policy, parent);
}
};
/**
* Type caster for GridBase classes
*/
template <typename T>
struct type_caster<tamaas::GridBase<T>> {
using type = tamaas::GridBase<T>;
using array_type =
array_t<typename type::value_type, array::c_style | array::forcecast>;
public:
// NOLINTNEXTLINE(readability-else-after-return)
PYBIND11_TYPE_CASTER(type, _("GridBaseWrap<T>"));
bool load(handle src, bool convert) {
if (!array_type::check_(src) or !convert)
return false;
auto buf = array_type::ensure(src);
if (!buf)
return false;
value.move(tamaas::wrap::GridBaseNumpy<T>(buf));
return true;
}
static handle cast(const type& src, return_value_policy policy,
handle parent) {
#define GRID_BASE_CASE(unused1, unused2, dim) \
case dim: { \
const auto* conv = dynamic_cast<const tamaas::Grid<T, dim>*>(&src); \
if (conv) \
return grid_to_python<array_type>(*conv, policy, parent); \
else \
return grid_to_python<array_type>(src, policy, parent); \
}
switch (src.getDimension()) {
BOOST_PP_SEQ_FOR_EACH(GRID_BASE_CASE, ~, (1)(2)(3));
default:
return grid_to_python<array_type>(src, policy, parent);
}
#undef GRID_BASE_CASE
}
};
} // namespace detail
} // namespace pybind11
#endif // CAST_HH

Event Timeline