diff --git a/tests/test_dumper.py b/tests/test_dumper.py index 895336a..d99529b 100644 --- a/tests/test_dumper.py +++ b/tests/test_dumper.py @@ -1,145 +1,131 @@ # -*- coding: utf-8 -*- # @file # @section LICENSE # # Copyright (©) 2016-2021 EPFL (École Polytechnique Fédérale de Lausanne), # Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published # by the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from __future__ import division import os -import shutil import tamaas as tm import numpy as np import pytest from tamaas.dumpers import NumpyDumper class Dumper(tm.ModelDumper): """Simple numpy dumper""" def __init__(self): tm.ModelDumper.__init__(self) def dump(self, model): np.savetxt('tractions.txt', np.ravel(model.traction)) np.savetxt('displacement.txt', np.ravel(model.displacement)) -def cleanup(): - for name in ['tractions.txt', - 'displacement.txt', - 'numpys', - 'hdf5', - 'netcdf']: - if os.path.exists(name) and os.path.isdir(name): - shutil.rmtree(name) - elif os.path.exists(name): - os.remove(name) - - -def test_dumpers(tamaas_fixture): +def test_dumpers(tamaas_fixture, tmp_path): + os.chdir(tmp_path) model = tm.ModelFactory.createModel(tm.model_type.volume_2d, [1., 1., 1.], [16, 4, 8]) dumper = Dumper() np_dumper = NumpyDumper('test_dump', 'traction', 'displacement') model.addDumper(np_dumper) model.dump() model.dump() dumper << model ref_t = model['traction'] ref_d = model['displacement'] tractions = np.loadtxt('tractions.txt') displacements = np.loadtxt('displacement.txt') assert np.all(tractions.reshape(ref_t.shape) == ref_t) assert np.all(displacements.reshape(ref_d.shape) == ref_d) with np.load('numpys/test_dump_0000.npz', allow_pickle=True) as npfile: tractions = npfile['traction'] displacements = npfile['displacement'] attributes = npfile['attrs'].item() assert np.all(tractions == ref_t) assert np.all(displacements == ref_d) assert str(model.type) == attributes['model_type'] assert os.path.isfile('numpys/test_dump_0001.npz') - cleanup() # Protecting test try: from tamaas.dumpers import H5Dumper import h5py - def test_h5dumper(tamaas_fixture): + def test_h5dumper(tamaas_fixture, tmp_path): + os.chdir(tmp_path) model = tm.ModelFactory.createModel(tm.model_type.volume_2d, [1., 1., 1.], [16, 4, 8]) model['displacement'][...] = 3.1415 dumper = H5Dumper('test_hdf5', 'traction', 'displacement') dumper << model assert os.path.isfile('hdf5/test_hdf5_0000.h5') fh = h5py.File('hdf5/test_hdf5_0000.h5', 'r') disp = np.array(fh['displacement'], dtype=tm.dtype) assert np.all(np.abs(disp - 3.1415) < 1e-15) assert list(fh.attrs['discretization']) == [16, 4, 8] assert fh.attrs['model_type'] == str(model.type) fh.close() - cleanup() except ImportError: pass try: from tamaas.dumpers import NetCDFDumper from netCDF4 import Dataset - def test_netcdfdumper(tamaas_fixture): + def test_netcdfdumper(tamaas_fixture, tmp_path): + os.chdir(tmp_path) model = tm.ModelFactory.createModel(tm.model_type.volume_2d, [1., 1., 1.], [16, 4, 8]) model['displacement'][...] = 3.1415 dumper = NetCDFDumper('test_netcdf', 'traction', 'displacement') # Dumping two frames dumper << model dumper << model assert os.path.isfile('netcdf/test_netcdf.nc') with Dataset('netcdf/test_netcdf.nc', 'r') as netcdf_file: disp = netcdf_file['displacement'][:] assert np.all(np.abs(disp - 3.1415) < 1e-15) assert disp.shape[1:] == (16, 4, 8, 3) assert disp.shape[0] == 2 with pytest.raises(Exception): model = tm.ModelFactory.createModel(tm.model_type.volume_2d, [1., 1., 1.], [15, 4, 8]) dumper << model - cleanup() - except ImportError: pass