diff --git a/tests/test_rectilinear_grid.py b/tests/test_rectilinear_grid.py index db7510e..cd10064 100644 --- a/tests/test_rectilinear_grid.py +++ b/tests/test_rectilinear_grid.py @@ -1,28 +1,28 @@ from uvw import RectilinearGrid, DataArray import numpy as np + def test_rectilinear_grid(): N = 3 x = np.linspace(0, 1, N*2) y = np.linspace(0, 1, N+2) z = np.linspace(0, 1, N) out_name = 'test_rectilinear_grid.vtr' - rect = RectilinearGrid(out_name, (x, y, z)) - x, y, z = np.meshgrid(x, y, z, indexing='ij') - r = np.sqrt(x**2 + y**2 + z**2) - e_r = np.zeros(r.shape + (3,3)) + xx, yy, zz = np.meshgrid(x, y, z, indexing='ij', sparse=True) + r = np.sqrt(xx**2 + yy**2 + zz**2) + e_r = np.zeros(r.shape + (3, 3)) e_r[0, 0, 0, :, :] = np.array([[0, 1, 0], [1, 0, 0], [0, 1, 1]]) e_r[-1, 0, 0, :, :] = np.eye(3) - rect.addPointData(DataArray(r, range(3), 'R')) - rect.addPointData(DataArray(r, range(3), 'R2')) - rect.addPointData(DataArray(e_r, range(3), 'e_r')) - rect.write() + with RectilinearGrid(out_name, (x, y, z)) as rect: + rect.addPointData(DataArray(r, range(3), 'R')) + rect.addPointData(DataArray(r, range(3), 'R2')) + rect.addPointData(DataArray(e_r, range(3), 'e_r')) output = open(out_name, 'r') reference = open('test_rectilinear_grid.ref', 'r') assert output.read() == reference.read() diff --git a/uvw/data_array.py b/uvw/data_array.py index ce1bc83..9c538d1 100644 --- a/uvw/data_array.py +++ b/uvw/data_array.py @@ -1,44 +1,49 @@ import numpy as np import functools import operator class DataArray: """Class holding information on ndarray""" def __init__(self, data, spatial_axes, name='', components_order='C'): """ Data array constructor :param data: the numpy array containing the data (possibly a view) :param spatial_axes: a container of ints that indicate which axes of the array correspond to space dimensions (in order) :param name: the name of the data :param components_order: the order of the non-spatial axes of the array """ self.data = data axes = list(range(data.ndim)) + spatial_axes = list(spatial_axes) + + if data.ndim < len(spatial_axes): + raise Exception('Dimensions of data smaller than space dimensions') + for ax in spatial_axes: axes.remove(ax) nb_components = functools.reduce( lambda x, y: x * data.shape[y], axes, 1) if components_order == 'C': axes.reverse() else: raise Exception('Unrecognized components order') axes += spatial_axes # Hopefully this is a view self.flat_data = self.data.transpose(*axes).reshape(-1, order='F') self.attributes = { "Name": name, "type": str(self.flat_data.dtype).capitalize(), "NumberOfComponents": str(nb_components) } def __str__(self): return self.attributes.__str__() diff --git a/uvw/vtk_files.py b/uvw/vtk_files.py index fb4478c..4b0798a 100644 --- a/uvw/vtk_files.py +++ b/uvw/vtk_files.py @@ -1,139 +1,146 @@ from . import writer from . import data_array import functools import numpy as np class VTKFile: """Generic VTK file""" def __init__(self, filename, filetype, rank=None): self.filename = filename self.rank = rank self.writer = writer.Writer(filetype) # Center piece self.piece = self.writer.registerPiece() # Registering data elements self.point_data = self.piece.register('PointData') self.cell_data = self.piece.register('CellData') def addPointData(self, data_array): self.point_data.registerDataArray(data_array) def addCellData(self, data_array): self.cell_data.registerDataArray(data_array) def write(self): self.writer.registerAppend() self.writer.write(self.filename) + def __enter__(self): + return self + + def __exit__(self, *args): + self.write() + class ImageData(VTKFile): """VTK Image data (coordinates are given by a range and constant spacing)""" def __init__(self, filename, ranges, points, rank=None): VTKFile.__init__(self, filename, self.__class__.__name__, rank) # Computing spacings spacings = [(x[1] - x[0]) / (n - 1) for x, n in zip(ranges, points)] # Filling in missing coordinates for _ in range(len(points), 3): points.append(1) # Setting extents, spacings and origin extent = functools.reduce( lambda x, y: x + "0 {} ".format(y-1), points, "") spacings = functools.reduce( lambda x, y: x + "{} ".format(y), spacings, "") origins = functools.reduce( lambda x, y: x + "{} ".format(y[0]), ranges, "") self.writer.setDataNodeAttributes({ 'WholeExtent': extent, 'Spacing': spacings, 'Origin': origins }) self.piece.setAttributes({ "Extent": extent }) class RectilinearGrid(VTKFile): """VTK Rectilinear grid (coordinates are given by 3 seperate ranges)""" def __init__(self, filename, coordinates, rank=None): VTKFile.__init__(self, filename, self.__class__.__name__, rank) # Checking that we actually have a list or tuple if type(coordinates).__name__ == 'ndarray': coordinates = [coordinates] self.coordinates = list(coordinates) # Filling in missing coordinates for _ in range(len(self.coordinates), 3): self.coordinates.append(np.array([0.])) # Setting data extent extent = [] for coord in self.coordinates: if coord.ndim != 1: raise Exception( - 'Coordinate array should have only one dimension') + 'Coordinate array should have only one dimension' + + ' (has {})'.format(coord.ndim)) extent.append(coord.size-1) extent = functools.reduce( lambda x, y: x + "0 {} ".format(y), extent, "") self.writer.setDataNodeAttributes({ "WholeExtent": extent }) self.piece.setAttributes({ "Extent": extent }) # Registering coordinates coordinate_component = self.piece.register('Coordinates') for coord, prefix in zip(self.coordinates, ('x', 'y', 'z')): array = data_array.DataArray(coord, [0], prefix + '_coordinates') coordinate_component.registerDataArray(array) class StructuredGrid(VTKFile): """VTK Structured grid (coordinates given by a single array of points)""" def __init__(self, filename, points, shape, rank=None): VTKFile.__init__(self, filename, self.__class__.__name__, rank) if points.ndim != 2: raise 'Points should be a 2D array' # Completing the missing coordinates points_3d = np.zeros((points.shape[0], 3)) for i in range(points.shape[1]): points_3d[:, i] = points[:, i] extent = [n - 1 for n in shape] for i in range(len(extent), 3): extent.append(0) extent = functools.reduce( lambda x, y: x + "0 {} ".format(y), extent, "") self.writer.setDataNodeAttributes({ "WholeExtent": extent }) self.piece.setAttributes({ "Extent": extent }) points_component = self.piece.register('Points') points_component.registerDataArray( data_array.DataArray(points_3d, [0], 'points'))