""" This module implements several helper routines for plotting.
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import EngFormatter
from . import lib
from .frequencies import convert
[docs]def plot_meshdata(ax, md, cmap="nipy_spectral"):
x, y = lib.edges(md.axes[1]), lib.edges(md.axes[0])
im = ax.pcolormesh(x, y, md.data, cmap=cmap)
ax.set_xlabel(md.labels[1])
ax.set_ylabel(md.labels[0])
fx = EngFormatter(unit=md.units[1])
ax.xaxis.set_major_formatter(fx)
fy = EngFormatter(unit=md.units[0])
ax.yaxis.set_major_formatter(fy)
return im
[docs]class MeshDataPlot:
[docs] def __init__(self, mesh_data, plot=True, **kwargs):
self.md = mesh_data
if plot:
self.plot(**kwargs)
[docs] def plot(self, show=True):
md = self.md
fig, ax = plt.subplots()
im = plot_meshdata(ax, md, "nipy_spectral")
fig.colorbar(im, ax=ax)
self.fig, self.ax = fig, ax
self.im = im
if show:
fig.tight_layout()
plt.show()
[docs] def show(self):
plt.show()
[docs]def plot_complex(x, y, ax, ax2, yaxis='intensity', limit=False,
phase_blanking=False, phase_blanking_threshold=1e-3,
amplitude_line="r-", phase_line="b-"):
if yaxis == "intensity":
amp = lib.abs2(y)
elif yaxis == "amplitude":
amp = np.abs(y)
else:
raise ValueError("yaxis mode '%s' is unknown!" % yaxis)
phase = lib.phase(y)
# center phase by weighted mean
phase -= lib.mean(phase, amp * amp)
if phase_blanking:
x2, phase2 = lib.mask_phase(x, amp, phase, phase_blanking_threshold)
else:
x2, phase2 = x, phase
if limit:
xlim = lib.limit(x, amp)
ax.set_xlim(xlim)
f = (x2 >= xlim[0]) & (x2 <= xlim[1])
ax2.set_ylim(lib.limit(phase2[f], padding=0.05))
li1, = ax.plot(x, amp, amplitude_line)
li2, = ax2.plot(x2, phase2, phase_line)
return li1, li2, amp, phase
[docs]class PulsePlot:
[docs] def __init__(self, pulse, plot=True, **kwargs):
self.pulse = pulse
if plot:
self.plot(**kwargs)
[docs] def plot(self, xaxis='wavelength', yaxis='intensity', limit=True,
oversampling=False, phase_blanking=False,
phase_blanking_threshold=1e-3, show=True):
pulse = self.pulse
fig, axs = plt.subplots(1, 2)
ax1, ax2 = axs.flat
ax12 = ax1.twinx()
ax22 = ax2.twinx()
if oversampling:
t = np.linspace(pulse.t[0], pulse.t[-1], pulse.N * oversampling)
field = pulse.field_at(t)
else:
t = pulse.t
field = pulse.field
# time domain
li11, li12, tamp, tpha = plot_complex(t, field, ax1, ax12, yaxis=yaxis,
phase_blanking=phase_blanking, limit=limit,
phase_blanking_threshold=phase_blanking_threshold)
fx = EngFormatter(unit="s")
ax1.xaxis.set_major_formatter(fx)
ax1.set_title("time domain")
ax1.set_xlabel("time")
ax1.set_ylabel(yaxis)
ax12.set_ylabel("phase (rad)")
# frequency domain
if oversampling:
w = np.linspace(pulse.w[0], pulse.w[-1], pulse.N * oversampling)
spectrum = pulse.spectrum_at(w)
else:
w = pulse.w
spectrum = pulse.spectrum
if xaxis == "wavelength":
w = convert(w + pulse.w0, "om", "wl")
unit = "m"
label = "wavelength"
elif xaxis == "frequency":
w = w
unit = " rad Hz"
label = "frequency"
li21, li22, samp, spha = plot_complex(w, spectrum, ax2, ax22, yaxis=yaxis,
phase_blanking=phase_blanking, limit=limit,
phase_blanking_threshold=phase_blanking_threshold)
fx = EngFormatter(unit=unit)
ax2.xaxis.set_major_formatter(fx)
ax2.set_title("frequency domain")
ax2.set_xlabel(label)
ax2.set_ylabel(yaxis)
ax22.set_ylabel("phase (rad)")
self.fig = fig
self.ax1, self.ax2 = ax1, ax2
self.ax12, self.ax22 = ax12, ax22
self.li11, self.li12, self.li21, self.li22 = li11, li12, li21, li22
self.tamp, self.tpha = tamp, tpha
self.samp, self.spha = samp, spha
if show:
fig.tight_layout()
plt.show()