""" This module provides the basic classes for the pulse retrieval algorithms.
"""
import numpy as np
from types import SimpleNamespace
from .. import io
from ..mesh_data import MeshData
from ..pulse_error import pulse_error
from .. import lib
from ..pnps import BasePNPS
# global dictionary that contains all PNPS classes
_RETRIEVER_CLASSES = {}
# =============================================================================
# Metaclass and factory
# =============================================================================
class MetaRetriever(type):
""" Metaclass that registers Retriever classes in a global dictionary.
"""
def __new__(cls, clsmethod, bases, attrs):
global _RETRIEVER_CLASSES
newclass = super().__new__(cls, clsmethod, bases, attrs)
method = newclass.method
if method is None:
return newclass
# register the Retriever method, e.g. 'copra'
if method in _RETRIEVER_CLASSES:
raise ValueError("Two retriever classes implement retriever '%s'."
% method)
_RETRIEVER_CLASSES[method] = newclass
return newclass
class MetaIORetriever(io.MetaIO, MetaRetriever):
# to fix metaclass conflicts
pass
# =============================================================================
# Retriever Base class
# =============================================================================
[docs]class BaseRetriever(io.IO, metaclass=MetaIORetriever):
""" The abstract base class for pulse retrieval.
This class implements common functionality for different retrieval
algorithms.
"""
method = None
supported_schemes = None
_io_store = ['pnps', 'options', 'logging', 'log',
'_retrieval_state', '_result']
[docs] def __init__(self, pnps, logging=False, verbose=False, **kwargs):
self.pnps = pnps
self.ft = self.pnps.ft
self.options = SimpleNamespace(**kwargs)
self._result = None
self.logging = logging
self.verbose = verbose
self.log = None
rs = self._retrieval_state = SimpleNamespace()
rs.running = False
if (self.supported_schemes is not None and
pnps.scheme not in self.supported_schemes):
raise ValueError("Retriever '%s' does not support scheme '%s'. "
"It only supports %s." %
(self.method, pnps.scheme, self.supported_schemes)
)
[docs] def retrieve(self, measurement, initial_guess, weights=None,
**kwargs):
""" Retrieve pulse from ``measurement`` starting at ``initial_guess``.
Parameters
----------
measurement : MeshData
A MeshData instance that contains the PNPS measurement. The first
axis has to correspond to the PNPS parameter, the second to the
frequency. The data has to be the measured _intensity_ over the
frequency (not wavelength!). The second axis has to match exactly
the frequency axis of the underlying PNPS instance. No
interpolation is done.
initial_guess : 1d-array
The spectrum of the pulse that is used as initial guess in the
iterative retrieval.
weights : 1d-array
Weights that are attributed to the measurement for retrieval.
In the case of (assumed) Gaussian uncertainties with standard
deviation sigma they should correspond to 1/sigma.
Not all algorithms support using the weights.
kwargs : dict
Can override retrieval options specified in :func:`__init__`.
Notes
-----
This function provides no interpolation or data processing. You have
to write a retriever wrapper for that purpose.
"""
self.options.__dict__.update(**kwargs)
if not isinstance(measurement, MeshData):
raise ValueError("measurement has to be a MeshData instance!")
self._retrieve_begin(measurement, initial_guess, weights)
self._retrieve()
self._retrieve_end()
def _retrieve_begin(self, measurement, initial_guess, weights):
pnps = self.pnps
if not np.all(pnps.process_w == measurement.axes[1]):
raise ValueError("Measurement has to lie on simulation grid!")
# Store measurement
self.measurement = measurement
self.parameter = measurement.axes[0]
self.Tmn_meas = measurement.data
self.initial_guess = initial_guess
# set the size
self.M, self.N = self.Tmn_meas.shape
# Setup the weights
if weights is None:
self._weights = np.ones((self.M, self.N))
else:
self._weights = weights.copy()
# Retrieval state
rs = self._retrieval_state
rs.approximate_error = False
rs.running = True
rs.steps_since_improvement = 0
# Initialize result
res = self._result = SimpleNamespace()
res.trace_error = self.trace_error(self.initial_guess)
res.approximate_error = False
res.spectrum = self.initial_guess.copy()
# Setup the logger
if self.logging:
log = self.log = SimpleNamespace()
log.trace_error = []
log.initial_guess = self.initial_guess.copy()
else:
self.log = None
if self.verbose:
print("Started retriever '%s'" % self.method)
print("Options:")
print(self.options)
print("Initial trace error R = {:.10e}".format(res.trace_error))
print("Starting retrieval...")
print()
def _retrieve_end(self):
rs = self._retrieval_state
rs.running = False
res = self._result
if res.approximate_error:
res.trace_error = self.trace_error(res.spectrum)
res.approximate_error = False
def _project(self, measured, Smk):
""" Performs the projection on the measured intensity.
"""
# in frequency domain
Smn = self.ft.forward(Smk)
# project and specially handle values with zero amplitude
absSmn = np.abs(Smn)
f = (absSmn > 0.0)
Smn[~f] = np.sqrt(measured[~f] + 0.0j)
Smn[f] = Smn[f] / absSmn[f] * np.sqrt(measured[f] + 0.0j)
# back in time domain
Smk2 = self.ft.backward(Smn)
return Smk2
def _objective_function(self, spectrum):
""" Calculates the minimization objective from the pulse spectrum.
This is Eq. 11 in the paper:
r = sum (Tmn^meas - mu * Tmn)
"""
# calculate the PNPS trace
Tmn = self.pnps.calculate(spectrum, self.parameter)
return self._r(Tmn)
[docs] def trace_error(self, spectrum, store=True):
""" Calculates the trace error from the pulse spectrum.
"""
Tmn = self.pnps.calculate(spectrum, self.parameter)
return self._R(Tmn, store=store)
def _r(self, Tmn, store=True):
""" Calculates the minimization objective r from a simulated trace Tmn.
"""
diff = self._error_vector(Tmn, store=store)
return np.sum(diff * diff)
def _error_vector(self, Tmn, store=True):
""" Calculates the residual vector from measured to simulated
intensity.
"""
# rename
rs = self._retrieval_state
Tmn_meas = self.Tmn_meas
# scaling factor
w2 = self._weights * self._weights
mu = np.sum(Tmn_meas * Tmn * w2) / np.sum(Tmn * Tmn * w2)
# store intermediate results in current retrieval state
if store:
rs.mu = mu
rs.Tmn = Tmn
rs.Smk = self.pnps.Smk
return np.ravel((Tmn_meas - mu * Tmn) * self._weights)
def _R(self, Tmn, store=True):
""" Calculates the trace error from a simulated trace Tmn.
"""
r = self._r(Tmn, store=store)
return self._Rr(r)
def _Rr(self, r):
""" Calculates the trace error from the minimization objective r.
"""
return np.sqrt(r / (self.M * self.N *
(self.Tmn_meas * self._weights).max()**2))
[docs] def result(self, pulse_original=None, full=True):
""" Analyzes the retrieval results in one retrieval instance
and processes it for plotting or storage.
"""
rs = self._retrieval_state
if self._result is None or self._retrieval_state.running:
return None
res = SimpleNamespace()
# the meta data
res.parameter = self.parameter
res.options = self.options
res.logging = self.logging
res.measurement = self.measurement
# store the retriever itself
if full:
res.pnps = self.pnps
else:
res.pnps = None
# the pulse spectra
# 1 - the retrieved pulse
res.pulse_retrieved = self._result.spectrum
# 2 - the original test pulse, optional
res.pulse_original = pulse_original
# 3 - the initial guess
res.pulse_initial = self.initial_guess
# the measurement traces
# 1 - the original data used for retrieval
res.trace_input = self.Tmn_meas
# 2 - the trace error and the trace calculated from the retrieved pulse
res.trace_error = self.trace_error(res.pulse_retrieved)
res.trace_retrieved = rs.mu * rs.Tmn
res.response_function = rs.mu
# the weights
res.weights = self._weights
# this is set if the original spectrum is provided
if res.pulse_original is not None:
# the trace error of the test pulse (non-zero for noisy input)
res.trace_error_optimal = self.trace_error(res.pulse_original)
# 3 - the optimal trace calculated from the test pulse
res.trace_original = rs.mu * rs.Tmn
dot_ambiguity = False
if self.pnps.method == "ifrog" or self.pnps.scheme == "shg-frog":
dot_ambiguity = True
# the pulse error to the test pulse
res.pulse_error, res.pulse_retrieved = pulse_error(
res.pulse_retrieved, res.pulse_original, self.ft,
dot_ambiguity=dot_ambiguity)
if res.logging:
# the logged trace errors
res.trace_errors = np.array(self.log.trace_error)
# the running minimum of the trace errors (for plotting)
res.rm_trace_errors = np.minimum.accumulate(res.trace_errors,
axis=-1)
if self.verbose:
lib.retrieval_report(res)
return res
[docs]def Retriever(pnps: BasePNPS, method: str = "copra", maxiter=300, maxfev=None,
logging=False, verbose=False, **kwargs) -> BaseRetriever:
""" Creates a retriever instance.
Parameters
----------
pnps : PNPS
A PNPS instance that is used to simulate a PNPS measurement.
method : str, optional
Type of solver. Should be one of
- 'copra' :class:`(see here) <COPRARetriever>`
- 'gpa' :class:`(see here) <GPARetriever>`
- 'gp-dscan' :class:`(see here) <GPDSCANRetriever>`
- 'pcgpa' :class:`(see here) <PCGPARetriever>`
- 'pie' :class:`(see here) <PIERetriever>`
- 'lm' :class:`(see here) <LMRetriever>`
- 'bfgs' :class:`(see here) <BFGSRetriever>`
- 'de' :class:`(see here) <DERetriever>`
- 'nelder-mead' :class:`(see here) <NMRetriever>`
'copra' is the default choice.
maxiter : int, optional
The maximum number of algorithm iterations. The default is 300.
maxfev : int, optional
The maximum number of function evaluations. If given, the algorithms
stop before this number is reached. Not all algorithms support this
feature. Default is ``None``, in which case it is ignored.
logging : bool, optional
Stores trace errors and pulses over the iterations if supported
by the retriever class. Default is `False`.
verbose : bool, optional
Prints out trace errors during the iteration if supported by the
retriever class. Default is `False`.
"""
method = method.lower()
try:
cls = _RETRIEVER_CLASSES[method]
except KeyError:
raise ValueError("Retriever '%s' is unknown!" % (method))
return cls(pnps, maxiter=maxiter, maxfev=maxfev,
logging=logging, verbose=verbose, **kwargs)