Page MenuHomec4science

cast.hh
No OneTemporary

File Metadata

Created
Fri, May 31, 19:58
/*
* SPDX-License-Indentifier: AGPL-3.0-or-later
*
* Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
* Copyright (©) 2020-2023 Lucas Frérot
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
/* -------------------------------------------------------------------------- */
#ifndef CAST_HH
#define CAST_HH
/* -------------------------------------------------------------------------- */
#include "field_container.hh"
#include "grid.hh"
#include "grid_base.hh"
#include "model_type.hh"
#include "numpy.hh"
/* -------------------------------------------------------------------------- */
#include <pybind11/cast.h>
/* -------------------------------------------------------------------------- */
namespace pybind11 {
// Format descriptor necessary for correct wrap of tamaas complex type
template <typename T>
struct format_descriptor<
tamaas::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = {'Z', c, '\0'};
static std::string format() { return std::string(value); }
};
#ifndef PYBIND11_CPP17
template <typename T>
constexpr const char format_descriptor<
tamaas::complex<T>,
detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
#endif
namespace detail {
// declare tamaas complex as a complex type for pybind11
template <typename T>
struct is_complex<tamaas::complex<T>> : std::true_type {};
template <typename T>
struct is_fmt_numeric<tamaas::complex<T>,
detail::enable_if_t<std::is_floating_point<T>::value>>
: std::true_type {
static constexpr int index = is_fmt_numeric<T>::index + 3;
};
static inline handle policy_switch(return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::copy:
case return_value_policy::move:
return handle();
case return_value_policy::automatic_reference: // happens in python-derived
// classes
case return_value_policy::reference:
return none();
case return_value_policy::reference_internal:
return parent;
default:
TAMAAS_EXCEPTION("Policy is not handled");
}
}
template <class array, typename T, tamaas::UInt dim>
handle grid_to_python(const tamaas::Grid<T, dim>& grid,
return_value_policy policy, handle parent) {
parent = policy_switch(policy, parent); // reusing variable
std::vector<tamaas::UInt> sizes(dim), strides(dim);
std::copy_n(grid.sizes().begin(), dim, sizes.begin());
std::copy_n(grid.getStrides().begin(), dim, strides.begin());
if (grid.getNbComponents() != 1) {
sizes.push_back(grid.getNbComponents());
strides.push_back(grid.getStrides().back());
}
// Numpy arrays have strides in bytes
std::transform(strides.begin(), strides.end(), strides.begin(),
[](auto s) { return s * sizeof(T); });
return array(sizes, strides, grid.getInternalData(), parent).release();
}
template <class array, typename T>
handle grid_to_python(const tamaas::GridBase<T>& grid,
return_value_policy policy, handle parent) {
parent = policy_switch(policy, parent); // reusing variable
std::vector<tamaas::UInt> sizes = {grid.dataSize()};
return array(sizes, grid.getInternalData(), parent).release();
}
/**
* Type caster for grid classes
* inspired by https://tinyurl.com/y8m47qh3 from T. De Geus
* and pybind11/eigen.h
*/
template <template <typename, tamaas::UInt> class G, typename T,
tamaas::UInt dim>
struct type_caster<G<T, dim>> {
using type = G<T, dim>;
using array_type =
array_t<typename type::value_type, array::c_style | array::forcecast>;
public:
// NOLINTNEXTLINE(readability-else-after-return)
PYBIND11_TYPE_CASTER(type, _("GridWrap<T, dim>"));
/**
* Conversion part 1 (Python->C++): convert a PyObject into a grid
* instance or return false upon failure. The second argument
* indicates whether implicit conversions should be applied.
*/
bool load(handle src, bool convert) {
if (!array_type::check_(src) or !convert)
return false;
auto buf = array_type::ensure(src);
if (!buf)
return false;
value.move(tamaas::wrap::GridNumpy<G<T, dim>>(buf));
return true;
}
/**
* Conversion part 2 (C++ -> Python): convert a grid instance into
* a Python object. The second and third arguments are used to
* indicate the return value policy and parent object (for
* ``return_value_policy::reference_internal``) and are generally
* ignored by implicit casters.
*/
static handle cast(const type& src, return_value_policy policy,
handle parent) {
return grid_to_python<array_type>(src, policy, parent);
}
};
/**
* Type caster for GridBase classes
*/
template <typename T>
struct type_caster<tamaas::GridBase<T>> {
using type = tamaas::GridBase<T>;
using array_type =
array_t<typename type::value_type, array::c_style | array::forcecast>;
public:
// NOLINTNEXTLINE(readability-else-after-return)
PYBIND11_TYPE_CASTER(type, _("GridBaseWrap<T>"));
bool load(handle src, bool convert) {
if (!array_type::check_(src) or !convert)
return false;
auto buf = array_type::ensure(src);
if (!buf)
return false;
value.move(tamaas::wrap::GridBaseNumpy<T>(buf));
return true;
}
static handle cast(const type& src, return_value_policy policy,
handle parent) {
using namespace ::tamaas::detail;
return tuple_dispatch_with_default<dims_t>(
[&](auto&& _) {
constexpr auto dim = std::decay_t<decltype(_)>::value;
const auto* conv = dynamic_cast<const tamaas::Grid<T, dim>*>(&src);
if (conv)
return grid_to_python<array_type>(*conv, policy, parent);
return grid_to_python<array_type>(src, policy, parent);
},
[&](auto&&) { return grid_to_python<array_type>(src, policy, parent); },
src.getDimension());
}
};
/**
* Type caster for grid variant
*/
template <>
struct type_caster<typename tamaas::FieldContainer::Value> {
using type = typename tamaas::FieldContainer::Value;
using array_type = array;
public:
// NOLINTNEXTLINE(readability-else-after-return)
PYBIND11_TYPE_CASTER(type, _("GridVariant"));
bool load(handle /*src*/, bool /*convert*/) {
return false;
// if (!array_type::check_(src) or !convert)
// return false;
// auto buf = array_type::ensure(src);
// if (!buf)
// return false;
// value.move(tamaas::wrap::GridBaseNumpy<T>(buf));
// return true;
}
static handle cast(const type& src, return_value_policy policy,
handle parent) {
using namespace ::tamaas::detail;
return boost::apply_visitor(
[policy, parent](auto&& grid_ptr) {
return tuple_dispatch_with_default<dims_t>(
[&](auto&& _) {
using T =
typename std::decay_t<decltype(*grid_ptr)>::value_type;
constexpr auto dim = std::decay_t<decltype(_)>::value;
const auto* conv =
dynamic_cast<const tamaas::Grid<T, dim>*>(grid_ptr.get());
if (conv)
return grid_to_python<array_type>(*conv, policy, parent);
return grid_to_python<array_type>(*grid_ptr, policy, parent);
},
[&](auto&&) {
return grid_to_python<array_type>(*grid_ptr, policy, parent);
},
grid_ptr->getDimension());
},
src);
}
};
} // namespace detail
} // namespace pybind11
#endif // CAST_HH

Event Timeline