Source code for pysme.gui.plot_pyplot

# -*- coding: utf-8 -*-
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button, SpanSelector
from scipy.constants import c

clight = c * 1e-3

# from ..sme.sme import SME_Struct

from .plot_colors import PlotColors

fmt = PlotColors()


[docs]class MaskPlot: """A plot that can be used to define the mask""" # Controls: # a, d keys: Switch between segments # Left, Right Mouse button: Select sections to change the mask depending on current mode # mode == "good/bad" : left -> line mask # right -> bad mask # Does not override existing good line mask # # mode == "line/cont" : left -> line mask # right -> continuum mask # Does not change bad line mask # Shift key : Switch between "good/bad" and "line/cont" modes def __init__(self, sme, segment=0, axes=None, show=True): self.wave = sme.wave self.spec = sme.spec self.mask = sme.mask self.smod = sme.synth self.segment = segment self.nsegments = len(self.wave) self.mode = "line/cont" self.lines = sme.linelist self.vrad = sme.vrad self.vrad = [v if v is not None else 0 for v in self.vrad] self.line_plot = None self.lock = False if axes is None: self.im = plt.subplots()[1] else: self.im = axes self.selector_line = SpanSelector( self.im, self.section_line_callback, direction="horizontal", useblit=True, button=(1,), ) self.selector_cont = SpanSelector( self.im, self.section_continuum_callback, direction="horizontal", useblit=True, button=(3,), ) self.im.figure.canvas.mpl_connect("key_press_event", self.key_event) self.im.callbacks.connect("xlim_changed", self.resize_event) ax_next = plt.axes([0.8, 0.025, 0.1, 0.04]) self.button_next = Button(ax_next, "-->") self.button_next.on_clicked(self.next_segment) ax_prev = plt.axes([0.7, 0.025, 0.1, 0.04]) self.button_prev = Button(ax_prev, "<--") self.button_prev.on_clicked(self.previous_segment) self.plot() if show: plt.show()
[docs] def resize_event(self, event): if self.line_plot is not None and not self.lock: xlim = np.array(self.im.get_xlim()) xlim *= 1 - self.vrad[self.segment] / clight idx = (self.lines_segment.wlcent >= xlim[0]) & ( self.lines_segment.wlcent <= xlim[1] ) importance = self.lines_segment.depth - min(self.lines_segment.depth[idx]) if max(importance[idx]) != 0: importance /= max(importance[idx]) else: importance[idx] = 1 for i in np.where(idx)[0]: self.line_plot[i][0].set_visible(True) self.line_plot[i][1].set_visible(True) self.line_plot[i][0].set_alpha(importance[i]) self.line_plot[i][1].set_alpha(importance[i]) for i in np.where(~idx)[0]: self.line_plot[i][0].set_visible(False) self.line_plot[i][1].set_visible(False)
[docs] def key_event(self, event): if event.key in ["shift"]: if self.mode == "good/bad": self.mode = "line/cont" else: self.mode = "good/bad" print("Switch to mode: %s" % self.mode) if event.key in ["a", "left"]: self.goto_segment(self.segment - 1) if event.key in ["d", "right"]: self.goto_segment(self.segment + 1)
[docs] def section_line_callback(self, min, max): mask_type = "line" if self.mode == "line/cont" else "good" self.section_select(min, max, mask_type)
[docs] def section_continuum_callback(self, min, max): mask_type = "cont" if self.mode == "line/cont" else "bad" self.section_select(min, max, mask_type)
[docs] def section_select(self, min, max, mask_type): print("{} {:.3f} - {:.3f}".format(mask_type, min, max)) # find points idx = (self.wave[self.segment] <= max) & (self.wave[self.segment] >= min) # update masks if mask_type == "line": mask_value = 1 elif mask_type == "cont": mask_value = 2 elif mask_type == "bad": mask_value = 0 elif mask_type == "good": mask_value = 1 if mask_type in ["line", "cont"]: idx = idx & (self.mask[self.segment] != 0) if mask_type == "good": idx = idx & (self.mask[self.segment] == 0) self.mask[self.segment][idx] = mask_value # update plot self.lock = True self.update() self.lock = False
[docs] def plot(self, update=False): if self.mask is not None: mask = self.mask[self.segment] if self.spec is not None and not update: self.im.plot( self.wave[self.segment], self.spec[self.segment], label="Observation", **fmt["Obs"], ) if self.smod is not None and not update: self.im.plot( self.wave[self.segment], self.smod[self.segment], label="Synthethic", **fmt["Syn"], ) if self.spec is not None: self.fill_line = self.im.fill_between( self.wave[self.segment], 0, self.spec[self.segment], where=mask == 1, label="Mask Line", **fmt["LineMask"], ) m = mask == 2 m[1:] = m[:-1] | m[1:] m[:-1] = m[:-1] | m[1:] self.fill_cont = self.im.fill_between( self.wave[self.segment], 0, self.spec[self.segment], where=m, label="Mask Continuum", **fmt["ContMask"], ) try: if self.lines is not None and not update: self.lock = True xlim = self.wave[self.segment][[0, -1]] xlim *= 1 - self.vrad[self.segment] / clight self.lines_segment = self.lines[ (self.lines.wlcent >= xlim[0]) & (self.lines.wlcent <= xlim[1]) ] importance = self.lines_segment.depth - min(self.lines_segment.depth) importance /= max(importance) self.line_plot = [[None, None] for _ in self.lines_segment] for i, line in enumerate(self.lines_segment): # if i > threshold: wl = line.wlcent * (1 + self.vrad[self.segment] / clight) self.line_plot[i][0] = self.im.text( wl, 1.1, f"{line.species} {line.wlcent:.2f}", rotation="vertical", horizontalalignment="right", verticalalignment="top", alpha=importance[i], ) if self.spec is not None: depth = np.interp( wl, self.wave[self.segment], self.spec[self.segment], ) else: depth = np.interp( wl, self.wave[self.segment], self.smod[self.segment], ) self.line_plot[i][1] = self.im.vlines( wl, ymin=depth, ymax=1.1, alpha=importance[i] ) self.lock = False except ValueError: pass self.im.figure.suptitle("SME Fit\nSegment %i" % self.segment) self.im.set_xlabel("Wavelength [Å]") self.im.set_ylabel("normalized Intensity") self.im.set_ylim((0, 1.2)) self.im.set_xlim(self.im.get_xlim()) self.im.legend(loc="lower left") self.im.figure.canvas.draw()
[docs] def update(self, reset_view=False): if not reset_view: xlim = self.im.get_xlim() ylim = self.im.get_ylim() # Remove filled between if reset_view: self.im.collections.clear() elif isinstance(self.im.collections[0], mpl.collections.PolyCollection): del self.im.collections[:2] else: del self.im.collections[-2:] # del self.im.collections[:2] self.plot(update=True) if not reset_view: self.im.set_xlim(xlim) self.im.set_ylim(ylim)
[docs] def connect_axes(self): self.im.callbacks.connect("xlim_changed", self.resize_event)
[docs] def goto_segment(self, segment): if segment >= 0 and segment < len(self.wave) - 1: self.segment = segment self.im.cla() self.plot() self.connect_axes()
[docs] def next_segment(self, _=None): self.goto_segment(self.segment + 1)
[docs] def previous_segment(self, _=None): self.goto_segment(self.segment - 1)