diff --git a/tests/conftest.py b/tests/conftest.py index 67722e7..0896059 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,29 +1,34 @@ import pytest import numpy as np from vtk.util.numpy_support import vtk_to_numpy as v2n @pytest.fixture def threeD_data(): N = 4 x = np.linspace(0, 1, N*2) y = np.linspace(0, 1, N+2) z = np.linspace(0, 1, N) xx, yy, zz = np.meshgrid(x, y, z, indexing='ij', sparse=True) r = np.sqrt(xx**2 + yy**2 + zz**2) e_r = np.zeros([s - 1 for s in r.shape] + [3, 3]) e_r[0, 0, 0, :, :] = np.array([[0, 1, 0], [1, 0, 0], [0, 1, 1]]) e_r[-1, 0, 0, :, :] = np.eye(3) return [x, y, z], r, e_r +@pytest.fixture(params=[False, True]) +def compression_fixture(request): + return request + + def get_vtk_data(reader, sstream): reader.SetReadFromInputString(True) reader.SetInputString(sstream.getvalue()) reader.Update() output = reader.GetOutput() return v2n(output.GetPointData().GetArray('point')), \ v2n(output.GetCellData().GetArray('cell')) diff --git a/tests/test_vtk_files.py b/tests/test_vtk_files.py index 01a9643..3b8a9fb 100644 --- a/tests/test_vtk_files.py +++ b/tests/test_vtk_files.py @@ -1,91 +1,95 @@ import io import numpy as np from numpy import all, min, max from vtk import ( vtkXMLRectilinearGridReader, vtkXMLImageDataReader, vtkXMLStructuredGridReader, ) from vtk.util.numpy_support import vtk_to_numpy from conftest import get_vtk_data from uvw import ( ImageData, RectilinearGrid, StructuredGrid, DataArray, ) -def test_rectilinear_grid(threeD_data): +def test_rectilinear_grid(threeD_data, compression_fixture): coords, r, e_r = threeD_data f = io.StringIO() - with RectilinearGrid(f, coords) as rect: + compress = compression_fixture.param + with RectilinearGrid(f, coords, compression=compress) as rect: rect.addPointData(DataArray(r, range(3), 'point')) rect.addCellData(DataArray(e_r, range(3), 'cell')) reader = vtkXMLRectilinearGridReader() vtk_r, vtk_e_r = get_vtk_data(reader, f) vtk_r = vtk_r.reshape(r.shape, order='F') vtk_e_r = vtk_e_r.reshape(e_r.shape, order='F').transpose((0, 1, 2, 4, 3)) assert all(vtk_r == r) assert all(vtk_e_r == e_r) -def test_image_data(threeD_data): +def test_image_data(threeD_data, compression_fixture): coords, r, e_r = threeD_data f = io.StringIO() + compress = compression_fixture.param with ImageData( f, [(min(x), max(x)) for x in coords], - [x.size for x in coords]) as fh: + [x.size for x in coords], + compression=compress) as fh: fh.addPointData(DataArray(r, range(3), 'point')) fh.addCellData(DataArray(e_r, range(3), 'cell')) reader = vtkXMLImageDataReader() vtk_r, vtk_e_r = get_vtk_data(reader, f) vtk_r = vtk_r.reshape(r.shape, order='F') vtk_e_r = vtk_e_r.reshape(e_r.shape, order='F').transpose((0, 1, 2, 4, 3)) assert all(vtk_r == r) assert all(vtk_e_r == e_r) -def test_structured_grid(): +def test_structured_grid(compression_fixture): f = io.StringIO() N = 5 r = np.linspace(0, 1, N) theta = np.linspace(0, 2*np.pi, 5*N) theta, r = np.meshgrid(theta, r, indexing='ij') x = r*np.cos(theta) y = r*np.sin(theta) points = np.vstack([x.ravel(), y.ravel()]).T - grid = StructuredGrid(f, points, (N, 5*N)) + compress = compression_fixture.param + grid = StructuredGrid(f, points, (N, 5*N), compression=compress) data = np.exp(-4*r**2) grid.addPointData(DataArray(data, reversed(range(2)), 'data')) grid.write() reader = vtkXMLStructuredGridReader() reader.SetReadFromInputString(True) reader.SetInputString(f.getvalue()) reader.Update() output = reader.GetOutput() vtk_data = vtk_to_numpy(output.GetPointData().GetArray('data')) vtk_data = vtk_data.reshape(data.shape, order='C') assert all(vtk_data == data) diff --git a/uvw/writer.py b/uvw/writer.py index 669560d..5912d52 100644 --- a/uvw/writer.py +++ b/uvw/writer.py @@ -1,177 +1,180 @@ import xml.dom.minidom as dom import io import zlib import numpy as np from functools import reduce from base64 import b64encode def setAttributes(node, attributes): """Set attributes of a node""" for item in attributes.items(): node.setAttribute(*item) def encodeArray(array, level): """Encode array data and header in base64.""" def compress(array): """Compress array with zlib. Returns header and compressed data.""" raw_data = array.tobytes() max_block_size = 2**15 # Enough blocks to span whole data nblocks = len(raw_data) // max_block_size + 1 last_block_size = len(raw_data) % max_block_size # Compress regular blocks compressed_data = [ zlib.compress(raw_data[i*max_block_size:(i+1)*max_block_size], level) for i in range(nblocks-1) ] # Compress last (smaller) block compressed_data.append( zlib.compress(raw_data[-last_block_size:], level) ) # Header data (cf https://vtk.org/Wiki/VTK_XML_Formats#Compressed_Data) usize = max_block_size psize = last_block_size csize = [len(x) for x in compressed_data] header = np.array([nblocks, usize, psize] + csize, dtype=np.uint32) return header.tobytes(), b"".join(compressed_data) def raw(array): """Returns header and array data in bytes.""" header = np.array([array.nbytes], dtype=np.uint32) return header.tobytes(), array.tobytes() if level is not None: data = compress(array) else: data = raw(array) return "".join(map(lambda x: b64encode(x).decode(), data)) class Component: """Generic component class capable of registering sub-components""" def __init__(self, name, parent_node, writer): self.writer = writer self.document = writer.document self.node = self.document.createElement(name) parent_node.appendChild(self.node) def setAttributes(self, attributes): setAttributes(self.node, attributes) def register(self, name, attributes={}): """Register a sub-component""" if type(attributes) != dict: raise Exception( 'Cannot register attributes of type ' + str(type(attributes))) sub_component = Component(name, self.node, self.writer) setAttributes(sub_component.node, attributes) return sub_component def _addArrayNodeData(self, data_array, node, vtk_format): if vtk_format == 'ascii': data_as_str = reduce( lambda x, y: x + str(y) + ' ', data_array.flat_data, "") elif vtk_format == 'binary': data_as_str = encodeArray(data_array.flat_data, self.writer.compression) else: raise Exception('Unsupported VTK Format "{}"'.format(vtk_format)) node.appendChild(self.document.createTextNode(data_as_str)) def _registerArrayComponent(self, array, name, vtk_format): attributes = array.attributes attributes['format'] = vtk_format return self.register(name, attributes) def registerDataArray(self, data_array, vtk_format='binary'): """Register a DataArray object""" component = self._registerArrayComponent(data_array, 'DataArray', vtk_format) self._addArrayNodeData(data_array, component.node, vtk_format) def registerPDataArray(self, data_array, vtk_format='binary'): """Register a DataArray object in p-file""" self._registerArrayComponent(data_array, 'PDataArray', vtk_format) class Writer: """Generic XML handler for VTK files""" def __init__(self, vtk_format, compression=None, vtk_version='0.1', byte_order='LittleEndian'): self.document = dom.getDOMImplementation() \ .createDocument(None, 'VTKFile', None) self.root = self.document.documentElement self.root.setAttribute('type', vtk_format) self.root.setAttribute('version', vtk_version) self.root.setAttribute('byte_order', byte_order) self.data_node = self.document.createElement(vtk_format) self.root.appendChild(self.data_node) self.size_indicator_bytes = np.dtype(np.uint32).itemsize self.append_data_arrays = [] - if compression is not None: + if compression is not None and compression != False: self.root.setAttribute('compressor', 'vtkZLibDataCompressor') if type(compression) is not int: compression = -1 else: if compression not in list(range(-1, 10)): raise Exception(('compression level {} is not ' 'recognized by zlib').format(compression)) + elif not compression: + compression = None + self.compression = compression def setDataNodeAttributes(self, attributes): """Set attributes for the entire dataset""" setAttributes(self.data_node, attributes) def registerPiece(self, attributes={}): """Register a piece element""" return self.registerComponent('Piece', self.data_node, attributes) def registerComponent(self, name, parent, attributes={}): comp = Component(name, parent, self) setAttributes(comp.node, attributes) return comp def registerAppend(self): append_node = Component('AppendedData', self.root, self) setAttributes(append_node.node, {'format': 'base64'}) self.root.appendChild(append_node.node) data_str = b"_" for data_array in self.append_data_arrays: data_str += encodeArray(data_array.flat_data) text = self.document.createTextNode(data_str.decode('ascii')) append_node.node.appendChild(text) def write(self, fd): if type(fd) == str: with open(fd, 'w') as file: self.write(file) elif issubclass(type(fd), io.TextIOBase): self.document.writexml(fd, indent="\n ", addindent=" ") else: raise RuntimeError("Expected a path or " + "file handle, got {}".format(type(fd))) def __str__(self): """Print XML to string""" return self.document.toprettyxml()