Source code for oda_api.plot_tools

# mypy: ignore-errors
# pylint: skip-file
# pylint: disable-all

from __future__ import absolute_import, division, print_function

import os.path
from builtins import (str, open, range,
                      zip, round, input, int, pow, object, zip)


__author__ = "Carlo Ferrigno"

import json
import numpy
import copy

from matplotlib import pylab as plt
from matplotlib.widgets import Slider
from matplotlib import cm

from astropy import table
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astroquery.simbad import Simbad

import oda_api.api as api

import time as _time
import astropy.wcs as wcs

# NOTE GW, optional
try:
    import ligo.skymap.plot
except ModuleNotFoundError:
    pass 

import logging

logger = logging.getLogger("oda_api.plot_tools")

__all__ = ['OdaImage', 'OdaLightCurve', 'OdaGWContours', 'OdaSpectrum']


class OdaProduct(object):

    def __init__(self, data):
        self.data = data
        self.meta = None
        self.logger = logger.getChild(self.__class__.__name__.lower())
        self.progress_logger = self.logger.getChild("progress")
        try:
            self.instrument = data._p_list[0].data_unit[1].header['INSTRUME']
        except:
            self.logger.warning('No instrument in data collection')
            self.instrument = 'none'

[docs] class OdaImage(OdaProduct): name = 'image' def get_image_for_gallery(self, ext_sig=None, meta=None, header=None, sources=None, levels=None, cmap=cm.gist_earth, unit_ID=4, det_sigma=3, output_folder=None): plt = self.build_fig(ext_sig=ext_sig, meta=meta, header=header, sources=sources, levels=levels, cmap=cmap, unit_ID=unit_ID, det_sigma=det_sigma, sliders=False) request_time = _time.time() pic_name = str(request_time) + '_image.png' pic_fn = pic_name if output_folder is not None: pic_fn = os.path.join(output_folder, pic_name) plt.savefig(pic_fn) return pic_fn
[docs] def show(self, ext_sig=None, meta=None, header=None, sources=None, levels=None, cmap=cm.gist_earth, unit_ID=4, det_sigma=3, sliders=True): """ OdaImage.show :param ext_sig: ODA data products extension, takes from class initialisation by default :param meta: ODA data products metadata, takes from class initialisation by default :param header: ODA data product image header, takes from class initialisation by default :param sources: ODA catalog table, takes from class initialisation by default :param levels: levels for contour plot, default is numpy.linspace(1, 10, 10) :param cmap: colormap default is cm.gist_earth, :param unit_ID: the unit to plot image default is 4 :param det_sigma: limit detection sigma to lot from catalog, note that :param sliders: plot sliders, set to false to upload images in gallery. :return: matplotlib figure instance """ plt = self.build_fig(ext_sig=ext_sig, meta=meta, header=header, sources=sources, levels=levels, cmap=cmap, unit_ID=unit_ID, det_sigma=det_sigma, sliders=sliders) plt.show()
def build_fig(self, ext_sig=None, meta=None, header=None, sources=None, levels=None, cmap=cm.gist_earth, unit_ID=4, det_sigma=3, sliders=True): if levels is None: levels = numpy.linspace(1, 10, 10) if ext_sig is None: ext_sig = self.data.mosaic_image_0_mosaic.data_unit[unit_ID] if meta is None: self.meta = self.data.mosaic_image_0_mosaic.meta_data if header is None: header = self.data.mosaic_image_0_mosaic.data_unit[unit_ID].header if sources is None: sources = self.data.dispatcher_catalog_1.table w = wcs.WCS(header) fig = plt.figure(figsize=(8, 8./1.62)) ax = plt.subplot(projection=w) data = ext_sig.data data = numpy.ma.masked_equal(data, numpy.NaN) self.cs = plt.contourf(data, cmap=cmap, levels=levels, extend="both", zorder=0) self.cs.cmap.set_under('k') self.cs.set_clim(numpy.min(levels), numpy.max(levels)) self.cb = plt.colorbar(self.cs) if len(sources) > 0: ras = numpy.array([x for x in sources['ra']]) decs = numpy.array([x for x in sources['dec']]) if 'src_names' in sources.columns: names = numpy.array([x for x in sources['src_names']]) # Defines relevant indexes for plotting regions m_new = numpy.array(['NEW' in name for name in names]) if 'significance' in sources.columns: sigmas = numpy.array([x for x in sources['significance']]) # plot new sources as pink circles m = m_new & (sigmas > det_sigma) if numpy.sum(m) > 0: ra_coord = ras[m] dec_coord = decs[m] new_names = names[m] plt.scatter(ra_coord, dec_coord, s=100, marker="o", facecolors='none', edgecolors='pink', lw=3, label="NEW any", zorder=5, transform=ax.get_transform('world')) else: ra_coord = [] dec_coord = [] new_names = [] for i in range(len(ra_coord)): plt.text(ra_coord[i], dec_coord[i] + 0.5, new_names[i], color="pink", size=15, transform=ax.get_transform('world')) # fallback for general catalog (e.g. legacysurvey) if not 'src_names' in sources.columns or not 'significance' in sources.columns: ra_coord = ras dec_coord = decs plt.scatter(ra_coord, dec_coord, s=30, marker="o", facecolors='none', edgecolors='magenta', lw=0.5, zorder=5, transform=ax.get_transform('world')) m = ~m_new & (sigmas > det_sigma - 1) if numpy.sum(m) > 0: ra_coord = ras[m] dec_coord = decs[m] cat_names = names[m] plt.scatter(ra_coord, dec_coord, s=100, marker="o", facecolors='none', edgecolors='magenta', lw=3, label="known", zorder=5, transform=ax.get_transform('world')) else: ra_coord = [] dec_coord = [] cat_names = [] for i in range(len(ra_coord)): plt.text(ra_coord[i], dec_coord[i] + 0.5, cat_names[i], color="magenta", size=15, transform=ax.get_transform('world')) plt.grid(color="grey", zorder=10) plt.xlabel("RA") plt.ylabel("Dec") if sliders: # Nice to have : slider cmin = plt.axes([0.85, 0.05, 0.02, 0.4]) cmax = plt.axes([0.85, 0.55, 0.02, 0.4]) data_min = data[numpy.isfinite(data)].min() data_max = data[numpy.isfinite(data)].max() self.smin = Slider(cmin, 'Min', data_min, data_max, valinit=1., orientation='vertical') self.smax = Slider(cmax, 'Max', data_min, data_max, valinit=10., orientation='vertical') self.smin.on_changed(self.update) self.smax.on_changed(self.update) return fig @staticmethod def get_js9_html(file_path, region_file=None, js9_id='myJS9', base_url='/mmoda/gallery/sites/default/files'): region = '' file = 'JS9.Preload("%s/%s"' % (base_url, file_path) if region_file is not None: file += ', {scale: \'log\', colormap: \'plasma\', onload: function(im){JS9.SetZoom(1); ' + \ 'JS9.DisplayCoordGrid(true); JS9.LoadRegions("%s/%s");}}, {display: "%s"});' % ( base_url, region_file, js9_id) else: file += ', {scale: \'log\', colormap: \'plasma\', onload: function(im){JS9.DisplayCoordGrid(true);}, {display: "%s"});' % ( js9_id) t = ''' <html> <head> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"> <meta http-equiv="X-UA-Compatible" content="IE=Edge;chrome=1" > <meta name="viewport" content="width=device-width, initial-scale=1"> <link type="image/x-icon" rel="shortcut icon" href="./favicon.ico"> <link type="text/css" rel="stylesheet" href="../libraries/js9/js9support.css"> <link type="text/css" rel="stylesheet" href="../libraries/js9/js9.css"> <script type="text/javascript" src="../libraries/js9/js9prefs.js"></script> <script type="text/javascript" src="../libraries/js9/js9support.min.js"></script> <script type="text/javascript" src="../libraries/js9/js9.min.js"></script> <script type="text/javascript" src="../libraries/js9/js9plugins.js"></script> </head> <body> <center><font size="+1"> </font></center> <table cellspacing="30"> <tr valign="top"> <td> </td> <td> <tr valign="top"> <td> <div class="JS9Menubar" id="%sMenubar" ></div> <div class="JS9Colorbar" id="%sColorbar" ></div> <div class="JS9" id="%s"></div> </td> <td> <p> </td> </tr> </table> <script type="text/javascript"> function init(){ var idx, obj; JS9.imageOpts.wcsunits = "degrees"; %s } $(document).ready(function(){ init(); }); </script> </body> </html> ''' % (js9_id, js9_id, js9_id, file) return t
[docs] def update(self, x): if self.smin.val < self.smax.val: self.cs.set_clim(self.smin.val, self.smax.val)
def write_fits(self, file_prefix='', output_dir='.'): file_fn = os.path.join(output_dir, f'{file_prefix}mosaic.fits') self.data.mosaic_image_0_mosaic.write_fits_file(file_fn, overwrite=True) return file_fn def extract_catalog_from_image(self, include_new_sources=False, det_sigma=5, objects_of_interest=[], flag=1, isgri_flag=2, update_catalog=False): catalog_str = self.extract_catalog_string_from_image(include_new_sources, det_sigma, objects_of_interest, flag, isgri_flag, update_catalog) return json.loads(catalog_str) def extract_catalog_string_from_image(self, include_new_sources=False, det_sigma=5, objects_of_interest=None, flag=1, isgri_flag=2, update_catalog=True) -> str: """ Example: objects_of_interest=['Her X-1'] objects_of_interest=[('Her X-1', Simbad.query )] objects_of_interest=[('Her X-1', Skycoord )] objects_of_interest=[ Skycoord(....) ] """ if objects_of_interest is None: objects_of_interest = [] image = self.data if image.dispatcher_catalog_1.table is None: self.logger.warning("No sources in the catalog") if objects_of_interest != []: return OdaImage.add_objects_of_interest(None, objects_of_interest, flag, isgri_flag) else: return 'none' sources = image.dispatcher_catalog_1.table[image.dispatcher_catalog_1.table['significance'] >= det_sigma] if len(sources) == 0: self.logger.warning('No sources in the catalog with det_sigma > %.1f' % det_sigma) if objects_of_interest != []: return self.add_objects_of_interest(None, objects_of_interest, flag, isgri_flag) else: return 'none' if not include_new_sources: ind = [not 'NEW' in ss for ss in sources['src_names']] clean_sources = sources[ind] self.logger.debug(ind) self.logger.debug(sources) self.logger.debug(clean_sources) else: clean_sources = sources unique_sources = self.add_objects_of_interest(clean_sources, objects_of_interest, flag, isgri_flag) copied_image = copy.deepcopy(image) copied_image.dispatcher_catalog_1.table = unique_sources if update_catalog: image.dispatcher_catalog_1.table = unique_sources return copied_image.dispatcher_catalog_1.get_api_dictionary() @staticmethod def make_one_source_catalog_string(name, ra, dec, isgri_flag, flag): out_str_templ ='{"cat_frame": "fk5", "cat_coord_units": "deg", "cat_column_list": [[1], ["%s"], [0.0], [%f], [%f], [-32768], [%d], [%d], [0.001]], "cat_column_names": ["meta_ID", "src_names", "significance", "ra", "dec", "NEW_SOURCE", "ISGRI_FLAG", "FLAG", "ERR_RAD"], "cat_column_descr": [["meta_ID", "<i8"], ["src_names", "<U7"], ["significance", "<f8"], ["ra", "<f8"], ["dec", "<f8"], ["NEW_SOURCE", "<i8"], ["ISGRI_FLAG", "<i8"], ["FLAG", "<i8"], ["ERR_RAD", "<f8"]], "cat_lat_name": "dec", "cat_lon_name": "ra"}' return out_str_templ % (name, ra, dec, isgri_flag, flag) def add_objects_of_interest(self, clean_sources, objects_of_interest, flag=1, isgri_flag=2, tolerance = 1./60.): for ooi in objects_of_interest: if isinstance(ooi, tuple): ooi, t = ooi if isinstance(t, SkyCoord): source_coord = t elif isinstance(ooi, str): t = Simbad.query_object(ooi) else: raise Exception("fail to elaborate object of interest") if isinstance(t, table.Table): source_coord = SkyCoord(t['RA'], t['DEC'], unit=(u.hourangle, u.deg), frame="fk5") self.logger.info("Elaborating object of interest: %s %f %f" % (ooi, source_coord.ra.deg, source_coord.dec.deg)) ra = source_coord.ra.deg dec = source_coord.dec.deg self.logger.info("RA=%g Dec=%g" % (ra, dec)) if clean_sources is not None: #Look for the source of interest in NEW sources by coordinates for ss in clean_sources: if 'NEW' in ss['src_names']: if numpy.abs(ra - ss['ra']) <= tolerance and numpy.abs(dec - ss['dec']) <= tolerance: self.logger.info('Found ' + ooi + ' in catalog as ' + ss['src_names']) ind = clean_sources['src_names'] == ss['src_names'] clean_sources['FLAG'][ind] = flag clean_sources['ISGRI_FLAG'][ind] = isgri_flag clean_sources['src_names'][ind] = ooi #Look for the source of interest in ind = clean_sources['src_names'] == ooi if numpy.count_nonzero(ind) > 0: self.logger.info('Found ' + ooi + ' in catalog') clean_sources['FLAG'][ind] = flag if 'ISGRI_FLAG' in clean_sources.keys(): clean_sources['ISGRI_FLAG'][ind] = isgri_flag if 'JEMX_FLAG' in clean_sources.keys(): clean_sources['JEMX_FLAG'][ind] = isgri_flag else: self.logger.info('Adding ' + ooi + ' to catalog') if ('flux' in clean_sources.colnames or 'Flux' in clean_sources.colnames or \ 'FLUX' in clean_sources.colnames) and 'ISGRI_FLAG' in clean_sources.colnames: self.logger.debug('Flux is present') clean_sources.add_row((0, ooi, 0, ra, dec, 0, isgri_flag, flag, 1e-3, 0, 0)) elif 'ISGRI_FLAG' in clean_sources.colnames: self.logger.debug('Flux is NOT present but ISGRI_FLAG is present') clean_sources.add_row((0, ooi, 0, ra, dec, 0, isgri_flag, flag, 1e-3)) else: self.logger.debug('Flux and ISGRI_FLAG are NOT present') clean_sources.add_row((0, ooi, 0, ra, dec, flag, 1e-3)) unique_sources = table.unique(clean_sources, keys=['src_names']) return unique_sources else: return self.make_one_source_catalog_string(ooi, ra, dec, isgri_flag, flag)
[docs] class OdaLightCurve(OdaProduct): name = 'lightcurve' used_source_name = '' def get_lc(self, source_name, systematic_fraction=0): """_summary_ Args: source_name (str): Source name to get the LC, for SPI-ACS, use 'query' systematic_fraction (int, optional): relative systematic error to add in quadrature. Defaults to 0. Returns: numpy array time numpy array delta_time, numpy array rate numpy array rate_error float e_min float e_max _type_: _description_ """ combined_lc = self.data # In LC name has no "-" nor "+" ?????? patched_source_name = source_name.replace('-', ' ').replace('+', ' ') hdu = None for j, dd in enumerate(combined_lc._p_list): self.logger.debug(dd.meta_data['src_name']) if dd.meta_data['src_name'] in source_name or dd.meta_data['src_name'] in patched_source_name or \ dd.meta_data['src_name'] == 'query': self.used_source_name = dd.meta_data['src_name'] for ii, du in enumerate(dd.data_unit): if 'LC' in du.name or 'RATE' in du.name: hdu = du.to_fits_hdu() if hdu is None: self.logger.info('Source ' + source_name + ' not found in the light curves') return None, None, None, None, None, None x = hdu.data['TIME'] y = hdu.data['RATE'] dy = hdu.data['ERROR'] self.logger.debug("Original length of light curve %d" % len(x)) ind = numpy.argsort(x) x = x[ind] y = y[ind] dy = dy[ind] dy = numpy.sqrt(dy ** 2 + (y * systematic_fraction) ** 2) ind = numpy.logical_and(numpy.isfinite(y), numpy.isfinite(dy)) ind = numpy.logical_and(ind, dy > 0) self.logger.debug("Final length of light curve %d " % numpy.sum(ind)) if 'E_MIN' in hdu.header: e_min = hdu.header['E_MIN'] else: if self.instrument == 'SPI-ACS': self.logger.debug('e_min set to 75 keV as the instrument is SPI-ACS') e_min = 75 else: e_min = 0 if 'E_MAX' in hdu.header: e_max = hdu.header['E_MAX'] else: if self.instrument == 'SPI-ACS': self.logger.debug('e_max set to 2000 keV as the instrument is SPI-ACS') e_max = 2000 else: e_max = 0 if self.instrument == 'SPI-ACS': self.timezero = hdu.header['TIMEZERO'] / 86400. + hdu.header['MJDREF'] from astropy.time import Time self.timezero_utc = Time(self.timezero, format='mjd').iso #This could only be valid for ISGRI try: dt_lc = hdu.data['XAX_E'] self.logger.debug('Get time bin directly from light curve') except: timedel = hdu.header['TIMEDEL'] if 'TIMEPIXR' in hdu.header: timepix = hdu.header['TIMEPIXR'] else: timepix = 0.5 t_lc = hdu.data['TIME'] + (0.5 - timepix) * timedel dt_lc = t_lc.copy() * 0.0 + timedel / 2 for i in range(len(t_lc) - 1): dt_lc[i + 1] = numpy.fabs(min(timedel / 2, t_lc[i + 1] - t_lc[i] - dt_lc[i])) self.logger.debug('Computed time bin from TIMEDEL') m_negative_bins = dt_lc < 0 if numpy.sum(m_negative_bins) > 0: self.logger.debug('found negative time bins at %s: disabling them', x[m_negative_bins]) x[m_negative_bins] = numpy.NaN dt_lc[m_negative_bins] = numpy.NaN y[m_negative_bins] = numpy.NaN dy[m_negative_bins] = numpy.NaN return x[ind], dt_lc[ind], y[ind], dy[ind], e_min, e_max def get_image_for_gallery(self, in_source_name='', systematic_fraction=0, ng_sig_limit=0, find_excesses=False, output_folder=None): plts = self.build_fig(in_source_name=in_source_name, systematic_fraction=systematic_fraction, ng_sig_limit=ng_sig_limit, find_excesses=find_excesses) request_time = _time.time() pic_name = str(request_time) + '_image.png' pic_fn = pic_name if output_folder is not None: pic_fn = os.path.join(output_folder, pic_name) if len(plts) == 1: plts[0].savefig(pic_fn) return pic_fn
[docs] def show(self, in_source_name='', systematic_fraction=0, ng_sig_limit=0, find_excesses=False): plt = self.build_fig(in_source_name=in_source_name, systematic_fraction=systematic_fraction, ng_sig_limit=ng_sig_limit, find_excesses=find_excesses) for p in plt: p.show()
def build_fig(self, in_source_name='', systematic_fraction=0, ng_sig_limit=0, find_excesses=False): #if ng_sig_limit <1 does not plot range combined_lc = self.data from scipy import stats if in_source_name == '': source_names = [dd.meta_data['src_name'] for dd in combined_lc._p_list] else: source_names = [in_source_name] figs = [] for source_name in source_names: x, dx, y, dy, e_min, e_max = self.get_lc(source_name, systematic_fraction) if x is None: return meany = numpy.sum(y / dy ** 2) / numpy.sum(1. / dy ** 2) err_mean = numpy.sum(1 / dy ** 2) std_dev = numpy.std(y) figs.append(plt.figure(figsize=(8, 8./1.62))) _ = plt.errorbar(x, y, xerr=dx, yerr=dy, marker='o', capsize=0, linestyle='', label='Lightcurve') _ = plt.axhline(meany, color='green', linewidth=3) if self.instrument == 'SPI-ACS': _ = plt.xlabel('seconds since %s UTC' % self.timezero_utc) else: _ = plt.xlabel('Time [IJD]') if e_min == 0 or e_max ==0: _ = plt.ylabel('Rate') else: _ = plt.ylabel('Rate %.1f-%.1f keV' % (e_min, e_max)) if ng_sig_limit >= 1: ndof = len(y) - 1 prob_limit = stats.norm().sf(ng_sig_limit) chi2_limit = stats.chi2(ndof).isf(prob_limit) band_width = numpy.sqrt(chi2_limit / err_mean) _ = plt.axhspan(meany - band_width, meany + band_width, color='green', alpha=0.3, label=f'{ng_sig_limit} $\sigma_m$, {100 * systematic_fraction}% syst') _ = plt.axhspan(meany - std_dev*ng_sig_limit, meany + std_dev*ng_sig_limit, color='cyan', alpha=0.3, label=f'{ng_sig_limit} $\sigma_d$, {100 * systematic_fraction}% syst') _ = plt.legend() plot_title = source_name _ = plt.title(plot_title) if find_excesses: ind = (y - band_width)/dy > ng_sig_limit if numpy.sum(ind) > 0: _ = plt.plot(x[ind], y[ind], marker='x', color='red', linestyle='', markersize=10) self.logger.info('We found positive excesses on the lightcurve at times') good_ind = numpy.where(ind) #print(good_ind[0][0:-1], good_ind[0][1:]) old_time = -1 if len(good_ind[0]) == 1: self.logger.info('%f' % (x[good_ind[0][0]])) else: for i,j in zip(good_ind[0][0:-1], good_ind[0][1:]): #print(i,j) if j-i > 2: if x[i] != old_time : self.logger.info('%f' % x[i]) _ = OdaLightCurve.plot_zoom(x,y,dy,i) self.logger.info('%f' % (x[j])) _ = OdaLightCurve.plot_zoom(x, y, dy, j) # else: # self.logger.debug('%f' % ((x[i]+x[j])/2)) old_time = x[j] return figs @staticmethod def plot_zoom(x, y, dy, i, n_before=5, n_after=15, save_plot=True, name_base='burst_at_'): fig = plt.figure(figsize=(8, 8./1.62)) _ = plt.errorbar(x[i-n_before:i+n_after], y[i-n_before:i+n_after], yerr=dy[i-n_before:i+n_after], marker='o', capsize=0, linestyle='', label='Lightcurve') _ = plt.xlabel('Time') _ = plt.ylabel('Rate') if save_plot: _ = plt.savefig(name_base+'%d.png' % i) return fig def write_fits(self, source_name, file_suffix='', output_dir='.'): # In LC name has no "-" nor "+" ?????? lc = self.data patched_source_name = source_name.replace('-', ' ').replace('+', ' ') lcprod = [l for l in lc._p_list if l.meta_data['src_name'] == source_name or \ l.meta_data['src_name'] == patched_source_name] if (len(lcprod) < 1): self.logger.warning("source %s not found in light curve products" % source_name) return "none", 0, 0, 0 if (len(lcprod) > 1): self.logger.warning( "source %s is found more than once light curve products, writing only the first one" % source_name) instrument = lcprod[0].data_unit[1].header['INSTRUME'] if instrument == 'IBIS': ind_extension = 1 else: ind_extension = 2 lc_fn = output_dir + "/%s_lc_%s%s.fits" % (instrument, source_name.replace(' ', '_'), file_suffix) hdu = lcprod[0].data_unit[ind_extension].to_fits_hdu() timedel = hdu.header['TIMEDEL'] timepixr = hdu.header['TIMEPIXR'] dt = timedel * timepixr hdu.header['TSTART'] = hdu.data['TIME'][0] - dt hdu.header['TSTOP'] = hdu.data['TIME'][-1] + dt hdu.header['TFIRST'] = hdu.data['TIME'][0] - dt hdu.header['TLAST'] = hdu.data['TIME'][-1] + dt hdu.header['TELAPSE'] = hdu.header['TLAST'] - hdu.header['TFIRST'] ontime=0 for x in hdu.data['FRACEXP']: ontime += x * timedel hdu.header['ONTIME'] = ontime fits.writeto(lc_fn, hdu.data, header=hdu.header, overwrite=True) mjdref = float(hdu.header['MJDREF']) tstart = float(hdu.header['TSTART']) + mjdref tstop = float(hdu.header['TSTOP']) + mjdref try: exposure = float(hdu.header['EXPOSURE']) except: exposure = -1 return lc_fn, tstart, tstop, exposure @staticmethod def check_product_for_gallery(**kwargs): if 'src_name' not in kwargs: logger.warning('The src_name parameter is mandatory for a light-curve product\n') raise api.UserError('the src_name parameter is mandatory for a light-curve product') logger.info('Policy for a light-curve product successfully verified\n') return True def get_html_image(self, source_name, systematic_fraction, color='blue'): import cdci_data_analysis.analysis.plot_tools x, dx, y, dy, e_min, e_max = self.get_lc(source_name, systematic_fraction) mask = numpy.logical_not(numpy.isnan(y)) x = x[mask] dx = dx[mask] y = y[mask] dy = dy[mask] if self.instrument == 'SPI-ACS': xlabel = 'seconds since %s UTC' % self.timezero_utc else: xlabel = 'Time [IJD]' sp = cdci_data_analysis.analysis.plot_tools.ScatterPlot(w=800, h=600, x_label=xlabel, y_label='Rate (%.0f - %.0f keV)' % (e_min, e_max), title=source_name) sp.add_errorbar(x, y, yerr=dy, xerr=dx, color=color) html_dict = sp.get_html_draw() html_str = html_dict['div'] + '\n' html_str += '<script src="https://cdn.bokeh.org/bokeh/release/bokeh-2.4.2.min.js"></script>\n' + \ '<script src="https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.2.min.js"></script>\n' html_str += html_dict['script'] return html_str
[docs] class OdaSpectrum(OdaProduct): name = 'spectrum'
[docs] def show_spectral_products(self): summed_data = self.data for dd, nn in zip(summed_data._p_list, summed_data._n_list): self.logger.debug(nn) dd.show_meta() # for kk in dd.meta_data.items(): if 'spectrum' in dd.meta_data['product']: self.logger.debug(dd.data_unit[1].header['EXPOSURE']) dd.show()
[docs] def get_spectrum_products(self, in_source_name='none'): if in_source_name == 'none': return None specprod = [l for l in self.data._p_list if l.meta_data['src_name'] in in_source_name] if len(specprod) < 1: self.logger.warning("source %s not found in spectral products" % in_source_name) return None return specprod
[docs] def show(self, in_source_name='', systematic_fraction=0, xlim=[]): plt = self.build_fig(in_source_name=in_source_name, systematic_fraction=systematic_fraction, xlim=xlim) if plt is not None: plt.show()
[docs] def get_html_image(self, in_source_name, systematic_fraction, x_range=None, y_range=None, color='blue'): import cdci_data_analysis.analysis.plot_tools x, dx, y, dy = self.get_values(in_source_name, systematic_fraction) if len(x) == 0: logger.warning('Returning empty HTML string, as no data are retrieved') return '' if x_range is None: x_range = [x.min(), x.max()] if y_range is None: y_range = [numpy.max([1e-4, (y-dy)[x < x_range[1]].min()]), (y+dy).max()] sp = cdci_data_analysis.analysis.plot_tools.ScatterPlot(w=800, h=600, x_label="Energy [keV]", y_label='Counts/s/keV', x_axis_type='log', y_axis_type='log', x_range=x_range, y_range=y_range, title=in_source_name) if len(x) == 0: return '' sp.add_errorbar(x, y, yerr=dy, xerr=dx, color=color) html_dict = sp.get_html_draw() html_str = html_dict['div'] + '\n' html_str += '<script src="https://cdn.bokeh.org/bokeh/release/bokeh-2.4.2.min.js"></script>\n' + \ '<script src="https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.2.min.js"></script>\n' html_str += html_dict['script'] return html_str
[docs] def get_values(self, in_source_name='', systematic_fraction=0): if in_source_name == '': return numpy.array([]), numpy.array([]), numpy.array([]), numpy.array([]) specprod = self.get_spectrum_products(in_source_name) if specprod is None: return numpy.array([]), numpy.array([]), numpy.array([]), numpy.array([]) spec = specprod[0].data_unit[1].to_fits_hdu() for hh in specprod[2].data_unit: if hh.to_fits_hdu().header['EXTNAME'] == 'EBOUNDS': ebounds = hh.to_fits_hdu() x = (ebounds.data['E_MAX'] + ebounds.data['E_MIN'])/2. dx = (ebounds.data['E_MAX'] - ebounds.data['E_MIN']) / 2. y = spec.data['RATE'] dy = numpy.sqrt(spec.data['STAT_ERR']**2 + spec.data['SYS_ERR']**2 + (y*systematic_fraction)**2) mask = numpy.logical_not(numpy.isnan(y)) x = x[mask] dx = dx[mask] y = y[mask] dy = dy[mask] y /= dx dy /= dx return x, dx, y, dy
[docs] def build_fig(self, in_source_name='', systematic_fraction=0, xlim=[]): if in_source_name == '': self.show_spectral_products() return x , dx, y, dy = self.get_values(in_source_name, systematic_fraction) if len(x) == 0: return fig = plt.figure(figsize=(8, 8./1.62)) _ = plt.errorbar(x, y, xerr=dx, yerr=dy, marker='o', capsize=0, linestyle='', label='spectrum') _ = plt.xlabel('Energy [keV]') _ = plt.xscale('log') _ = plt.yscale('log') _ = plt.ylabel('$dN/dE$ [keV$^{-1}s$^{-1}cm$^{-2}$]') _ = plt.title(in_source_name) if len(xlim) == 2: _ = plt.xlim(xlim) return fig
[docs] def write_fits(self, source_name='', file_suffix='', grouping=[0, 0, 0], systematic_fraction=0, output_dir='.'): """ Grouping argument is [minimum_energy, maximum_energy, number_of_bins] number of bins > 0, linear grouping number_of_bins < 0, logarithmic binning """ if source_name == '': self.show_spectral_products() self.logger.warning('PLease specify a source to save the spectral products') return "none", 0, 0, 0 specprod = self.get_spectrum_products(source_name) if specprod is None: return "none", 0, 0, 0 instrument = specprod[0].data_unit[1].header['INSTRUME'] out_name = source_name.replace(' ', '_').replace('+', 'p') spec_fn = output_dir + "/%s_spectrum_%s%s.fits" % (instrument, out_name, file_suffix) arf_fn = output_dir + "/%s_arf_%s%s.fits" % (instrument, out_name, file_suffix) rmf_fn = output_dir + "/%s_rmf_%s%s.fits" % (instrument, out_name, file_suffix) self.logger.info("Saving spectrum %s with rmf %s and arf %s" % (spec_fn, rmf_fn, arf_fn)) specprod[0].write_fits_file(spec_fn) specprod[1].write_fits_file(arf_fn) specprod[2].write_fits_file(rmf_fn) ff = fits.open(spec_fn, mode='update') ff[1].header['RESPFILE'] = rmf_fn ff[1].header['ANCRFILE'] = arf_fn mjdref = ff[1].header['MJDREF'] tstart = float(ff[1].header['TSTART']) + mjdref tstop = float(ff[1].header['TSTOP']) + mjdref exposure = ff[1].header['EXPOSURE'] ff[1].data['SYS_ERR'] = numpy.zeros(len(ff[1].data['SYS_ERR'])) + systematic_fraction ind = numpy.isfinite(ff[1].data['RATE']) ff[1].data['QUALITY'][ind] = 0 if numpy.sum(grouping) != 0: if grouping[1] <= grouping[0] or grouping[2] == 0: raise RuntimeError('Wrong grouping arguments') ff_rmf = fits.open(rmf_fn) e_min = ff_rmf['EBOUNDS'].data['E_MIN'] e_max = ff_rmf['EBOUNDS'].data['E_MAX'] ff_rmf.close() ind1 = numpy.argmin(numpy.abs(e_min - grouping[0])) ind2 = numpy.argmin(numpy.abs(e_max - grouping[1])) n_bins = numpy.abs(grouping[2]) ff[1].data['GROUPING'][0:ind1] = 0 ff[1].data['GROUPING'][ind2:] = 0 ff[1].data['QUALITY'][0:ind1] = 1 ff[1].data['QUALITY'][ind2:] = 1 if grouping[2] > 0: step = int((ind2 - ind1 + 1) / n_bins) self.logger.info('Linear grouping with step %d' % step) for i in range(1, step): j = range(ind1 + i, ind2, step) ff[1].data['GROUPING'][j] = -1 else: ff[1].data['GROUPING'][ind1:ind2] = -1 e_step = (e_max[ind2] / e_min[ind1]) ** (1.0 / n_bins) self.logger.info('Geometric grouping with step %.3f' % e_step) loc_e = e_min[ind1] while (loc_e < e_max[ind2]): ind_loc_e = numpy.argmin(numpy.abs(e_min - loc_e)) ff[1].data['GROUPING'][ind_loc_e] = 1 loc_e *= e_step ff.flush() ff.close() return spec_fn, tstart, tstop, exposure
[docs] class OdaGWContours(OdaProduct): # TODO to clarify the name, also for the gallery name = 'contour' @staticmethod def _plot_single_contour(contour_coords, ax, color='r'): coords = numpy.array(contour_coords) try: ax.plot(coords[:,0], coords[:,1], '-', transform=ax.get_transform('world'), color = color) except TypeError: ax.plot(coords[:,0], coords[:,1], '-', transform=ax.get_transform(), color = color) @staticmethod def _plot_contour_list(contour_list, ax, color=None): kwargs = {} if color is not None: kwargs['color'] = color for contour_coords in contour_list: OdaGWContours._plot_single_contour(contour_coords, ax, **kwargs)
[docs] def plot_event_contours(self, event, legend=True, name_in_legend=True, colors = [], ax = None): if ax is None: ax = plt.axes(projection='astro hours mollweide') if not colors: prop_cycle = plt.rcParams['axes.prop_cycle'] colors = prop_cycle.by_key()['color'] lpr = [] names = [] if event in self.data.contours.keys(): for i in range(len(self.data.contours[event].levels)): color = colors[i%len(colors)] OdaGWContours._plot_contour_list(self.data.contours[event].contours[i], ax, color) lpr.append(plt.Rectangle((0, 0), 1, 1, fc = color)) names.append(f"{self.data.contours[event].name+' ' if name_in_legend else ''}{self.data.contours[event].levels[i]}%") else: raise ValueError(f'Wrong event name: {event}') if legend is True: ax.legend(lpr, names)
[docs] def plot_contours(self, legend=True, colors = [], ax = None): if ax is None: ax = plt.axes(projection='astro hours mollweide') if not colors: prop_cycle = plt.rcParams['axes.prop_cycle'] colors = prop_cycle.by_key()['color'] lpr = [] names = [] i = 0 for event, data in self.data.contours.items(): color = colors[i%len(colors)] self.plot_event_contours(event, ax = ax, colors = [color], legend = False) i+=1 lpr.append(plt.Rectangle((0, 0), 1, 1, fc = color)) names.append(event) if legend is True: ax.legend(lpr, names, numpoints=1, bbox_to_anchor=(1.05, 1), loc='upper left')
[docs] def show(self, event_name = None): fig = self.build_fig(event_name=event_name) if fig is not None: fig.show()
[docs] def build_fig(self, event_name=None): fig = plt.figure(figsize=(8, 8./1.62)) if event_name is None: self.plot_contours() else: self.plot_event_contours(event_name) return fig
# TODO can an implementation of this method provided?
[docs] def write_fits(self): raise NotImplementedError