/////////////////////////////////////////////////////////////////////////////
//                   SOFTWARE COPYRIGHT NOTICE AGREEMENT                   //
//       This software and its documentation are copyright (2007) by the   //
//   Broad Institute/Massachusetts Institute of Technology.  All rights    //
//   are reserved.  This software is supplied without any warranty or      //
//   guaranteed support whatsoever. Neither the Broad Institute nor MIT    //
//   can be responsible for its use, misuse, or functionality.             //
/////////////////////////////////////////////////////////////////////////////

#ifndef FORCE_DEBUG
     #define NDEBUG
#endif
#define STRING_FAST_EXECUTE

#include "math/Functions.h"
#include "String.h"
#include "system/Types.h"
#include "Vec.h"

#include <sys/mman.h>

const int binary2_header_size = 62;
const int binary3_header_size = 96;
const int binary_element_count_begin = 34;
const int binary_element_count_size = 13;
const int binary_element_count_end = binary_element_count_begin + binary_element_count_size;

const static String binary2_first_line( "binary format 2, header = 3 lines\n" );
const static String binary3_first_line( "binary format 3, header = 4 lines\n" );

const static String binary_little_endian_line( "\nlittle endian\n" );
const static String binary_big_endian_line(    "\nbig endian   \n" );

const static String binary3_padding_line( "padding to make long word aligned\n" );

inline
void BinaryWriteSize( const int fd, const longlong n ) {
  String length = ToString(n);
  int k = length.size( );
  ForceAssertLe( k, binary_element_count_size );
  length.resize( binary_element_count_size );
  for ( int i = k; i < binary_element_count_size; i++ )
    length[i] = ' ';
  WriteBytes( fd, length.c_str( ), binary_element_count_size );
}

template <class T>
void
BinaryWrite2Or3( const String& filename, const vec<T>& v,
                 const int version )
{    
  ForceAssert( version == 2 || version == 3 );
  longlong n = v.size( );
  static longlong ten_trillion = (longlong) 10000000 * (longlong) 10000000;
  ForceAssertLt( n, ten_trillion );
  Remove(filename);
  int fd = OpenForWrite(filename);
  if ( version == 2 )
    WriteBytes( fd, binary2_first_line.c_str( ), binary2_first_line.size( ) );
  else
    WriteBytes( fd, binary3_first_line.c_str( ), binary3_first_line.size( ) );
  BinaryWriteSize( fd, n );
  ForceAssertEq( binary_little_endian_line.size( ), binary_big_endian_line.size( ) );
#ifdef Little_Endian
  {    WriteBytes( fd, binary_little_endian_line.c_str( ), binary_little_endian_line.size( ) );    }
#else
  {    WriteBytes( fd, binary_big_endian_line.c_str( ), binary_big_endian_line.size( ) );    }
#endif
  if ( version == 3 )
    WriteBytes( fd, binary3_padding_line.c_str(), binary3_padding_line.size() );
  if ( n > 0 ) 
    WriteBytes( fd, &v[0], (longlong) sizeof(T) * n );
  close(fd);    
}

template <class T>
void
BinaryWrite2( const String& filename, const vec<T>& v ) {
  BinaryWrite2Or3( filename, v, 2 );
}

template <class T>
void
BinaryWrite3( const String& filename, const vec<T>& v ) {
  BinaryWrite2Or3( filename, v, 3 );
}


template <class T>
void
CheckHeader2Or3( const String& filename, int fd, longlong& n,
                 const int version )
{    
  ForceAssert( version == 2 || version == 3 );

  static String header;
  if ( version == 2 ) {
    header.resize( binary2_header_size );
    ReadBytes( fd, &header[0], binary2_header_size );
    for ( unsigned int i = 0; i < binary2_first_line.size( ); i++ )
      if ( header[i] != binary2_first_line[i] )
        FatalErr( "Binary read 2 of " << filename << " failed: "
                  << "first line doesn't match expected value." );
  }
  else {
    header.resize( binary3_header_size );
    ReadBytes( fd, &header[0], binary3_header_size );
    for ( unsigned int i = 0; i < binary3_first_line.size( ); i++ )
      if ( header[i] != binary3_first_line[i] )
        FatalErr( "Binary read 3 of " << filename << " failed: "
                  << "first line doesn't match expected value." );
  }

  Bool little = True, big = True;
  for ( unsigned int i = 0; i < binary_little_endian_line.size( ); i++ )
    if ( header[ binary_element_count_end + i ] != binary_little_endian_line[i] ) {
      little = False;
      break;
    }
  for ( unsigned int i = 0; i < binary_big_endian_line.size( ); i++ )
    if ( header[ binary_element_count_end + i ] != binary_big_endian_line[i] ) {    
      big = False;
      break;
    }
  if ( !little && !big ) 
    FatalErr( "Binary read " << version << " of " << filename << " failed: "
              "can't determine endian setting from header." );    
#ifdef Little_Endian
  if ( !little )
    FatalErr( "Binary read " << version << " of " << filename << " failed because file "
              << "was written on a big endian architecture,\nand read back "
              << "in on this little endian architecture.\nUnfortunately, "
              << "this is not possible at present." );   
#else
  if ( !big )
    FatalErr( "Binary read " << version << " of " << filename << " failed because file "
              << "was written on a little endian architecture,\nand read back "
              << "in on this big endian architecture.\nUnfortunately, "
              << "this is not possible at present." );
#endif

  int d;
  for ( d = binary_element_count_begin; d < binary_element_count_end; d++ ) {
    if ( header[d] == ' ' ) break;
    if ( !isdigit( header[d] ) )
      FatalErr( "Binary read 2 of " << filename << " failed: "
                << "didn't find record count where it should be." );   
  }
  for ( int d2 = d + 1; d2 < binary_element_count_end; d2++ ) {
    if ( header[d] != ' ' )
      FatalErr( "Binary read 2 of " << filename << " failed: "
                << "didn't find white space where expected." );
  }
  static String ns;
  ns = header.substr( binary_element_count_begin, d - binary_element_count_begin );
  n = ns.Int( ); 
  longlong N = FileSize(filename);
  if ( N != (longlong) header.size( ) + n * (longlong) sizeof(T) )
    FatalErr( "Binary read " << version << " of " << filename << " failed:\n"
              << "header size = " << header.size( )
              << ", filesize = " << N << ", record count = " << n
              << ", record size = " << sizeof(T) << "." ); 
}

template <class T>
void
CheckHeader2( const String& filename, int fd, longlong& n ) {
  CheckHeader2Or3<T>( filename, fd, n, 2 );
}

template <class T> 
void
CheckHeader3( const String& filename, int fd, longlong& n ) {
  CheckHeader2Or3<T>( filename, fd, n, 3 );
}


template <class T> 
void
BinaryRead2Or3( const String& filename, vec<T>& v, int version = -1, 
     const Bool append = False )
{
  if( version == -1 ) version = WhichBinaryFormat(filename);
  ForceAssert( version == 2 || version == 3 );
  int fd = OpenForRead(filename);
  longlong n;
  if ( version == 2 )
    CheckHeader2<T>( filename, fd, n );
  else
    CheckHeader3<T>( filename, fd, n );
  int start = ( append ? v.isize( ) : 0 );
  v.resize( append ? n + v.isize( ) : n );
  if ( n > 0 ) ReadBytes( fd, &v[start], (longlong) sizeof(T) * n );
  close(fd);
}

template <class T> 
void
BinaryRead2( const String& filename, vec<T>& v, bool strict ) {
  BinaryRead2Or3<T>( filename, v, (strict ? 2 : -1) );
}

template <class T> 
void
BinaryRead3( const String& filename, vec<T>& v, bool strict, const Bool append ) {
  BinaryRead2Or3<T>( filename, v, (strict ? 3 : -1), append );
}


template <class T> 
void
BinaryReadSubset2Or3( const String& filename, const vec<int>& ids, vec<T>& v, Bool append, 
                      int version = -1 )
{    
  if( version == -1 ) version = WhichBinaryFormat(filename);
  ForceAssert( version == 2 || version == 3 );
  int fd = OpenForRead(filename);
  longlong n;
  if ( version == 2 )
    CheckHeader2<T>( filename, fd, n );
  else
    CheckHeader3<T>( filename, fd, n );
  int newsize = ids.size( );
  if (append) newsize += v.size( );
  int start = ( append ? v.isize() : 0 );
  v.resize(newsize);
  const int header_size = ( version == 2 ? binary2_header_size : binary3_header_size );
  for ( int i = 0; i < ids.isize( ); i++ ) {
    lseek( fd, header_size + ids[i] * sizeof(T), SEEK_SET );
    read( fd, &v[start+i], sizeof(T) );    
  }
  close(fd);    
}

template <class T> 
void
BinaryReadSubset2( const String& filename, const vec<int>& ids, vec<T>& v, 
		   Bool append, bool strict ) {
  BinaryReadSubset2Or3<T>( filename, ids, v, append, (strict ? 2 : -1) );
}

template <class T> 
void
BinaryReadSubset3( const String& filename, const vec<int>& ids, vec<T>& v, 
		   Bool append, bool strict ) {
  BinaryReadSubset2Or3<T>( filename, ids, v, append, (strict ? 3 : -1) );
}


template <class T>
void
BinaryReadRange2Or3( const String& filename, longlong from, longlong to, vec<T>& v, 
                     int version = -1 )
{    
  if( version == -1 ) version = WhichBinaryFormat(filename);
  ForceAssert( version == 2 || version == 3 );
  ForceAssert( from <= to );
  int fd = OpenForRead(filename);
  longlong n;
  if ( version == 2 )
    CheckHeader2<T>( filename, fd, n );
  else
    CheckHeader3<T>( filename, fd, n );
  v.resize( to - from );
  const int header_size = ( version == 2 ? binary2_header_size : binary3_header_size );
  lseek( fd, header_size + from * sizeof(T), SEEK_SET );
  if ( to > from ) read( fd, &v[0], (to - from) * sizeof(T) );
  close(fd);
}

template <class T>
void
BinaryReadRange2( const String& filename, longlong from, longlong to, 
		  vec<T>& v, bool strict ) {
  BinaryReadRange2Or3<T>( filename, from, to, v, (strict ? 2 : -1) );
}

template <class T> 
void
BinaryReadRange3( const String& filename, longlong from, longlong to, 
		  vec<T>& v, bool strict ) {
  BinaryReadRange2Or3<T>( filename, from, to, v, (strict ? 3 : -1) );
}


template<class T> 
longlong 
BinarySize2Or3( const String& filename, int version = -1 )
{
  if( version == -1 ) version = WhichBinaryFormat(filename);
  ForceAssert( version == 2 || version == 3 );
  int fd = OpenForRead(filename);
  longlong n;
  if ( version == 2 )
    CheckHeader2<T>( filename, fd, n );
  else
    CheckHeader3<T>( filename, fd, n );
  return n;
}

template<class T> 
longlong 
BinarySize2( const String& filename, bool strict ) {
  return BinarySize2Or3<T>( filename, (strict ? 2 : -1) );
}

template<class T> 
longlong 
BinarySize3( const String& filename, bool strict ) {
  return BinarySize2Or3<T>( filename, (strict ? 3 : -1) );
}


// I'm going straight to hell for this one...

// There is now a class mappedvec<T>, that inherits publicly from
// vec<T>, which in turn inherits from vector<T>, which in turn
// inherits from _Vector_base<T>, at least in the local implementation
// of vector.  This is where the pointers to the data actually lie.
// Scared yet?  Thought so.

// There is a new function, BinaryMmap3(), that takes a filename and a
// reference to an auto_ptr to a const mappedvec<T>.  (We'll get to
// the logic behind this peculiar syntax in a bit.)  In this function,
// we allocate a mappedvec<T>, then we explicitly call ~vec() on the
// pointer we just allocated (to deallocate any memory it may be
// using).  We then mmap() the file's contents, and set the
// baseclass's pointers by hand to the relevant offsets into the file.
// We then set the auto_ptr we handed in to point to the munged
// mappedvec<T>.

// The destructor is even weirder: we calculate from the baseclass's
// pointers where the mmapped region starts and how big it is, munmap
// it, and then we do an in-place new on "this", which resets all the
// internal pointers to something that the baseclasses' destructors
// (which will be called immediately after this) know what to do with.

// These are all insanely dependent upon the precise implementation of
// vector being used, so we have these checks for which version of
// GLIBCXX is in use.  If the version doesn't match, we implement
// BinaryMmap3() using BinaryRead3() and print out a warning.

#if ( __GLIBCXX__ == 20041105 || __GLIBCXX__ == 20050926 || __GLIBCXX__ == 20060524 )
#define USE_BINARYMMAP3
#endif 

template <typename T>
void
BinaryMmap3( const String& filename, auto_ptr< const mappedvec<T> >& vp ) {
  mappedvec<T>* pVec = new mappedvec<T>;

#ifdef USE_BINARYMMAP3

#define VEC_M_START _M_impl._M_start
#define VEC_M_FINISH _M_impl._M_finish
#define VEC_M_EOS _M_impl._M_end_of_storage

  // Manually call superclass destructor to deallocate any storage
  // that may have been allocated for this vector.

  pVec->vec<T>::~vec();
  
  // Open the file and check the header.
  int fd = OpenForRead(filename);
  longlong n;
  CheckHeader3<T>( filename, fd, n );

  // Mmap the file.
  size_t size = binary3_header_size + n * sizeof(T);
  char* pMapped = (char*) mmap( 0, size, PROT_READ, MAP_SHARED, fd, 0 );
  if ( pMapped == 0 ) {
    cout << "Call to mmap() failed." << endl;
    TracebackThisProcess();
  }
  
  // Directly set the internal pointers of the vector<T> baseclass.
  pVec->VEC_M_START = (T*) ( pMapped + binary3_header_size );
  pVec->VEC_M_FINISH = pVec->VEC_M_EOS = pVec->VEC_M_START + n;
#else

#warning Unable to implement BinaryMmap3(), defaulting to use BinaryRead3() instead.
  BinaryRead3( filename, *pVec );

#endif

  vp.reset( pVec );
}


template <typename T>
mappedvec<T>::~mappedvec() {
#ifdef USE_BINARYMMAP3
  // Figure out where the mmapped region starts.
  char* pMapped = (char*) this->VEC_M_START - binary3_header_size;
  size_t size = (char*)this->VEC_M_EOS - pMapped;
  
  // Unmap the region and check the return code.
  if ( munmap( pMapped, size ) != 0 ) {
    cout << "Call to munmap() failed." << endl;
    TracebackThisProcess();
  }
  
  // Now we have to reset the pointers to something that the
  // superclasses' destructors know what to do with.  We do this with
  // an in-place new.
  //
  // We need to undef then redefine any macro named "new" for this to
  // work.  The file system/MemTracker.h defines such a macro to do
  // its job.
#ifdef new
#undef new
#endif
  new (this) mappedvec<T>();
#ifdef TRACK_MEMORY
#define new new(__FILE__,__LINE__)
#endif
#endif // USE_BINARYMMAP3
}
  

template <typename T>
void Binary3Writer<T>::Open( const String& filename, bool keep_open ) {
  m_filename = filename;
  m_keep_open = keep_open;
  Remove( filename );
  vec<T> emptyVec;
  BinaryWrite3( filename, emptyVec );
  m_fd = ::Open( filename, O_WRONLY );
  off_t currpos = lseek( m_fd, 0, SEEK_END );
  ForceAssertEq( binary3_header_size, currpos );
  m_objectCount = 0;
  if( ! m_keep_open )
    ::Close(m_fd);
}

template <typename T>
Binary3Writer<T>::~Binary3Writer() {
  this->Close();
}

template <typename T>
void Binary3Writer<T>::Write( const T& object ) {
  if( ! m_keep_open ) {
    m_fd = ::Open( m_filename, O_WRONLY ^ O_APPEND );
  }

  WriteBytes( m_fd, (char*) &object, sizeof(T) );
  ++m_objectCount;

  if( ! m_keep_open ) {
    ::Close(m_fd);
  }
}

template <typename T>
void Binary3Writer<T>::WriteMultiple( const vec<T>& objects ) {
  if( objects.empty() ) return;

  if( ! m_keep_open ) {
    m_fd = ::Open( m_filename, O_WRONLY ^ O_APPEND );
  }

  WriteBytes( m_fd, (char*) &objects[0], sizeof(T) * longlong(objects.size()) );
  m_objectCount += objects.size();

  if( ! m_keep_open ) {
    ::Close(m_fd);
  }
}

template <typename T>
void Binary3Writer<T>::Close() {
  if ( m_fd < 0 ) 
    return;

  if( ! m_keep_open ) {
    m_fd = ::Open( m_filename, O_WRONLY );
    lseek( m_fd, 0, SEEK_END );    
  }
    
  off_t currpos = lseek( m_fd, binary_element_count_begin, SEEK_SET );
  ForceAssertEq( currpos, binary_element_count_begin );
  BinaryWriteSize( m_fd, m_objectCount );
  ::Close( m_fd );
  m_fd = -1;
  m_objectCount = 0;
}



template <typename T>
Binary3Iter<T>::Binary3Iter( const String& filename, T* p_to_fill, 
			     longlong max_memory ) {
  m_maxsize = max_memory / sizeof(T);

  m_fd = OpenForRead( filename );
  CheckHeader3<T>( filename, m_fd, m_globalSize );

  m_globalIndex = m_localIndex = 0;

  FillBuffer();
  
  if( ! m_data.empty() )
    (*p_to_fill) = m_data[0];
}

template <typename T>
Binary3Iter<T>::~Binary3Iter() {
  close( m_fd );
}

template <typename T>
void Binary3Iter<T>::Next( T* p_to_fill ) {
  if( ++m_globalIndex >= m_globalSize ) return;
  if( ++m_localIndex == m_data.size() ) FillBuffer();
  (*p_to_fill) = m_data[m_localIndex];
}

template <typename T>
void Binary3Iter<T>::FillBuffer() {
  ForceAssertEq( m_localIndex, m_data.size() );
  m_data.resize( min( longlong(m_maxsize), m_globalSize - m_globalIndex ) );
  if( m_data.size() > 0 )
    ReadBytes( m_fd, &m_data[0], sizeof(T) * longlong(m_data.size()) );
  m_localIndex = 0;
}


#define BINARY2_DEF(T)                                                         \
     template void BinaryWrite2( const String& filename, const vec<T>& v );    \
     template void BinaryRead2( const String& filename, vec<T>& v, bool );     \
     template void BinaryReadSubset2( const String& filename,                  \
          const vec<int>& ids, vec<T>& v, Bool, bool );                        \
     template void BinaryReadRange2( const String& filename,                   \
          longlong from, longlong to, vec<T>& v, bool );                       \
     template longlong BinarySize2<T>( const String& filename, bool )

#define BINARY3_DEF(T)                                                         \
     template void BinaryWrite3( const String& filename, const vec<T>& v );    \
     template void BinaryRead3( const String& filename, vec<T>& v, bool,       \
          const Bool  );                                                       \
     template void BinaryReadSubset3( const String& filename,                  \
          const vec<int>& ids, vec<T>& v, Bool, bool );                        \
     template void BinaryReadRange3( const String& filename,                   \
          longlong from, longlong to, vec<T>& v, bool );                       \
     template longlong BinarySize3<T>( const String& filename, bool );         \
     template void BinaryMmap3( const String&, auto_ptr< const mappedvec<T> >& ); \
     template mappedvec<T>::~mappedvec();                                      \
     template Binary3Writer<T>::Binary3Writer( const String&, bool );          \
     template Binary3Writer<T>::~Binary3Writer();                              \
     template void Binary3Writer<T>::Write( const T& );                        \
     template void Binary3Writer<T>::WriteMultiple( const vec<T>& );           \
     template void Binary3Writer<T>::Open( const String&, bool );              \
     template void Binary3Writer<T>::Close();                                  \
     template Binary3Iter<T>::Binary3Iter( const String&, T*, longlong );      \
     template Binary3Iter<T>::~Binary3Iter();                                  \
     template void Binary3Iter<T>::Next( T* );                                 \
     template void Binary3Iter<T>::FillBuffer()
