#! /usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright © 2016-2019 Cyril Desjouy <cyril.desjouy@univ-lemans.fr>
#
# This file is part of nsfds2
#
# nsfds2 is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# nsfds2 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with nsfds2. If not, see <http://www.gnu.org/licenses/>.
#
# Creation Date : 2019-05-02 - 14:09:56
#
# pylint: disable=too-many-locals
# pylint: disable=too-many-instance-attributes
# pylint: disable=no-member
"""
--------------
Graphic utilities for nsfds2
--------------
"""
import os
import sys
import getpass
import pathlib
import h5py
import numpy as _np
from scipy import signal as _signal
import matplotlib.pyplot as _plt
import matplotlib.animation as _ani
from mpl_toolkits.axes_grid1 import make_axes_locatable
from ofdlib2 import fdtd as _fdtd
from progressbar import ProgressBar, Bar, ReverseBar, ETA
from mplutils import modified_jet, MidPointNorm, set_figsize, get_subplot_shape
import fdgrid.graphics as _graphics
from nsfds2.utils.misc import nearest_index as _ne
__all__ = ['get_data', 'DataIterator', 'DataExtractor', 'Plot']
[docs]def get_data(filename):
""" Get data from `filename` (hdf5 file). """
try:
filename = pathlib.Path(filename).expanduser()
data = h5py.File(filename, 'r')
except OSError:
print('You must provide a valid hdf5 file')
sys.exit(1)
else:
return data
[docs]class DataIterator:
""" Data Generator
Parameters
----------
data : DataExtractor or str
DataExtractor instance or filename.
view : tuple
The variable to display.
nt : int
The last frame to consider.
"""
def __init__(self, data, view=('p'), nt=None):
if isinstance(data, DataExtractor):
self.data = data
elif isinstance(data, (pathlib.Path, str)):
self.data = DataExtractor(data)
self.view = view
self.ns = self.data.get_attr('ns')
self.icur = -self.ns
if nt is None:
self.nt = self.data.get_attr('nt')
else:
self.nt = nt
def __iter__(self):
""" Iterator """
return self
def __next__(self):
""" Next element of iterator : (frame_number, variable) """
try:
self.icur += self.ns
if self.nt:
if self.icur > self.nt:
raise StopIteration
tmp = [self.icur]
for var in self.view:
tmp.append(self.data.get(view=var, iteration=self.icur))
return tmp
except KeyError:
raise StopIteration
[docs]class Plot:
""" Helper class to plot results from nsfds2.
Parameters
----------
filename : str
hdf5 file
quiet : bool, optional
Quiet mode.
"""
def __init__(self, filename, quiet=False):
self.filename = pathlib.Path(filename).expanduser()
self.path = self.filename.parent
self.quiet = quiet
self.data = DataExtractor(self.filename)
self.nt = self.data.get_attr('nt')
self.ns = self.data.get_attr('ns')
self._init_geo()
self._init_fig()
def _init_geo(self):
""" Init coordinate system. """
self.obstacles = self.data.get_attr('obstacles')
self.Npml = self.data.get_attr('Npml')
self.mesh = self.data.get_attr('mesh')
self.bc = self.data.get_attr('bc')
if self.mesh == 'curvilinear':
self.x = self.data.get_dataset('xp')
self.z = self.data.get_dataset('zp')
else:
self.x, self.z = _np.meshgrid(self.data.get_dataset('x'),
self.data.get_dataset('z'))
self.x = _np.ascontiguousarray(self.x.T)
self.z = _np.ascontiguousarray(self.z.T)
def _init_fig(self):
""" Init figure parameters. """
self.cm = modified_jet()
self.title = r'{} -- iteration : {}'
self.titles = {'p': r'$p_a$ [Pa]',
'e': r'$e$ [kg.m$^2$.s$^{-2}$]',
'vx': r'$v_x$ [m/s]',
'vz': r'$v_z$ [m/s]',
'rho': r'$\rho$ [kg.m$^3$]',
'vxz': r'$\omega$ [m/s]'}
[docs] def movie(self, view=('p', 'e', 'vx', 'vz'), nt=None, ref=None,
figsize='auto', show_pml=False, show_probes=False,
dpi=100, fps=24, comp=1):
""" Make movie. """
# Progress bar
if not self.quiet:
widgets = [Bar('>'), ' ', ETA(), ' ', ReverseBar('<')]
pbar = ProgressBar(widgets=widgets, maxval=self.nt).start()
# Movie parameters
title = os.path.basename(self.filename).split('.')[0]
metadata = dict(title=title, artist=getpass.getuser(), comment='From nsfds2')
writer = _ani.FFMpegWriter(fps=fps, metadata=metadata, bitrate=-1, codec="libx264")
movie_filename = f'{title}.mkv'
# Nb of iterations and reference
nt = self.nt if not nt else _ne(nt, self.ns, self.nt)
ref = 'auto' if not ref else ref
# Create Iterator and make 1st frame
data = DataIterator(self.data, view=view, nt=nt)
i, *var = next(data)
fig, axes, ims = self.fields(view=view, iteration=i, ref=ref,
show_pml=show_pml,
show_probes=show_probes,
figsize=figsize,
comp=comp)
with writer.saving(fig, self.path / movie_filename, dpi=dpi):
writer.grab_frame()
for i, *var in data:
# StackOv : using-set-array-with-pyplot-pcolormesh-ruins-figure
for ax, mesh, v, j in zip(axes.ravel(), ims, var, range(len(ims))):
mesh.set_array(v[:-1, :-1].T.flatten())
ax.set_title(self.titles[view[j]] + f' (n={i})')
writer.grab_frame()
if not self.quiet:
pbar.update(i)
if not self.quiet:
pbar.finish()
[docs] def probes(self):
""" Plot pressure at probes. """
probes = self.data.get_dataset('probe_locations').tolist()
if not probes:
return None
p = self.data.get_dataset('probe_values')
t = _np.arange(self.nt)*self.data.get_attr('dt')
_, ax = _plt.subplots(figsize=(9, 4))
for i, c in enumerate(probes):
if self.data.get_attr('mesh') == 'curvilinear':
p0 = self.data.get_attr('p0')/self.data.get_dataset('J')[c[0], c[1]]
else:
p0 = self.data.get_attr('p0')
ax.plot(t, p[i, :] - p0, label=f'@{tuple(c)}')
ax.set_xlim(t.min(), t.max())
ax.set_xlabel('Time [s]')
ax.set_ylabel('Pressure [Pa]')
ax.legend()
ax.grid()
return None
[docs] def spectrogram(self):
""" Plot spectograms at probes. """
probes = self.data.get_dataset('probe_locations').tolist()
if not probes:
return None
p = self.data.get_dataset('probe_values')
M = 1024
fig, ax = _plt.subplots(p.shape[0], figsize=(9, 4))
for i, c in enumerate(probes):
if self.data.get_attr('mesh') == 'curvilinear':
p0 = self.data.get_attr('p0')/self.data.get_dataset('J')[c[0], c[1]]
else:
p0 = self.data.get_attr('p0')
freqs, times, Sx = _signal.spectrogram(p[i, :] - p0,
fs=1/self.data.get_attr('dt'),
window='hanning',
nperseg=M, noverlap=M-100,
detrend=False,
scaling='spectrum')
im = ax[i].pcolormesh(times, freqs/1000, 10*_np.log10(Sx), cmap='viridis')
ax[i].set_ylabel('Frequency [kHz]')
if i != len(probes) - 1:
ax[i].set_xticks([])
fig.colorbar(im, ax=ax[i], label=f'probe {i}')
ax[-1].set_xlabel('Time [s]')
ax[0].set_title('Square spectrum magitude')
_plt.tight_layout()
return None
[docs] def fields(self, view=('p', 'e', 'vx', 'vz'), iteration=None, ref=None,
show_pml=False, show_probes=True, figsize='auto',
midpoint=0, comp=1):
""" Make figure """
if iteration is None:
iteration = self.nt
else:
iteration = _ne(iteration, self.ns, self.nt)
var = []
norm = []
ims = []
ticks = []
for v in view:
var.append(self.data.get(view=v, iteration=iteration).T)
# vmin & vmax
if ref:
vmin, vmax = self.data.reference(view=v, ref=ref)
else:
vmin, vmax = var[-1].min(), var[-1].max()
# midpoint
if vmin > 0 and vmax > 0:
midpoint = var[-1].mean()
else:
midpoint = 0
# ticks
if abs(vmin-midpoint)/vmax > 0.33:
ticks.append([vmin, midpoint, vmax])
else:
ticks.append([midpoint, vmax])
norm.append(MidPointNorm(vmin=vmin/comp, vmax=vmax/comp, midpoint=midpoint))
fig, axes = _plt.subplots(*get_subplot_shape(len(var)))
if not isinstance(axes, _np.ndarray): # if only 1 varible in view
axes = _np.array(axes)
for i, ax in enumerate(axes.ravel()):
if i < len(var):
ims.append(ax.pcolormesh(self.x, self.z, var[i][:-1, :-1],
cmap=self.cm, norm=norm[i]))
ax.set_title(self.titles[view[i]] + f' (n={iteration})')
ax.set_xlabel(r'$x$ [m]')
ax.set_ylabel(r'$z$ [m]')
ax.set_aspect('equal')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
_plt.colorbar(ims[i], cax=cax, ticks=ticks[i])
probes = self.data.get_dataset('probe_locations').tolist()
if probes and show_probes:
_ = [ax.plot(self.x[i, j], self.z[i, j], 'ro') for i, j in probes]
_graphics.plot_subdomains(ax, self.x, self.z, self.obstacles)
if show_pml:
_graphics.plot_pml(ax, self.x, self.z, self.bc, self.Npml)
else:
ax.remove()
fig.set_size_inches(*set_figsize(axes, figsize))
_plt.tight_layout()
return fig, axes, ims
[docs] @staticmethod
def show():
""" Show all figures. """
_plt.show()