diff --git a/python/tamaas/dumpers/__init__.py b/python/tamaas/dumpers/__init__.py
index d223d84..16ef2f3 100644
--- a/python/tamaas/dumpers/__init__.py
+++ b/python/tamaas/dumpers/__init__.py
@@ -1,163 +1,229 @@
# @file
# @section LICENSE
#
# Copyright (©) 2016-19 EPFL (École Polytechnique Fédérale de Lausanne),
# Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
"""
Dumpers for the class tamaas.Model
"""
from __future__ import print_function
-from .. import ModelDumper
+from .. import ModelDumper, model_type
from ._helper import PeriodicHelper, step_dump, directory_dump
import numpy as np
class FieldDumper(ModelDumper):
"""Abstract dumper for python classes using fields"""
postfix = ''
extension = ''
name_format = "{basename}{postfix}.{extension}"
def __init__(self, basename, *fields, all_fields=False):
"""Construct with desired fields"""
super().__init__()
self.basename = basename
self.fields = fields
self.all_fields = all_fields
self.make_periodic = PeriodicHelper()
def add_field(self, field):
"""Add another field to the dump"""
if field not in self.fields:
self.fields.append(field)
def dump_to_file(self, fd, model):
"""Dump to a file (name or handle)"""
pass
def get_fields(self, model):
"""Get the desired fields"""
if not self.all_fields:
requested_fields = self.fields
else:
requested_fields = model.getFields()
return {field: model.getField(field) for field in requested_fields}
def get_attributes(self, model):
"""Get model attributes"""
return {
'model_type': str(model.type),
'system_size': model.getSystemSize(),
'discretization': model.getDiscretization()
}
def dump(self, model):
self.dump_to_file(self.file_path, model)
@property
def file_path(self):
"""Get the default filename"""
return self.name_format.format(basename=self.basename,
postfix=self.postfix,
extension=self.extension)
@directory_dump('numpys')
@step_dump
class NumpyDumper(FieldDumper):
"""Dumper to compressed numpy files"""
extension = 'npz'
def dump_to_file(self, fd, model):
"""Saving to compressed multi-field Numpy format"""
np.savez_compressed(fd, attrs=self.get_attributes(model),
**self.get_fields(model))
try:
import h5py
@directory_dump('hdf5')
@step_dump
class H5Dumper(FieldDumper):
"""Dumper to HDF5 file format"""
extension = 'h5'
def dump_to_file(self, fd, model):
"""Saving to HDF5 with metadata about the model"""
with h5py.File(fd, 'w') as fh:
# Writing data
for name, field in self.get_fields(model).items():
dset = fh.create_dataset(name, field.shape, field.dtype,
compression='gzip',
compression_opts=7)
dset[:] = field
# Writing metadata
for name, attr in self.get_attributes(model).items():
fh.attrs[name] = attr
except ImportError:
pass
try:
import uvw # noqa
@directory_dump('paraview')
@step_dump
class UVWDumper(FieldDumper):
"""Dumper to VTK files for elasto-plastic calculations"""
extension = 'vtr'
forbidden_fields = ['traction', 'gap'] # TODO make generic
def dump_to_file(self, fd, model):
"""Dump displacements, plastic deformations and stresses"""
discretization = model.getDiscretization().copy()
# Because we make fields periodic
discretization[1] += 1
discretization[2] += 1
# Space coordinates
coordinates = [np.linspace(0, L, N)
for L, N in zip(model.getSystemSize(),
discretization)]
# Correct order of coordinate dimensions
dimension_indices = [1, 2, 0]
# Creating rectilinear grid with correct order for components
grid = uvw.RectilinearGrid(fd, (coordinates[i]
for i in dimension_indices))
# Iterator over fields we want to dump
# Avoid 2D fields (TODO make generic)
fields_it = filter(lambda t: t[0] not in self.forbidden_fields,
self.get_fields(model).items())
# We make fields periodic for visualization
for name, field in fields_it:
array = uvw.DataArray(np.array(self.make_periodic[name](field),
dtype=np.double),
dimension_indices, name)
grid.addPointData(array)
grid.write()
except ImportError as e:
print(e)
+
+
+try:
+ from netCDF4 import Dataset
+
+ @directory_dump('netcdf')
+ @step_dump
+ class NetCDFDumper(FieldDumper):
+ """Dumper to netCDF4 files"""
+
+ extension = "nc"
+ boundary_fields = ['traction', 'gap']
+
+ def dump_to_file(self, fd, model):
+ with Dataset(fd, 'w', format='NETCDF4') as rootgrp:
+ self._dump_boundary(rootgrp, model)
+
+ if model.type in {model_type.volume_1d, model_type.volume_2d}:
+ self._dump_volume(rootgrp, model)
+
+ def _dump_boundary(self, grp, model):
+ self._dump_generic(grp, model,
+ lambda f: f[0] in self.boundary_fields,
+ 'boundary',
+ model.getBoundaryDiscretization(),
+ "xy")
+
+ def _dump_volume(self, grp, model):
+ self._dump_generic(grp, model,
+ lambda f: f[0] not in self.boundary_fields,
+ 'volume', model.getDiscretization(), "zxy")
+
+ def _dump_generic(self, grp, model, predicate,
+ group_name, shape, dimensions):
+ model_dim = len(model.getDiscretization())
+ field_dim = len(shape)
+
+ grp = grp.createGroup(group_name)
+
+ for size, label in zip(shape, dimensions):
+ grp.createDimension(label, size)
+
+ vec = grp.createDimension('vec', model_dim)
+ tens = grp.createDimension('tens', 2*model_dim)
+
+ fields = filter(predicate, self.get_fields(model).items())
+ dim_labels = list(dimensions[:field_dim])
+
+ print(field_dim)
+
+ for label, data in fields:
+ dims = dim_labels
+
+ # If we have an extra component
+ if data.ndim > field_dim:
+ if data.shape[-1] == tens.size:
+ dims.append(tens.name)
+ elif data.shape[-1] == vec.size:
+ dims.append(vec.name)
+
+ print(label, data.shape, dims)
+ var = grp.createVariable(label, 'f8', dims)
+ var[:] = np.array(data, dtype=np.double).flatten()
+
+except ImportError:
+ pass
diff --git a/python/tamaas/dumpers/netcdf_dumper.py b/python/tamaas/dumpers/netcdf_dumper.py
deleted file mode 100644
index 48d2f15..0000000
--- a/python/tamaas/dumpers/netcdf_dumper.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# @file
-# @section LICENSE
-#
-# Copyright (©) 2016-19 EPFL (École Polytechnique Fédérale de Lausanne),
-# Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published
-# by the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-
-from ._helper import FieldDumper
-
-import numpy as np
-
-from netCDF4 import Dataset
-from ._helper import step_dump
-
-
-__all__ = ['NetCDFDumper']
-
-
-@step_dump
-class NetCDFDumper(FieldDumper):
- """Dumper to NetCDF format files"""
- def __init__(self, basename, *fields):
- default_fields = ['traction', 'displacement',
- 'stress', 'plastic_strain']
- super().__init__(basename, default_fields, *fields)
-
- self.rootgrp = Dataset("{}.nc".format(self.basename),
- "w", format="NETCDF4_CLASSIC")
-
- def __del__(self):
- self.rootgrp.close()
-
- def dump_boundary(self, model):
- group = self.rootgrp
- shape = model.getDiscretization()
- tags = ["Z", "X", "Y"]
-
- # Creating dimensions
- group.createDimension("T", None)
- for tag, dim in zip(tags, shape):
- group.createDimension(tag, dim)
-
- # Dimension variables
- for dim, L in zip(group.dimensions.values(),
- model.getSystemSize()):
- tag = dim.name
- var = group.createVariable(tag, 'f8', (tag,))
- var[...] = np.linspace(0, L, dim.size)
-
- # Field variables
- for field in self.fields:
- var = group.createVariable(field, 'f8', ['T'] + tags)
- var[self.count, ...] = model.getField(field)[..., 2]
-
- def dump_volume(self, model):
- pass
-
- def dump(self, model):
- self.dump_boundary(model)
-
- self.rootgrp.model_type = str(model.type)
- self.rootgrp.sync()
diff --git a/tests/test_dumper.py b/tests/test_dumper.py
index d21e2e0..bf563ab 100644
--- a/tests/test_dumper.py
+++ b/tests/test_dumper.py
@@ -1,110 +1,138 @@
# -*- coding: utf-8 -*-
# @file
# @section LICENSE
#
# Copyright (©) 2016-19 EPFL (École Polytechnique Fédérale de Lausanne),
# Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from __future__ import print_function, division
-import tamaas as tm
-import numpy as np
import os
import shutil
+import tamaas as tm
+import numpy as np
from tamaas.dumpers import NumpyDumper
class Dumper(tm.ModelDumper):
"""Simple numpy dumper"""
+
def __init__(self):
tm.ModelDumper.__init__(self)
def dump(self, model):
np.savetxt('tractions.txt', np.ravel(model.getField('traction')))
- np.savetxt('displacement.txt', np.ravel(model.getField('displacement')))
+ np.savetxt('displacement.txt',
+ np.ravel(model.getField('displacement')))
def cleanup():
for name in ['tractions.txt',
'displacement.txt',
'numpys',
- 'hdf5']:
+ 'hdf5',
+ 'netcdf']:
if os.path.exists(name) and os.path.isdir(name):
shutil.rmtree(name)
elif os.path.exists(name):
os.remove(name)
def test_dumpers(tamaas_fixture):
model = tm.ModelFactory.createModel(tm.model_type.volume_2d, [1., 1., 1.],
[16, 4, 8])
dumper = Dumper()
np_dumper = NumpyDumper('test_dump', 'traction', 'displacement')
model.addDumper(np_dumper)
model.dump()
model.dump()
dumper << model
tractions = np.loadtxt('tractions.txt')
displacements = np.loadtxt('displacement.txt')
assert tractions.size == model.getTraction().size
assert displacements.size == model.getDisplacement().size
with np.load('numpys/test_dump_0000.npz') as npfile:
tractions = npfile['traction']
displacements = npfile['displacement']
attributes = npfile['attrs'].item()
t_shape = list(model.getTraction().shape)
d_shape = list(model.getDisplacement().shape)
assert tractions.shape == tuple(t_shape)
assert displacements.shape == tuple(d_shape)
assert str(model.type) == attributes['model_type']
assert os.path.isfile('numpys/test_dump_0001.npz')
cleanup()
# Protecting test
try:
from tamaas.dumpers import H5Dumper
import h5py
def test_h5dumper(tamaas_fixture):
model = tm.ModelFactory.createModel(tm.model_type.volume_2d,
[1., 1., 1.],
[16, 4, 8])
model.getDisplacement()[...] = 3.1415
dumper = H5Dumper('test_hdf5', 'traction', 'displacement')
dumper << model
assert os.path.isfile('hdf5/test_hdf5_0000.h5')
fh = h5py.File('hdf5/test_hdf5_0000.h5', 'r')
disp = np.array(fh['displacement'], dtype=tm.dtype)
assert np.all(np.abs(disp - 3.1415) < 1e-15)
assert list(fh.attrs['discretization']) == [16, 4, 8]
assert fh.attrs['model_type'] == str(model.type)
fh.close()
cleanup()
except ImportError:
pass
+
+try:
+ from tamaas.dumpers import NetCDFDumper
+ import netCDF4
+
+ def test_netcdfdumper(tamaas_fixture):
+ model = tm.ModelFactory.createModel(tm.model_type.volume_2d,
+ [1., 1., 1.],
+ [16, 4, 8])
+ model.getDisplacement()[...] = 3.1415
+ dumper = NetCDFDumper('test_netcdf', 'traction', 'displacement')
+ dumper << model
+
+ assert os.path.isfile('netcdf/test_netcdf_0000.nc')
+
+ fh = netCDF4.Dataset('netcdf/test_netcdf_0000.nc', 'r')
+ disp = fh['volume/displacement'][:]
+ assert np.all(np.abs(disp - 3.1415) < 1e-15)
+ assert disp.shape == (16, 4, 8, 3)
+
+ fh.close()
+ cleanup()
+
+except ImportError:
+ pass