import numpy as np
from astropy.table import Table


class TestHelper(object):
  """
  Several utility functions that can be reused in various tests. 
  Inherit actual test case classes from this class together with unittest.TestCase (i.e. use it like mixin).
  See http://stackoverflow.com/questions/6655724/how-to-write-a-custom-assertfoo-method-in-python
  
  @TODO: make project-wide test suite and reorganize tests in a proper way
  """

  def assert_table_non_empty_columns(self, filename):
    """
    Check that all columns in a (FITS) table are non empty (not all nans)
    """
    t = Table.read(filename)
    for col in t.columns:
      if len(t[col][np.isnan(t[col]) == False]) == 0:
        raise AssertionError('Empty column %s in %s table' % (col, filename))


  def assert_non_null_ids(self, filename, columns=('objid', 'mjd', 'plate', 'fiberid', 'specobjid')):
    """
    Check that common ID columns do not contain nans
    """
    t = Table.read(filename)
    for c in columns:
      for col in (c.lower(), c.upper()): # try lower and upper cased column name
        if col in t.columns:
          if len(t[col][np.isnan(t[col]) == True]) != 0:
            raise AssertionError('Empty values in %s column of %s table' % (col, filename))

  
  def assert_less_rows(self, join_result_filename, input_filename):
    """
    Check that `join_result_filename` has same rows as `input_filename`
    """
    t1 = Table.read(join_result_filename)
    t2 = Table.read(input_filename)
    if len(t1) != len(t2):
      raise AssertionError('%s and %s have different number of rows' % (join_result_filename, input_filename))
      