Source code for specviz.core.threads

"""
Thread Helpers
"""
from qtpy.QtCore import QThread, Signal
import os
import logging

from ..core.data import Spectrum1DRef, Spectrum1DRefModelLayer
from ..interfaces.factories import FitterFactory

import astropy.io.registry as io_registry

__all__ = [
    'FileLoadThread',
    'FitModelThread',
]

[docs]class FileLoadThread(QThread): """ Asynchronously read in a file Parameters ---------- parent: QtWidget The parent widget or None Call ---- file_name: str Name of the file to read. file_filter: str Type of file to read. Attributes ---------- file_name: str Name of the file to read. file_filter: str Type of file to read. If `Auto`, try all known formats. Signals ------- status(message, timeout) State of the thread message: The status message timeout: Time (msec) to display message result(Spectrum1DRef) The file's data """ status = Signal(str, int) result = Signal(Spectrum1DRef) def __init__(self, parent=None): super(FileLoadThread, self).__init__(parent) self.file_name = "" self.file_filter = ""
[docs] def __call__(self, file_name, file_filter): """ Initialize the thread """ self.file_name = file_name self.file_filter = file_filter
[docs] def run(self): """ Start thread to read the file. """ self.status.emit("Loading file...", 0) data = self.read_file(self.file_name, self.file_filter) if data is not None: self.status.emit("File loaded successfully!", 5000) else: self.status.emit("An error occurred while loading file.", 5000) if data is not None: self.result.emit(data) else: logging.error("Could not open file.")
[docs] def read_file(self, file_name, file_filter): """ Convenience method that directly reads a spectrum from a file. Parameters ---------- file_name: str Name of the file to read. file_filter: str Type of file to read. If `Auto`, try all known formats. Returns ------- data: Spectrum1DRef The file's data or None if no known formats are found. Notes ----- This exists mostly to facilitate development workflow. In time it could be augmented to support fancier features such as wildcards, file lists, mixed file types, and the like. Note that the filter string is hard coded here; its details might depend on the intrincacies of the registries, loaders, and data classes. In other words, this is brittle code. """ file_filter = 'Auto (*)' if file_filter is None else file_filter logging.info("Attempting to read file {} with {}.".format(file_name, file_filter)) if not (file_name and file_filter): return file_name = str(file_name) file_ext = os.path.splitext(file_name)[-1] all_formats = io_registry.get_formats(Spectrum1DRef)['Format'] if file_filter == 'Auto (*)': #-- sort loaders by priorty given in the definition all_priority = [getattr(io_registry.get_reader(fmt, Spectrum1DRef), 'priority', 0) for fmt in all_formats] all_registry = sorted(zip(all_formats, all_priority), key=lambda item: item[1], reverse=True) all_formats = [item[0] for item in all_registry] else: all_formats = [x for x in all_formats if file_filter in x] for format in all_formats: logging.info("Trying to load with {}".format(format)) try: data = Spectrum1DRef.read(file_name, format=format) return data except Exception as e: logging.error("Incompatible loader for selected data: {" "} because {}".format(file_filter, e))
[docs]class FitModelThread(QThread): """ Asynchronously fit a model to a layer Parameters ---------- parent: QtWidget The parent widget or None Call ---- model_layer: Spectrum1DRefLayer The layer to fit to. fitter_name: An `~astropy.modeling` fitter The fitter to use mask: numpy.ndarray The mask to apply Attributes ---------- model_layer: Spectrum1DRefLayer The layer to fit to. fitter_name: An `~astropy.modeling` fitter The fitter to use mask: numpy.ndarray The mask to apply Signals ------- status(message, timeout) State of the thread message: The status message timeout: Time (msec) to display message result(Spectrum1DRefModelLayer) The fit """ status = Signal(str, int) result = Signal(Spectrum1DRefModelLayer) def __init__(self, parent=None): super(FitModelThread, self).__init__(parent) self.model_layer = None self.fitter_name = ""
[docs] def __call__(self, model_layer, fitter_name, mask=None): self.model_layer = model_layer self.fitter_name = fitter_name self.mask = mask
[docs] def run(self): """ Start thread to fit the model """ self.status.emit("Fitting model...", 0) model_layer, message = self.fit_model(self.model_layer, self.fitter_name, self.mask) if not message: self.status.emit("Fit completed successfully!", 5000) else: self.status.emit("Fit completed, but with warnings.", 5000) self.result.emit(model_layer)
[docs] def fit_model(self, model_layer, fitter_name, mask=None): """ Fit the model Parameters ---------- model_layer: Spectrum1DRefLayer The layer to fit to. fitter_name: An `~astropy.modeling` fitter The fitter to use mask: numpy.ndarray The mask to apply Returns ------- (model_layer, fitter_message): Spectrum1DRefLayer, str The model_layer.model is updated with the fit paramters. The message is from the fitter itself. """ if not hasattr(model_layer, 'model'): logging.warning("This layer has no model to fit.") return # When fitting, the selected layer is a ModelLayer, thus # the data to be fitted resides in the parent parent_layer = model_layer._parent if parent_layer is None: return flux = parent_layer.data dispersion = parent_layer.dispersion model = model_layer.model # The fitting should only consider the masked regions flux = flux[mask].compressed().value dispersion = dispersion[mask].compressed().value # Get compressed versions of the data arrays # flux = flux.compressed().value # dispersion = dispersion.compressed().value # If the number of parameters is greater than the number of data # points, bail if len(model.parameters) > flux.size: logging.warning("Unable to perform fit; number of parameters is " "greater than the number of data points.") return # Perform fitting of model if fitter_name: fitter = FitterFactory.all_fitters[fitter_name]() else: fitter = FitterFactory.default_fitter() fitted_model = fitter(model, dispersion, flux, maxiter=2000) if 'message' in fitter.fit_info: # The fitter 'message' should probably be logged at INFO level. # Problem is, info messages do not display in the error console, # and we, ideally, want the user to see the message immediately # after the fit is executed. logging.warning(fitter.fit_info['message']) # Update original model with new values from fitted model if hasattr(fitted_model, '_submodels'): for i in range(len(fitted_model._submodels)): for pname in model._submodels[i].param_names: value = getattr(fitted_model, "{}_{}".format(pname, i)) setattr(model._submodels[i], pname, value.value) setattr(model[i], pname, value.value) else: for pname in model.param_names: value = getattr(fitted_model, "{}".format(pname)) setattr(model, pname, value.value) # model_layer.model = fitted_model # update GUI with fit results return model_layer, fitter.fit_info.get('message', "")