/*  
  Copyright 2002 Andreas Rottmann

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  This library is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  General Public License for more details.

  You should have received a copy of the GNU General Public  License
  along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
*/
#include <sys/socket.h>
#include <sys/un.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>

#include <errno.h>

#include <glib.h>

#include <sigcx/util.h>

#include "yehia/net/socket.h"

namespace Yehia
{

namespace Net
{

using SigCX::errno_string;

Address::Address(const void *data, int size)
{
  data_ = 0;
  copy_data(data, size);
}

Address::Address()
{
  data_ = 0;
  size_ = 0;
}

Address::Address(const Address& addr)
{ 
  data_ = 0;
  copy_data(addr.data_, addr.size_); 
}

Address::~Address()
{
  set_data(0, 0);
}

void Address::set_data(void *data, int size)
{
  if (data_)
    delete[] (unsigned char *)data_;
  data_ = data;
  size_ = size;
}

void Address::copy_data(const void *data, int size)
{
  if (data_)
    delete[] (unsigned char *)data_;
  data_ = new unsigned char[size];
  memcpy(data_, data, size);
  size_ = size;
}

InetAddress::InetAddress(const std::string& a, unsigned short port)
{
  bool tried_resolve = false;
  std::string address = a;

  TRY_IP_CONV:
  {
    struct sockaddr_in addr;
    if (inet_pton(AF_INET, address.c_str(), &addr.sin_addr))
    {
      addr.sin_family = AF_INET;
      addr.sin_port = htons(port);
      copy_data(&addr, sizeof(addr));
      return;
    }
  }
  {
    struct sockaddr_in6 addr;
    if (inet_pton(AF_INET6, address.c_str(), &addr.sin6_addr))
    {
      addr.sin6_family = AF_INET6;
      addr.sin6_port = htons(port);
      addr.sin6_flowinfo = 0;
      addr.sin6_scope_id = 0;
      copy_data(&addr, sizeof(addr));
      return;
    }
  }
  if (tried_resolve)
    throw Invalid(a);
  address = host_by_name(address);
  if (address.empty())
    throw Invalid(a);
  tried_resolve = true;
  goto TRY_IP_CONV;
}

InetAddress::InetAddress(const Address& addr)
{
  struct sockaddr *saddr = (struct sockaddr *)addr.data();
  switch (saddr->sa_family)
  {
    case AF_INET:
    case AF_INET6:
      break; // ok
    default:
      throw Invalid("cannot convert to internet address");
  }
  copy_data(addr.data(), addr.size());
}

InetAddress::InetAddress(Type type)
{
  switch (type)
  {
    case IPV4:
    {
      struct sockaddr_in addr;
      addr.sin_family = AF_INET;
      addr.sin_port = htons(0);
      addr.sin_addr.s_addr = htonl(INADDR_ANY);
      copy_data(&addr, sizeof(addr));
      break;
    }
    case IPV6:
    {
      struct sockaddr_in6 addr;
      addr.sin6_family = AF_INET6;
      addr.sin6_port = htons(0);
      addr.sin6_addr = in6addr_any;
      addr.sin6_flowinfo = 0;
      addr.sin6_scope_id = 0;
      copy_data(&addr, sizeof(addr));
      break;
    }
    default:
      g_assert_not_reached();
  }
}

InetAddress::Type InetAddress::type() const
{
  struct sockaddr *saddr = (struct sockaddr *)data();
  
  switch (saddr->sa_family)
  {
    case AF_INET: 	return IPV4;
    case AF_INET6: 	return IPV6;
  }
  
  g_assert_not_reached();
  return IPV4; // NOTREACHED, just to silence compiler
}

std::string InetAddress::address() const
{
  struct sockaddr *saddr = (struct sockaddr *)data();
  
  switch (saddr->sa_family)
  {
    case AF_INET:
      return "";
    case AF_INET6: 	
      return "";
    default:
      g_assert_not_reached();
  }
  return std::string();
}

int InetAddress::port_by_name(const std::string& name)
{
  return -1;
}

int InetAddress::proto_by_name(const std::string& name)
{
  return -1;
}

std::string InetAddress::host_by_name(const std::string& name)
{
  struct hostent *hp;
  char buffer[128];
#if HAVE_GETHOSTBYNAME_R
  struct hostent hostbuf;
  size_t hstbuflen;
  char *tmphstbuf;
  int res;
  int herr;
          
  hstbuflen = 1024;
  /* Allocate buffer, remember to free it to avoid memory leakage.  */
  tmphstbuf = g_new(char, hstbuflen);
  
  while ((res = gethostbyname_r(name.c_str(), &hostbuf, tmphstbuf, hstbuflen,
                                &hp, &herr)) == ERANGE)
  {
    /* Enlarge the buffer.  */
    hstbuflen *= 2;
    tmphstbuf = g_realloc(tmphstbuf, hstbuflen);
  }

  /*  Check for errors.  */
  if (res || hp == NULL)
    return std::string();
#else
  hp = gethostbyname(name.c_str());
#endif

  if (!hp || 
      inet_ntop(hp->h_addrtype, hp->h_addr, buffer, sizeof(buffer)) != buffer)
    return std::string();

#ifdef HAVE_GETHOSTBYNAME_R
  g_free(tmphstbuf);
#endif

  return buffer;
}

union SockAddr
{
    struct sockaddr sa;
    struct sockaddr_un un;
    struct sockaddr_in in;
    struct sockaddr_in6 in6;
};

Socket::Socket(const Address& address, Style style, int proto)
{
  struct sockaddr *saddr = (struct sockaddr *)address.data();
  
  fd_ = -1;

  ensure_open(address, style, proto);
  
  if (bind(fd_, saddr, address.size()) < 0)
    throw Error(errno_string());
}

Socket::~Socket()
{
}

ssize_t Socket::read(void *buffer, size_t size)
{
  g_return_val_if_fail(fd_ >= 0, 0);
  
  return ::read(fd_, buffer, size);
}

ssize_t Socket::write(const void *buffer, size_t size)
{
  size_t bytesleft = size;

  g_return_val_if_fail(fd_ >= 0, 0);
  
  while (bytesleft > 0)
  {
    ssize_t written = ::write(fd_, buffer, bytesleft);
    if (written < 0)
      return written;
    bytesleft -= written;
  }
  
  return size;
}

void Socket::connect(const Address& address) throw (Error)
{
  if (::connect(fd_, (struct sockaddr *)address.data(), address.size()) < 0)
    throw Error(errno_string());
}

Socket::Style Socket::style() const
{
  // XXX: implement
  return STREAM;
}

Address Socket::address() const
{
  SockAddr addr;
  socklen_t len = sizeof(addr);
  
  if (getsockname(fd_, (struct sockaddr *)&addr, &len) < 0)
    throw Error(errno_string());
  
  return Address(&addr, len);
}

void Socket::ensure_open(const Address& addr, Style style, int proto)
{
  const struct sockaddr *saddr = (const struct sockaddr *)addr.data();
  int c_style = 
    (style == STREAM) ? SOCK_STREAM :
    (style == DATAGRAM) ? SOCK_DGRAM : SOCK_RAW;

  
  if (fd_ >= 0)
    close();
  
  switch (saddr->sa_family)
  {
    case AF_INET:
      if ((fd_ = ::socket(PF_INET, c_style, proto)) < 0)
        throw Error(errno_string());
      break;
    case AF_INET6:
      if ((fd_ = ::socket(PF_INET6, c_style, proto)) < 0)
        throw Error(errno_string());
      break;
    default:
      throw Error("invalid address family");
  }
}

void Socket::close()
{
  // XXX: err check, exception
  if (fd_ >= 0 && ::close(fd_) >= 0)
  {
    fd_ = -1;
  }
}

void Socket::shutdown(Direction d)
{
  int how =  (d == RECEIVE) ? 0 : (d == TRANSMIT) ? 1 : 2;
  
  // XXX : err check, exception
  if (fd_ >= 0) 
    ::shutdown(fd_, how);
}

void Socket::listen(int n)
{
  if (fd_ < 0)
    throw Error("Socket::listen(): closed socket");
  
  if (::listen(fd_, n) < 0)
    throw Error("Socket::listen(): " + errno_string());
}

Socket Socket::accept(Address& address) throw (Error)
{
  SockAddr addr;
  
  socklen_t size;
  int newsock = ::accept(fd_, (struct sockaddr *)&addr, &size);
  if (newsock < 0)
    throw Error(errno_string());
  
  address.copy_data(&addr, size);
  
  return Socket(newsock);
}

Socket Socket::accept() throw (Error)
{
  SockAddr addr;
  
  socklen_t size = sizeof(addr);
  int newsock = ::accept(fd_, (struct sockaddr *)&addr, &size);
  if (newsock < 0)
    throw Error(errno_string());
  
  return Socket(newsock);
}

Socket::Socket(int fd)
{
  fd_ = fd;
}

bool Socket::set_nonblocking(bool nonblock)
{
  int flags = fcntl(fd_, F_GETFD, 0);
  if (flags < 0)
    throw std::runtime_error(errno_string());
  
  bool had_nonblock = flags & O_NONBLOCK;
  if (nonblock && !had_nonblock)
    flags |= O_NONBLOCK;
  else if (!nonblock && had_nonblock)
    flags &= ~O_NONBLOCK;
  else
    return had_nonblock;
  
  if (fcntl(fd_, F_SETFD, flags) < 0)
    throw std::runtime_error(errno_string());

  return had_nonblock;
}

bool Socket::set_keepalive(bool keepalive)
{
  int prev_val;
  socklen_t prev_val_sz = sizeof(prev_val);
  int new_val = keepalive;

  if (getsockopt(fd_, SOL_SOCKET, SO_KEEPALIVE, 
                 &prev_val, &prev_val_sz) != 0)
    throw Error(errno_string());
  if (setsockopt(fd_, SOL_SOCKET, SO_KEEPALIVE, 
                 &new_val, sizeof(new_val)) != 0)
    throw Error(errno_string());
  
  return prev_val;
}

}

}
