/* -*- mode: c++; c-basic-offset: 3; -*- */
#ifndef __EXTENSIONS__
#define __EXTENSIONS__
#endif
#ifndef _BSD_SOURCE
#define _BSD_SOURCE
#endif

// define DEBUG to enable log messages on std out
// DEBUG = 1: basic messages
// DEBUG = 2: verbose
// DEBUG = 3: verbose and random packet receiving errors
// #define DEBUG 1


#include "PConfig.h"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sys/types.h>
#include <sys/socket.h>
#if defined (P__LINUX)
// #include <sys/time.h>
#include <linux/types.h>
#include <linux/sockios.h>
#include <sys/ioctl.h>
typedef u_int32_t in_addr_t;
#elif defined (P__WIN32)
#include <sys/ioctl.h>
#define IFF_POINTOPOINT 0
#elif defined (P__DARWIN)
#include <sys/ioctl.h>
#else
#include <sys/sockio.h>
#endif
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h> 
#include <memory>
#include <vector>
#include <algorithm>
#include <string>
#include <sstream>
#include "framexmit/framerecv.hh"
#include "framexmit/matchInterface.hh"
#include "checksum_crc32.hh"
#ifdef DEBUG
#include <iostream>
#include "tconv.h"
#endif

using namespace std;

namespace framexmit {

   //===================================  Open the receiver socket.
   bool frameRecv::open (const char* mcast_addr, const char* interface,
                     int port)
   {
      if (sock >= 0) {
         close();
      }
   
      // open socket
      sock = socket (PF_INET, SOCK_DGRAM, 0);
      if (sock < 0) {
         return false;
      }
   
      // set reuse socket option 
      int reuse = 1;
      if (setsockopt (sock, SOL_SOCKET, SO_REUSEADDR, 
                     (char*) &reuse, sizeof (int)) != 0) {
         close ();
         return false;
      }
   
      // set receive buffer size 
      int bufsize = rcvInBuffersize;
      for (int bTest=bufsize; bTest<=rcvInBufferMax; bTest+=rcvInBuffersize) {
	 if (setsockopt (sock, SOL_SOCKET, SO_RCVBUF, 
			 (char*) &bTest, sizeof (int)) != 0) break;
	 bufsize = bTest;
      }
      if (setsockopt (sock, SOL_SOCKET, SO_RCVBUF, 
                     (char*) &bufsize, sizeof (int)) != 0) {
         close();
         return false;
      }
  
      // bind socket
      name.sin_family = AF_INET;
      name.sin_port = htons (port);
      name.sin_addr.s_addr = htonl (INADDR_ANY);
      if (::bind(sock, (struct sockaddr*) &name, sizeof (name))) {
         close();
         return false;
      }
   
      // options
      if (mcast_addr != 0) {
      
         // check multicast address
         if (!IN_MULTICAST (ntohl (inet_addr (mcast_addr)))) {
            close();
            return false;
         }
      
         // get interface address
         in_addr i_addr;
         if (!matchInterface (sock, interface, i_addr)) {
            close();
            return false;
         }
      
         // multicast: join
         group.imr_multiaddr.s_addr = inet_addr (mcast_addr);
         group.imr_interface.s_addr = i_addr.s_addr;
         if (setsockopt (sock, IPPROTO_IP, IP_ADD_MEMBERSHIP, 
                        (char*) &group, sizeof (ip_mreq)) == -1) {
            close();
            return false;
         }
         multicast = true;

         if (logison) {
            char buf[256];
            sprintf (buf, "Join multicast %s at %s", mcast_addr,
                    inet_ntoa(i_addr));
            addLog (buf);
         }
      #ifdef DEBUG
         cout << "join multicast " << mcast_addr << " at " 
	      << inet_ntoa (i_addr) << endl;
      #endif
      }
      else {
         multicast = false;
         // broadcast: enable
         int bset = 1;
         if (setsockopt (sock, SOL_SOCKET, SO_BROADCAST, 
                        (char*) &bset, sizeof (bset)) == -1) {
            close();
            return false;
         }
      #ifdef DEBUG
         cout << "broadcast port "  << port << endl;
      #endif
      }

      //--------------------------------  Open a request socket.
      rq_sock = socket(PF_INET, SOCK_DGRAM, 0);
      if (rq_sock < 0) {
	 perror("Error opening request port, rebroadcast requests disabled");
      }

      //  Bind the socket to a port
      else {
	 // set send buffer size 
	 bufsize = rcvOutBuffersize;
	 if (setsockopt (rq_sock, SOL_SOCKET, SO_SNDBUF, 
			 (char*) &bufsize, sizeof (int)) != 0) {
	    ::close(rq_sock);
	    rq_sock = -1;
	 }

	 sockaddr_in rqname;
	 rqname.sin_family = AF_INET;
 	 rqname.sin_port = htons(0);
	 rqname.sin_addr.s_addr = htonl(INADDR_ANY);
	 if (::bind(rq_sock, (struct sockaddr*) &rqname, sizeof (rqname))) {
	    perror("Error binding request port, rebroadcast requests disabled");
	    ::close(rq_sock);
	    rq_sock = -1;
	 }
      }

      //--------------------------------  check quality of service
      if ((qos < 0) || (qos > 2)) {
         qos = 2;
      }
      pkts.clear();
      first = true;

      return true;
   }


   void frameRecv::close ()
   {
      if (sock < 0) {
         return;
      }
   
      if (multicast) {
         // multicast drop
         if (setsockopt (sock, IPPROTO_IP, IP_DROP_MEMBERSHIP, 
                        (char*) &group, sizeof (ip_mreq)) == -1) {
            // do nothing; close anyway
         }
      }
   
      pkts.clear();
   
      ::close (sock);
      sock = -1;
      if (rq_sock >= 0) {
	 ::close(rq_sock);
	 rq_sock = -1;
      }
   }


   class packetControl {
   public:
      packetControl () : received (false) {
      }
      bool 		received;
   };

   bool isNotValidPacket (const packetControl& cntrl) {
      return !cntrl.received;
   }


   bool frameRecv::purge ()
   {
      // empty input queue
      int max = 10;
      while (getPacket (false)) {
         if (--max <= 0) {
            break;
         }
      }
      return !pkts.empty();
   }


   int64_t frameRecv::calcDiff (const frameRecv::packetlist& pkts, 
                     unsigned int newseq) {
      int64_t diff = 0;
      if (pkts.size() > 0) {
         diff = int64_t(pkts[0]->header.seq) - int64_t(newseq);
	 if (diff > (int64_t)0x80000000LL) {
            diff -= (int64_t)0x100000000LL;
         }
         if (diff < -(int64_t)0x80000000LL) {
            diff += (int64_t)0x100000000LL;
         }
      }
      return diff;
   }


   int frameRecv::receive (char*& data, int maxlen, unsigned int* sequence,
			   unsigned int* time, unsigned int* duration)
   {
      typedef std::vector<packetControl>  ctrl_vect;
      typedef ctrl_vect::iterator         ctrl_iter;

      int64_t lastTimeSequenceDone = 0;	// the last time a sequence finished
      int64_t firstPktTime = 0;	// when the first packet arrived


      // check socket
      if (sock < 0) {
         return -1;
      }
   
      // check on data array and max. length
      if (data == 0) {
         maxlen = -1;
      }
      else if (maxlen <= 0) {
         return -2;
      }
   
      // determine qos parameter
      double qosFrac;
      switch (qos) {
         case 0: 
            qosFrac = 0.3;
            break;
         case 1: 
            qosFrac = 0.1;
            break;
         default: 
            qosFrac = 0.03;
            break;
      }
   
      // read a packet if buffer empty
      if (pkts.size() == 0) {
         // request a packet
         if (!getPacket()) {
            return -3;
         }
	 firstPktTime = get_timestamp();
      }
   #ifdef DEBUG
      cout << "first packet received" << endl;
   #endif
      // calculate new sequnence number from old
      unsigned int newseq;

      // if first, set the old sequence number to one of the first packet
      if (first) {
         oldseq = pkts[0]->header.seq;
         retransmissionRate = 0.0;
	 newseq = oldseq;	// actually, we expect the sequence to start with oldseq
	                        // if this is the first time
	 first = false;

      }
      else {
	 newseq = oldseq + 1;	// actually, we expect the sequence to start with oldseq
      }

      // main receiver loop
      //bool requestPacket = false;	// request a new packet
      bool firstPacket = true;		// first packet received of this seq.?
      bool seqDone = false;		// sequence transmission done
      int len = 0;			// bytes received so far
      int retry = 0;			// # of rebroadcasts
      int64_t timestamp = 0;		// time stamp (in us) of retry timer
      // When retryExpired == true, then we request packets to be retransmitted
      bool retryExpired = true;		// keeps track of retry timer
      double newRetransmissionRate = 0;	// new retransmission rate
      ctrl_vect pktCntrl;	        // marks received packets
   
      //--------------------------------  Loop over bursts of packets
      while (1) {
	 seqDone = false;		// reset seqDone to false

      	 //-----------------------------  Fill the packet list from the input 
	 //                               socket (maximum of 10 packets)
         int max = 10;
         while (getPacket (false)) {
	    if(firstPktTime == 0)
	       	 firstPktTime = get_timestamp();


            if (--max <= 0) {
               break;
            }
         }
      
         // calculate difference between pkts[0].header.seq - newseq.
	 // this is where it loses the first packet. If this is the first packet,
	 // i.e. if first was set, then we should be using oldseq NOT newseq
         int64_t diff = calcDiff (pkts, newseq);

         //  Wait a receiver tick if no packets were received or
         //  only more recent packets are in the queue
         if (pkts.empty() || (!firstPacket && (diff > 0))) {
	    micro_delay(rcvDelayTick);
	 }
      
         //  Process the packets at front of queue if some are there
         if (!pkts.empty()) {

            // check if sequence out of sync
            if (((diff<0 ? -diff : diff) > maxSequenceOutOfSync) ||
               ((diff >= 1) && (len == 0))) {
	       newseq = pkts.back()->header.seq + 1;
	       pkts.clear();
 
               firstPacket = true;
               retry = 0;
               len = 0;
               seqDone = false;

               newRetransmissionRate = 0.0;
               if (logison) {
		  ostringstream msg;
                  msg << "Have to skip " << int(diff) << " sequence numbers."
		      << "New sequence number is " << newseq;
                  addLog(msg.str());
               }
            #ifdef DEBUG
	       if (diff) cout << "have to skip (" << int(diff) << ")" << endl;
            #endif
               continue;
            } // end if(out of sequence)
         
            // reject negative sequence difference (i.e. old packets)
            while (!pkts.empty() && (diff < 0)) {
               // skip old packets
               if (logison) {
		  ostringstream msg;
                  msg << "Skip packet " <<  pkts[0]->header.pktNum 
		      << " of out-of-order sequence " << pkts[0]->header.seq;
                  addLog (msg.str());
               }
               pkts.pop_front();
               diff = calcDiff (pkts, newseq);
            }
            if (pkts.empty()) {
               continue;
            }

            // found a valid packet
            while (!pkts.empty() && (diff == 0)) {
               // allocate memory and setup qos packet control if first packet
               if (firstPacket) {
                  // check if data array is long enough
                  int needed = pkts[0]->header.pktTotal * packetSize;
                  if ((maxlen > 0) && (maxlen < needed)) {
                     return -needed;
                  }
                  // allocate data array if necessary
                  if (maxlen == -1) {
                     if (data != 0) {
                        delete [] data;
                     }
                     data = new (std::nothrow) char [needed];
                  }
                  if (logison && !data) {
                     char buf[256];
                     sprintf (buf, "Memory allocation for %i failed",
                             needed);
                     addLog (buf);
                  }
                  // check if data array is valid
                  if (data == 0) {
                     return -5;
                  }
                  // set sequence, time stamp & duration on caller's request
                  if (sequence != 0) *sequence = newseq;
                  if (time != 0)     *time     = pkts[0]->header.timestamp;
                  if (duration != 0) *duration = pkts[0]->header.duration;

                  // setup qos & packet control
		  pktCntrl.clear();
                  pktCntrl.resize(pkts[0]->header.pktTotal);
                  newRetransmissionRate = 
                     (1.0 - 1.0/retransmissionAverage) * retransmissionRate;
                  firstPacket = false;
               }
            
               // copy packet payload into receiver buffer
               int pktNum = pkts[0]->header.pktNum;
               if ((pktNum >= 0) && (pktNum < (int)pktCntrl.size()) &&
                  (!pktCntrl[pktNum].received)) {
                  memcpy (data + pktNum * packetSize, pkts[0]->payload,
                         pkts[0]->header.pktLen);
                  len += pkts[0]->header.pktLen;
                  pktCntrl[pktNum].received = true;
               #if DEBUG - 0 > 2 // add a random error rate!
                  if ((double)rand() < rcvErrorRate * 32768.0) {
                     pktCntrl[pktNum].received = false;
                     len -= pkts[0]->header.pktLen;
                  }
               #endif
               }
            
               // check if last packet
               if (pkts[0]->header.pktTotal == pkts[0]->header.pktNum + 1) {
                  seqDone = true;
               #ifdef DEBUG
                  //  cout << "received last package at " <<
                  //     (TAInow() % 1000000000000) / 1E9 << endl;
               #endif
               }
               // remove packet
               pkts.pop_front();
               diff = calcDiff (pkts, newseq);
            }  // while (!pkts.empty() && !diff

            // sequence done if next sequence is already in buffer
	    // --- condition on len != 0  (alex?)
            if (diff > 0 && len) {
               seqDone = true;
            }
         } // if !pkts.empty()


	 //-----------------------------  Time to give up waiting?
	 //
	 //  Give up waiting for another packet after 10ms
	 int64_t timestampNow = get_timestamp();

	 ////////////////////////////////////////////////////////////
	 // NOTE: the very first packet of a sequence will be coming in on the order of seconds from
	 // the last transmission, so we do NOT want to say things have timed out if it's been more
	 // than a couple of milliseconds!
	 // So we do NOT check the time stamp if:
	 //  (1) we've only received one packet so far
	 //  (2) that packet was the beginning of a series (i.e. packet number 0)
	 // This coincides with zero retries and the first packet
	 // ALSO:
	 // Can reduce the number of calls if make use of lastTimeSequenceDone:
	 // Just take the maximum of (lastPktTime, lastTimeSequenceDone) and compare timestampNow to that
	 ////////////////////////////////////////////////////////////
	 if(!(firstPacket && (retry==0))) {
	    // when trying to figure out if we've lost a packet, take the later of:
	    //  (1) the last packet arrival time; or
	    //  (2) the last time we decided a sequence was done
	    int64_t lastInterestingTime = lastPktTime;
	    if(lastTimeSequenceDone > lastPktTime) {
	       lastInterestingTime = lastTimeSequenceDone;
	    }
	    if(timestampNow >= lastInterestingTime + par.lostPacketWait) {

	       seqDone = true;
	    } // if timestamp
	 } // if !firstpacket && (retry==0)


         //-----------------------------  Continue if sequence not yet finished
         if (!seqDone) {
            continue;
         }

         // calculate # of missing packets
         int missingPackets = count_if (pktCntrl.begin(), pktCntrl.end(), 
                              isNotValidPacket);
	 ////////////////////////////////////////////////////////////
	 // Sequence done: either actually finished, or timed out
	 ////////////////////////////////////////////////////////////
	 lastTimeSequenceDone = get_timestamp();	// save the last time a sequence finished
      
         //-----------------------------  Check if complete
	 //         if (missingPackets == 0) {
	 if (missingPackets == 0 && pktCntrl.size() != 0) {

            oldseq = newseq;
            retransmissionRate = newRetransmissionRate;
            return len; // done!
         }
         #ifdef DEBUG
         // cout << "missing packets " << missingPackets << 
            // " (packets pending " << pkts.size() << ")" << endl;
         #endif

	 ////////////////////////////////////////////////////////////
	 // Retransmit logic
	 // UWM: request retransmit 2.2s after sequence has started to
	 // arrive, then at 360ms intervals thereafter
	 ////////////////////////////////////////////////////////////
         if ((retry > 0) && !retryExpired) {

            // check if retry is expired
	    retryExpired = !timestampNow || 
	                   (timestampNow - timestamp) > retryTimeout;
         } // if retry > 0 && !retryExpired


         // check qos
         if ((missingPackets > maximumRetransmit) ||
	     // TURN OFF QoS checking
	     // (retransmissionRate > qosFrac) || 
            ((retry >= maxRetry) && (retryExpired)) ||
            ((int)pkts.size() >= rcvpacketbuffersize)) {

	       // Save the original length of the packet buffer to pkts_size
	       int pkts_size = pkts.size();
	       unsigned int sequence_delete = newseq;
	       int newseq_found = 0;
	       // Search through packet buffer starting from the head of the buffer                                                 
	       for(int i = 0; i < pkts_size; ) {
		  // Is this packet of the same sequence we want to delete?                                                         
		  if(pkts[i]->header.seq == sequence_delete) {
		     // Yes: erase this packet from the buffer

		     pkts.erase(pkts.begin() + i);
		     // AND show that this buffer is now one packet shorter                                                         
		     pkts_size --;
		     // Note: we don't have to update i here since i will again point                                               
		     // to the next element (since we deleted this element)                                                         
		  }
		  else {
		     // This is a packet we want to keep. Have we already found a new sequence number?                              
		     if(!newseq_found) {

			newseq = pkts[i]->header.seq;
			newseq_found = 1;
		     }
		     // Move on to the next packet in the buffer                                                                    
		     i++;
		  }
	       }
	       // Did we not find any new sequence? If so, just increment the sequence number by 1                                  
	       if(!newseq_found) {
		  newseq ++;
	       }
            firstPacket = true;
            retry = 0;
            len = 0;
            seqDone = false;
            retransmissionRate = newRetransmissionRate;

	    if(!pkts.empty()) {
	        diff = calcDiff (pkts, newseq);
	    }
	    else {
	       diff = 0;
	    }

         }
         else if ((retry == 0) || (retryExpired)) {


            // calculate new retransmission rate
            if (retry == 0) {
               newRetransmissionRate = 
                  (1.0 - 1.0/retransmissionAverage) * retransmissionRate + 
                  1.0/retransmissionAverage * 
                  (double) missingPackets / pktCntrl.size();
            }
            // ask for retransmit
            retransmitpacket	rpkt;
            memset (&rpkt.header, 0, sizeof (packetHeader));
            rpkt.header.pktType = PKT_REQUEST_RETRANSMIT;
            rpkt.header.seq = newseq;
            rpkt.header.pktNum = 0;
            rpkt.header.pktTotal = 1;
            int n = 0;
            int i = 0;
            for (ctrl_iter iter = pktCntrl.begin(); 
		 iter != pktCntrl.end(); iter++, i++) {
               // check if packet has to be resent
               if (!iter->received) {
                  rpkt.pktResend[n++] = i;
                  // check if packet is full
                  if (n >= packetSize / (int)sizeof (int32_t)) {
                     break;
                  }
               }
            }
            if (logison) {
               char buf[256];
               sprintf (buf, "Ask for retransmission of %i packets\n"
                       "New retransmission rate is %g",
                       n, newRetransmissionRate);
               addLog (buf);
            }
         #ifdef DEBUG
            cout << "ask for retransmission " << n << " (rate " <<
               newRetransmissionRate << "; seq " << newseq << ") at " << 
               (TAInow() % 1000000000000LL) / 1E9 << endl;
         #endif
         #if DEBUG - 0 > 1
            cout << "Packets =";
            for (int jj = 0; jj < n; ++jj) {
               cout << " " << rpkt.pktResend[jj];
            }
            cout << endl;
         #endif
            rpkt.header.pktLen = n * sizeof (int32_t);
            if (!putPacket (rpkt)) {
               return -6;
            }
         
            // start new timer
            retryExpired = false;
	    timestamp = get_timestamp();
            retry++;
         }
      
         //requestPacket = true;
      } // end loop over bursts of packets, while (1)

      // never reached
      return false;
   }

   //===================================  Receive a packet and put it into pkts
   bool frameRecv::getPacket (bool block) {
      // check if enough space for a new packet
      if ((int)pkts.size() >= rcvpacketbuffersize) {
         if (logison) {
            addLog ("Packet buffer is full");
         }
      #ifdef DEBUG
         cout << "packet buffer full" << endl;
      #endif
         return false;
      }
   
      // allocate packet buffer
      auto_pkt_ptr	pkt = auto_pkt_ptr (new (nothrow) packet);
      if (pkt.get() == 0) {
         return false;
      }
   
      // receive a packet from socket
      checksum_crc32 crc;
      uint32_t save_crc = 0;
      int n;
      do {
         // if not blocking poll socket
         if (!block) {
            timeval 	wait;		// timeout=0
            wait.tv_sec = 0;
            wait.tv_usec = 0;
            fd_set 		readfds;	
            FD_ZERO (&readfds);
            FD_SET (sock, &readfds);
            int nset = select (FD_SETSIZE, &readfds, 0, 0, &wait);
            if (nset < 0) {
            #ifdef DEBUG
               cout << "select on receiver socket failed" << endl;
            #endif
               return false;
            }
            else if (nset == 0) {
               return false;
            }
         }
      	 // read packet
         socklen_t max = sizeof (name);
         n = recvfrom (sock, (char*) pkt.get(), sizeof (packet), 0, 
                      (struct sockaddr*) &name, &max);
         if (n < 0) {
            return false;
         }
	 if (size_t(n) < sizeof (packetHeader)) continue;

	 //-----------------------------  Calculate the checksum
	 save_crc = ntohl(pkt->header.checksum);
	 if (save_crc != 0) {
	    crc.reset();
	    pkt->header.checksum = 0;
	    crc.add(reinterpret_cast<unsigned char*>(pkt.get()), n);
	    pkt->header.checksum = crc.result();
	 }

         //-----------------------------  Swap if necessary
         pkt->ntoh();
      // repeat until valid packet
      } while (pkt->header.checksum != save_crc || 
	       n < (int)sizeof(packetHeader) || 
              ((pkt->header.pktType != PKT_BROADCAST) && 
              (pkt->header.pktType != PKT_REBROADCAST)) ||
              (n != (int)sizeof (packetHeader) + pkt->header.pktLen));
   
      #if DEBUG - 0 > 1
         cout << "received packet " << pkt->header.pktNum << 
             (pkt->header.pktType == PKT_REBROADCAST ? "R " : " ") <<
             " seq=" << pkt->header.seq << endl;
      #endif 
      // now add packet to packet buffer
      packetlist::iterator pos = lower_bound (pkts.begin(), pkts.end(), pkt);
      // check if end of list
      if (pos == pkts.end()) {
         pkts.push_back (pkt);
      }
      // check if duplicate
      else if (*pos == pkt) {
         // do nothing
      }

      //--------------------------------  Check if front of list
      else if (pos == pkts.begin()) {
         pkts.push_front (pkt);
      }

      //--------------------------------  Otherwise insert at position
      else {
         pkts.insert (pos, pkt);
      }

      //--------------------------------  Record the packet arrival time
      lastPktTime = get_timestamp();
      return true;
   }

   //===================================  Send a rebroadcast request packet.
   bool frameRecv::putPacket (retransmitpacket& pkt) {
      int sbytes = sizeof (packetHeader) + pkt.header.pktLen;
      // swap if necessary
      pkt.hton();
      if (sendto (rq_sock, (char*) &pkt, sbytes, 0, 
                 (struct sockaddr*) &name, 
                 sizeof (struct sockaddr_in)) != sbytes) {
         ::close (rq_sock);
         rq_sock = -1;
         return false;
      }
      else {
         return true;
      }
   }


   const char* frameRecv::logmsg () {
      return copyLog();
   }

   void frameRecv::addLog (const string& s)
   {
      logs.push_back (s + "\n");
      if ((int)logs.size() > maxlog) {
         logs.pop_front();
      }
   }

   void frameRecv::addLog (const char* s)
   {
      string l (s);
      addLog (l);
   }

   const char* frameRecv::copyLog ()
   {
      logbuf = "framexmit log:\n";
      for (deque<string>::iterator i = logs.begin(); 
          i != logs.end(); i++) {
         logbuf += *i;
      }
      return logbuf.c_str();
   }

   void frameRecv::clearlog ()
   {
      logs.clear();
      logbuf = "";
   }


}
