Source code for pypret.retrieval.retriever

""" This module provides the basic classes for the pulse retrieval algorithms.
import numpy as np
from types import SimpleNamespace
from .. import io
from ..mesh_data import MeshData
from ..pulse_error import pulse_error
from .. import lib
from ..pnps import BasePNPS

# global dictionary that contains all PNPS classes

# =============================================================================
# Metaclass and factory
# =============================================================================
class MetaRetriever(type):
    """ Metaclass that registers Retriever classes in a global dictionary.
    def __new__(cls, clsmethod, bases, attrs):
        global _RETRIEVER_CLASSES
        newclass = super().__new__(cls, clsmethod, bases, attrs)
        method = newclass.method
        if method is None:
            return newclass
        # register the Retriever method, e.g. 'copra'
        if method in _RETRIEVER_CLASSES:
            raise ValueError("Two retriever classes implement retriever '%s'."
                             % method)
        _RETRIEVER_CLASSES[method] = newclass
        return newclass

class MetaIORetriever(io.MetaIO, MetaRetriever):
    # to fix metaclass conflicts

# =============================================================================
# Retriever Base class
# =============================================================================
[docs]class BaseRetriever(io.IO, metaclass=MetaIORetriever): """ The abstract base class for pulse retrieval. This class implements common functionality for different retrieval algorithms. """ method = None supported_schemes = None _io_store = ['pnps', 'options', 'logging', 'log', '_retrieval_state', '_result']
[docs] def __init__(self, pnps, logging=False, verbose=False, **kwargs): self.pnps = pnps self.ft = self.pnps.ft self.options = SimpleNamespace(**kwargs) self._result = None self.logging = logging self.verbose = verbose self.log = None rs = self._retrieval_state = SimpleNamespace() rs.running = False if (self.supported_schemes is not None and pnps.scheme not in self.supported_schemes): raise ValueError("Retriever '%s' does not support scheme '%s'. " "It only supports %s." % (self.method, pnps.scheme, self.supported_schemes) )
[docs] def retrieve(self, measurement, initial_guess, weights=None, **kwargs): """ Retrieve pulse from ``measurement`` starting at ``initial_guess``. Parameters ---------- measurement : MeshData A MeshData instance that contains the PNPS measurement. The first axis has to correspond to the PNPS parameter, the second to the frequency. The data has to be the measured _intensity_ over the frequency (not wavelength!). The second axis has to match exactly the frequency axis of the underlying PNPS instance. No interpolation is done. initial_guess : 1d-array The spectrum of the pulse that is used as initial guess in the iterative retrieval. weights : 1d-array Weights that are attributed to the measurement for retrieval. In the case of (assumed) Gaussian uncertainties with standard deviation sigma they should correspond to 1/sigma. Not all algorithms support using the weights. kwargs : dict Can override retrieval options specified in :func:`__init__`. Notes ----- This function provides no interpolation or data processing. You have to write a retriever wrapper for that purpose. """ self.options.__dict__.update(**kwargs) if not isinstance(measurement, MeshData): raise ValueError("measurement has to be a MeshData instance!") self._retrieve_begin(measurement, initial_guess, weights) self._retrieve() self._retrieve_end()
def _retrieve_begin(self, measurement, initial_guess, weights): pnps = self.pnps if not np.all(pnps.process_w == measurement.axes[1]): raise ValueError("Measurement has to lie on simulation grid!") # Store measurement self.measurement = measurement self.parameter = measurement.axes[0] self.Tmn_meas = self.initial_guess = initial_guess # set the size self.M, self.N = self.Tmn_meas.shape # Setup the weights if weights is None: self._weights = np.ones((self.M, self.N)) else: self._weights = weights.copy() # Retrieval state rs = self._retrieval_state rs.approximate_error = False rs.running = True rs.steps_since_improvement = 0 # Initialize result res = self._result = SimpleNamespace() res.trace_error = self.trace_error(self.initial_guess) res.approximate_error = False res.spectrum = self.initial_guess.copy() # Setup the logger if self.logging: log = self.log = SimpleNamespace() log.trace_error = [] log.initial_guess = self.initial_guess.copy() else: self.log = None if self.verbose: print("Started retriever '%s'" % self.method) print("Options:") print(self.options) print("Initial trace error R = {:.10e}".format(res.trace_error)) print("Starting retrieval...") print() def _retrieve_end(self): rs = self._retrieval_state rs.running = False res = self._result if res.approximate_error: res.trace_error = self.trace_error(res.spectrum) res.approximate_error = False def _project(self, measured, Smk): """ Performs the projection on the measured intensity. """ # in frequency domain Smn = self.ft.forward(Smk) # project and specially handle values with zero amplitude absSmn = np.abs(Smn) f = (absSmn > 0.0) Smn[~f] = np.sqrt(measured[~f] + 0.0j) Smn[f] = Smn[f] / absSmn[f] * np.sqrt(measured[f] + 0.0j) # back in time domain Smk2 = self.ft.backward(Smn) return Smk2 def _objective_function(self, spectrum): """ Calculates the minimization objective from the pulse spectrum. This is Eq. 11 in the paper: r = sum (Tmn^meas - mu * Tmn) """ # calculate the PNPS trace Tmn = self.pnps.calculate(spectrum, self.parameter) return self._r(Tmn)
[docs] def trace_error(self, spectrum, store=True): """ Calculates the trace error from the pulse spectrum. """ Tmn = self.pnps.calculate(spectrum, self.parameter) return self._R(Tmn, store=store)
def _r(self, Tmn, store=True): """ Calculates the minimization objective r from a simulated trace Tmn. """ diff = self._error_vector(Tmn, store=store) return np.sum(diff * diff) def _error_vector(self, Tmn, store=True): """ Calculates the residual vector from measured to simulated intensity. """ # rename rs = self._retrieval_state Tmn_meas = self.Tmn_meas # scaling factor w2 = self._weights * self._weights mu = np.sum(Tmn_meas * Tmn * w2) / np.sum(Tmn * Tmn * w2) # store intermediate results in current retrieval state if store: = mu rs.Tmn = Tmn rs.Smk = self.pnps.Smk return np.ravel((Tmn_meas - mu * Tmn) * self._weights) def _R(self, Tmn, store=True): """ Calculates the trace error from a simulated trace Tmn. """ r = self._r(Tmn, store=store) return self._Rr(r) def _Rr(self, r): """ Calculates the trace error from the minimization objective r. """ return np.sqrt(r / (self.M * self.N * (self.Tmn_meas * self._weights).max()**2))
[docs] def result(self, pulse_original=None, full=True): """ Analyzes the retrieval results in one retrieval instance and processes it for plotting or storage. """ rs = self._retrieval_state if self._result is None or self._retrieval_state.running: return None res = SimpleNamespace() # the meta data res.parameter = self.parameter res.options = self.options res.logging = self.logging res.measurement = self.measurement # store the retriever itself if full: res.pnps = self.pnps else: res.pnps = None # the pulse spectra # 1 - the retrieved pulse res.pulse_retrieved = self._result.spectrum # 2 - the original test pulse, optional res.pulse_original = pulse_original # 3 - the initial guess res.pulse_initial = self.initial_guess # the measurement traces # 1 - the original data used for retrieval res.trace_input = self.Tmn_meas # 2 - the trace error and the trace calculated from the retrieved pulse res.trace_error = self.trace_error(res.pulse_retrieved) res.trace_retrieved = * rs.Tmn res.response_function = # the weights res.weights = self._weights # this is set if the original spectrum is provided if res.pulse_original is not None: # the trace error of the test pulse (non-zero for noisy input) res.trace_error_optimal = self.trace_error(res.pulse_original) # 3 - the optimal trace calculated from the test pulse res.trace_original = * rs.Tmn dot_ambiguity = False if self.pnps.method == "ifrog" or self.pnps.scheme == "shg-frog": dot_ambiguity = True # the pulse error to the test pulse res.pulse_error, res.pulse_retrieved = pulse_error( res.pulse_retrieved, res.pulse_original, self.ft, dot_ambiguity=dot_ambiguity) if res.logging: # the logged trace errors res.trace_errors = np.array(self.log.trace_error) # the running minimum of the trace errors (for plotting) res.rm_trace_errors = np.minimum.accumulate(res.trace_errors, axis=-1) if self.verbose: lib.retrieval_report(res) return res
[docs]def Retriever(pnps: BasePNPS, method: str = "copra", maxiter=300, maxfev=None, logging=False, verbose=False, **kwargs) -> BaseRetriever: """ Creates a retriever instance. Parameters ---------- pnps : PNPS A PNPS instance that is used to simulate a PNPS measurement. method : str, optional Type of solver. Should be one of - 'copra' :class:`(see here) <COPRARetriever>` - 'gpa' :class:`(see here) <GPARetriever>` - 'gp-dscan' :class:`(see here) <GPDSCANRetriever>` - 'pcgpa' :class:`(see here) <PCGPARetriever>` - 'pie' :class:`(see here) <PIERetriever>` - 'lm' :class:`(see here) <LMRetriever>` - 'bfgs' :class:`(see here) <BFGSRetriever>` - 'de' :class:`(see here) <DERetriever>` - 'nelder-mead' :class:`(see here) <NMRetriever>` 'copra' is the default choice. maxiter : int, optional The maximum number of algorithm iterations. The default is 300. maxfev : int, optional The maximum number of function evaluations. If given, the algorithms stop before this number is reached. Not all algorithms support this feature. Default is ``None``, in which case it is ignored. logging : bool, optional Stores trace errors and pulses over the iterations if supported by the retriever class. Default is `False`. verbose : bool, optional Prints out trace errors during the iteration if supported by the retriever class. Default is `False`. """ method = method.lower() try: cls = _RETRIEVER_CLASSES[method] except KeyError: raise ValueError("Retriever '%s' is unknown!" % (method)) return cls(pnps, maxiter=maxiter, maxfev=maxfev, logging=logging, verbose=verbose, **kwargs)