Page MenuHomec4science

Utils_Wavelet.py
No OneTemporary

File Metadata

Created
Sat, May 3, 12:35

Utils_Wavelet.py

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 5 15:31:13 2024
@author: Vigneashwara.P
email: vigneashwara.pandiyan@tii.ae
_status_: "Prototyping"
_maintainer_ = Vigneashwara Pandiyan
Modification and reuse of this code should be authorized by the first owner, code author(s)
"""
import pandas as pd
import glob
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm
import ntpath
import re
from scipy import signal
import pywt
import os
from matplotlib import ticker
#%%
def TwoDwaveletplot(data_new, time, sample_rate,scales, plotname, path, levels):
"""
Plot a 2D wavelet plot.
Parameters:
data_new (numpy.ndarray): The input data array.
time (numpy.ndarray): The time array.
scales (numpy.ndarray): The scales array.
plotname (str): The name of the plot.
path (str): The path to save the plot.
levels (int): plot levels.
Returns:
None
"""
plt.rcParams.update(plt.rcParamsDefault)
vmax=1
vmin=0
levels = np.linspace(vmin,vmax,levels)
signal_window=data_new.transpose()
waveletname= 'morl'
dt = time[1] - time[0]
[coefficients, frequencies] = pywt.cwt(signal_window, scales, waveletname, dt)
power = (abs(coefficients))
lenthA= len(frequencies)
# frequencies= frequencies[frequencies < lowpass]
lenthB = len(frequencies)
trimlenth = lenthA - lenthB
power=np.delete(power, np.s_[0:trimlenth],axis=0)
# power=np.log2(power)
print(np.min(power))
print(np.max(power))
fig, ax = plt.subplots(figsize=(12, 7))
# plt.rcParams['agg.path.chunksize'] = windowsize
fig.patch.set_visible(True)
im = plt.contourf(time, frequencies, power,levels=levels ,vmax=vmax,vmin=vmin,cmap=plt.cm.coolwarm)
ax.axis('on')
ax.tick_params(axis='both', which='major', labelsize=20)
ax.tick_params(axis='both', which='minor', labelsize=20)
ax.ticklabel_format(axis='x', style='sci',scilimits=(0,0))
ax.ticklabel_format(axis='y', style='sci',scilimits=(0,0))
ax.yaxis.offsetText.set_fontsize(20)
ax.xaxis.offsetText.set_fontsize(20)
ax.set_ylim(sample_rate*0.01, sample_rate//2)
cbformat = ticker.ScalarFormatter()
cbformat.set_scientific('%.3e')
cbformat.set_powerlimits((vmin,vmax))
cbformat.set_useMathText(True)
cb=plt.colorbar(im,format=cbformat)
cb.set_label(label='Intensity (a.u)',fontsize=20)
cb.ax.tick_params(labelsize=20)
plottitle=('2D Wavelet plot')
plt.suptitle(plottitle, fontsize=20)
plt.xlabel('Time(sec)',fontsize=20)
plt.ylabel('Frequency(Hz)',fontsize=20)
graphname=str(plotname)+'_Wavelet2D'+'.png'
plt.savefig(os.path.join(path, graphname), bbox_inches='tight',dpi=100)
plt.show()
plt.clf()
#%%
def ThreeDwaveletplot(rawspace,time,scales,plotname,windowsize,path,lowpass):
"""
Create a 3D wavelet plot.
Args:
rawspace (array-like): The input signal.
time (array-like): The time values corresponding to the input signal.
scales (array-like): The scales to be used for wavelet transform.
plotname (str): The name of the plot.
windowsize (int): The chunk size for calculating the wavelet transform.
path (str): The path to save the plot.
lowpass (float): The maximum frequency to display on the plot.
Returns:
None
"""
plt.rcParams.update(plt.rcParamsDefault)
waveletname = 'morl'
cmap = plt.cm.coolwarm
data_new=rawspace
dt = time[1] - time[0]
[coefficients, frequencies] = pywt.cwt(data_new, scales, waveletname, dt)
power = abs(coefficients)
period = 1. / frequencies
lenthA= len(frequencies)
frequencies= frequencies[frequencies < lowpass]
lenthB = len(frequencies)
trimlenth = lenthA - lenthB
power=np.delete(power, np.s_[0:trimlenth],axis=0)
timeplot, frequencies = np.meshgrid(time,frequencies)
X=timeplot
print('timeplot',len(np.ravel(timeplot)))
print('frequencies',len(np.ravel(frequencies)))
print('power',len(np.ravel(power)))
fig = plt.figure(figsize=(10,6))
plt.rcParams['agg.path.chunksize'] = len(np.ravel(timeplot))
ax = fig.add_subplot(projection='3d')
surf = ax.plot_surface(timeplot,frequencies, power, cmap=plt.cm.coolwarm,
linewidth=0, rstride=1, cstride=1,antialiased=True,vmin=np.min(power),
vmax=np.max(power),alpha=0.5)
vmax=np.max(power)
ax.set_ylim(0, lowpass)
ax.set_zlim(0, vmax)
ax.ticklabel_format(axis='x', style='sci',scilimits=(0,0))
ax.ticklabel_format(axis='y', style='sci',scilimits=(0,0))
plt.xlabel('Time(sec)',fontsize=20, labelpad=15)
ax.invert_xaxis()
plt.ylabel('Frequency(Hz)',fontsize=20, labelpad=10)
ax.zaxis.set_rotate_label(False)
ax.set_zlabel('Power',rotation=90,fontsize=20, labelpad=10)
ax.grid(False)
ax.set_facecolor('white')
# make the panes transparent
ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# make the grid lines transparent
ax.xaxis._axinfo["grid"]['color'] = (1,1,1,0)
ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0)
ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0)
ax.tick_params(axis='both', which='major', labelsize=20)
ax.tick_params(axis='both', which='minor', labelsize=20)
ax.yaxis.offsetText.set_fontsize(20)
ax.xaxis.offsetText.set_fontsize(20)
ax.view_init(azim = 10+10,elev = 32)
# fig.colorbar(surf)
cb=fig.colorbar(surf,ax=ax, shrink=0.5, aspect=5)
cb.set_label(label='Energy intensity',fontsize=20)
cb.ax.tick_params(labelsize=20)
# fig.patch.set_visible(True)
plottitle=('3D Wavelet plot')
plt.title(plottitle, fontsize=20)
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
graphname=str(plotname)+'_Wavelet3D'+'.png'
plt.savefig(os.path.join(path, graphname), bbox_inches='tight')
plt.show()
plt.clf()
del timeplot
del frequencies
del power

Event Timeline