diff --git a/test.py b/test.py index a9e8866..def220e 100644 --- a/test.py +++ b/test.py @@ -1,22 +1,28 @@ from writer import * from data_array import * from vtk_files import * import numpy as np N = 10 -x = np.linspace(0, 1, 2*N) +x = np.linspace(0, 1, N) y = np.linspace(0, 1, N) z = np.linspace(0, 1, N) rect = RectilinearGrid('test.vtr', (x, y, z)) x, y, z = np.meshgrid(x, y, z, indexing='ij') r = np.sqrt(x**2 + y**2 + z**2) e_r = np.zeros(r.shape + (3,3)) e_r[0, 0, 0, :, :] = np.array([[0, 1, 0], [1, 0, 0], [0, 1, 1]]) e_r[-1, 0, 0, :, :] = np.eye(3) print(e_r.shape) -rect.addPointData(DataArray(e_r, range(3), 'e_r')) rect.addPointData(DataArray(r, range(3), 'R')) +rect.addPointData(DataArray(r, range(3), 'R2')) +rect.addPointData(DataArray(e_r, range(3), 'e_r')) rect.write() + +from evtk.hl import gridToVTK + +e_r = e_r.transpose(4, 3, 0, 1, 2).copy() +gridToVTK('grid', x, y, z, pointData={'R': r, 'R2':r, 'e_r': e_r}) diff --git a/vtk_files.py b/vtk_files.py index b3f85f0..713cd6d 100644 --- a/vtk_files.py +++ b/vtk_files.py @@ -1,65 +1,65 @@ from writer import * from data_array import * import functools class VTKFile: """Generic VTK file""" def __init__(self, filename, filetype, rank=None): self.filename = filename self.rank = rank self.writer = Writer(filetype) def addPointData(self, data_array): - self.point_data.registerDataArray(data_array, vtk_format='ascii') + self.point_data.registerDataArray(data_array, vtk_format='append') def addCellData(self, data_array): self.cell_data.registerDataArray(data_array, vtk_format='ascii') def write(self): self.writer.registerAppend() self.writer.write(self.filename) class RectilinearGrid(VTKFile): """VTK Rectilinear grid (coordinates are given by 3 seperate ranges)""" def __init__(self, filename, coordinates, rank=None): VTKFile.__init__(self, filename, 'RectilinearGrid', rank) # Checking that we actually have a list or tuple if type(coordinates).__name__ not in ('tuple', 'list'): coordinates = [coordinates] self.coordinates = list(coordinates) # Filling in missing coordinates for _ in range(len(coordinates), 3): self.coordinates.append(np.array([0.])) # Setting data extent extent = [] for coord in self.coordinates: if coord.ndim != 1: raise Exception('Coordinate array should have only one dimension') extent.append(coord.size-1) extent = functools.reduce(lambda x, y: x + "0 {} ".format(y), extent, "") self.writer.setDataNodeAttributes({ "WholeExtent": extent }) self.piece = self.writer.registerPiece({ "Extent": extent }) # Registering coordinates coordinate_component = self.piece.register('Coordinates') for coord, prefix in zip(self.coordinates, ('x', 'y', 'z')): array = DataArray(coord, [0], prefix + '_coordinates') - coordinate_component.registerDataArray(array, vtk_format='ascii') + coordinate_component.registerDataArray(array, vtk_format='binary') # Registering data elements self.point_data = self.piece.register('PointData') self.cell_data = self.piece.register('CellData') diff --git a/writer.py b/writer.py index b1a7dc9..a59a4cf 100644 --- a/writer.py +++ b/writer.py @@ -1,87 +1,104 @@ import xml.dom import xml.dom.minidom as dom import functools import base64 +import numpy as np def setAttributes(node, attributes): """Set attributes of a node""" for item in attributes.items(): node.setAttribute(*item) +def encodeArray(array): + # Mandatory number of bytes encoded as uint32 + nbytes = array.nbytes + bytes = base64.b64encode(np.array([nbytes], dtype=np.uint32)) + bytes += base64.b64encode(array) + return bytes + class Component: """Generic component class capable of registering sub-components""" def __init__(self, name, parent_node, writer): self.writer = writer self.document = writer.document self.node = self.document.createElement(name) parent_node.appendChild(self.node) def register(self, name, attributes=dict()): """Register a sub-component""" sub_component = Component(name, self.node, self.writer) setAttributes(sub_component.node, attributes) return sub_component def registerDataArray(self, data_array, vtk_format='append'): + """Register a DataArray object""" array_component = Component('DataArray', self.node, self.writer) attributes = data_array.attributes attributes['format'] = vtk_format if vtk_format == 'append': attributes['offset'] = str(self.writer.offset) - self.writer.offset += data_array.flat_data.size + array = data_array.flat_data + self.writer.offset += array.nbytes + self.writer.offset += self.writer.size_indicator_bytes self.writer.append_data_arrays.append(data_array) elif vtk_format == 'ascii': data_as_str = functools.reduce(lambda x, y: x + str(y) + ' ', data_array.flat_data, "") array_component.node.appendChild(self.document.createTextNode(data_as_str)) + + elif vtk_format == 'binary': + array_component.node.appendChild( + self.document.createTextNode( + encodeArray(data_array.flat_data).decode('ascii'))) setAttributes(array_component.node, attributes) class Writer: """Generic XML handler for VTK files""" def __init__(self, vtk_format, vtk_version='0.1', byte_order='LittleEndian'): self.document = dom.getDOMImplementation().createDocument(None, 'VTKFile', None) self.root = self.document.documentElement self.root.setAttribute('type', vtk_format) self.root.setAttribute('version', vtk_version) self.root.setAttribute('byte_order', byte_order) self.data_node = self.document.createElement(vtk_format) self.root.appendChild(self.data_node) self.offset = 0 # Global offset + self.size_indicator_bytes = np.dtype(np.uint32).itemsize self.append_data_arrays = [] def setDataNodeAttributes(self, attributes): """Set attributes for the entire dataset""" setAttributes(self.data_node, attributes) def registerPiece(self, attributes): """Register a piece element""" piece = Component('Piece', self.data_node, self) setAttributes(piece.node, attributes) return piece def registerAppend(self): append_node = Component('AppendedData', self.root, self) setAttributes(append_node.node, {'format': 'base64'}) self.root.appendChild(append_node.node) data_str = b"_" for data_array in self.append_data_arrays: - data_str += base64.b64encode(data_array.flat_data) + data_str += encodeArray(data_array.flat_data) text = self.document.createTextNode(data_str.decode('ascii')) append_node.node.appendChild(text) def write(self, filename): with open(filename, 'w') as file: - file.write(str(self)) + self.document.writexml(file, indent="\n ", addindent=" ") def __str__(self): """Print XML to string""" return self.document.toprettyxml()