#include <stdio.h>
#define __USE_GNU
#include <stdlib.h>
#include <stdarg.h>
#include <sys/types.h>

#include <pthread.h>
#include <stdint.h>
#include <getopt.h>
#include <errno.h>
#include <assert.h>
#include <time.h>


  #ifdef WIN32
#include <Winsock2.h>
#include <signal.h>
// #include <ws2tcpip.h>
#else
#include <arpa/inet.h>
#include <sys/select.h>
#include <sys/wait.h>
#include <sys/socket.h>
#include <sys/resource.h>
#include <netdb.h>
#include <unistd.h>
#include <uuid/uuid.h>
#include <execinfo.h>
#include <netinet/in.h>
#endif  

#include <tdata/usrtc.h>
#include <sexpr/sexp.h>
#include <sntl/connection.h>

/* define a little bit */
#define DEFAULT_PORT  13133
#define CHANNEL_COUNT 200
#define CLIENT_COUNT 100
#define MESSAGES_PER_SESSION 10000
#define ITERATION_COUNT 1000

#define FAILS_ONLY
//#define SIMPLE_TESTING

static FILE *log_file = NULL;

inline int log_init(const char *file)
{
  if(log_file) { 
    fclose(log_file); 
  } 
  log_file = fopen(file, "w");
  if(!log_file) return EIO;
  return 0;
}

inline void log_msg(const char *prefix, const char *data)
{
  if(log_file) fprintf(log_file, "[%s]: %s\n", prefix, data);
}
 
inline void log_begin(const char *data) 
{
#ifndef FAILS_ONLY
  log_msg("BEGIN", data);
#endif
}

inline void log_end(const char *data) 
{
#ifndef FAILS_ONLY
  log_msg("END", data);
#endif
}

inline void log_info(const char *data) 
{
#ifndef FAILS_ONLY
  log_msg("INFO", data);
#endif
}

inline void log_error(const char *data) 
{
  log_msg("FAILED", data);
}

inline void log_assert(const char *info, int rc, int exp) 
{
  if(log_file && (rc) != (exp)) {
    fprintf(log_file, "[FAILED]: %s result: %d, expected: %d\n", info, rc, exp);
  }
}

inline void log_close()
{
  if(log_file) { 
    fflush(log_file); 
    fclose(log_file); 
  }
}

inline void log_flush()
{
  fflush(log_file); 
}

void signal_error(int sig, siginfo_t *si, void *ptr)
{
    void* error_addr;
    void* trace[16];
    int    x;
    int    trace_size;
    char** messages;

    fprintf(stderr, "Something is wrong: backtrace: \n");
    uintptr_t fptr = (uintptr_t)(si->si_addr);
    fprintf(stderr, "Signal: %d, function pointer: 0x%.12lX \n", sig, fptr);
    #if __WORDSIZE == 64
        error_addr = (void*)((ucontext_t*)ptr)->uc_mcontext.gregs[REG_RIP];
    #else
        error_addr = (void*)((ucontext_t*)ptr)->uc_mcontext.gregs[REG_EIP];
    #endif

    trace_size = backtrace(trace, 16);
    trace[1] = error_addr;

    messages = backtrace_symbols(trace, trace_size);
    if (messages)
    {
        for (x = 1; x < trace_size; x++)
        {
          fprintf(stderr, "%s\n", messages[x]);
        }
        free(messages);
    }
    
    fprintf(stderr, "end of backtrace\n");

    exit(1);
}

typedef struct
{
  pthread_t **threads;
  int thread_count;
  conn_t *co;
} test_data_t;

/*static*/ sexp_t *make_request(const char *req)
{
  char *request = strdup(req);
  sexp_t *sx = parse_sexp(request, strlen(request));
  free(request);
  return sx;
}

/*static */int allocate_threads(int count, test_data_t *data)
{
  int i = 0;           
  data->threads = (pthread_t **)malloc(count * sizeof(pthread_t *));
  for(i = 0; i < count; ++i) {
    data->threads[i] = (pthread_t *)malloc(sizeof(pthread_t));
    if(!data->threads) return ENOMEM;
  }
  data->thread_count = count;
  return 0;
}
/*static */int deallocate_threads(test_data_t *data)
{
  int i = 0;
  for(i = 0; i < data->thread_count; ++i) {
    if(!data->threads[i]) return EINVAL;
    pthread_join(*data->threads[i], NULL);
    free(data->threads[i]);
    data->threads[i] = 0;
  }
  free(data->threads);
  data->thread_count = 0;
  
  return 0;
}

void *test_invalid_channel(void *ctx)
{
  log_begin("Invalid channel testing");
  conn_t *co = (conn_t *)ctx;
  chnl_t *channel = NULL;
  int rc = 0, i;
  for(i = 0; i < ITERATION_COUNT || ITERATION_COUNT < 0; ++i) {
    rc = channel_open(co, &channel, 1);
    log_assert("channel_open with type 1", rc, EINVAL);
    // TODO: segmentation fault below 
    //rc = channel_close(channel);
    //log_assert("channel_close with type 1", rc, EINVAL);
  }
  log_end("Invalid channel testing");
  return 0x00;
}

void *test_correct_channel(void *ctx)
{
  log_begin("Channel testing");
  
  conn_t *co = (conn_t *)ctx;
  chnl_t *channel = NULL;
  int rc = 0, i, j;
  char buf[128];
  time_t start, end;
  int a, b;
  sexp_t *add_request = NULL;
  sxmsg_t *msg = NULL;
  double exec_time;
  
  for(j = 0; j < ITERATION_COUNT || ITERATION_COUNT < 0; ++j) {
    rc = channel_open(co, &channel, 12);
    log_assert("channel_open with type 12", rc, 0);
    log_begin("Test messaging");
    //#if 0
    for(i = 0; i < MESSAGES_PER_SESSION; ++i) {
      a = rand() % 100;
      b = rand() % 100;
      sprintf(buf, "(ar-add (%d %d))", a, b);
      add_request = make_request(buf);
      time(&start);
      rc = msg_send(channel, add_request, &msg);
      time(&end);
      exec_time = difftime(end, start);
      sprintf(buf, "rpc execution time: %lf", exec_time);
      log_info(buf);
      log_assert("rpc execution", rc, a + b);
      //destroy_sexp(add_request);
    }
    //#endif
    log_end("Test messaging");

    rc = channel_close(channel);
    log_assert("channel_close with type 12", rc, 0);
  }
  
  log_end("Channel testing");
  return 0x00;
}

int test_channels(test_data_t *data, int index)
{
  int rc = 0;  
  if(index < CLIENT_COUNT) {
    rc = pthread_create(data->threads[index], NULL, test_correct_channel, data->co);
  } else {
    rc = pthread_create(data->threads[index], NULL, test_invalid_channel, data->co);
  }
  
  return rc;
}

void test_channel_handling(conn_t *co)
{
  chnl_t *channel = NULL;
  int rc = 0, i = 0;
  
  for(i = 0; i < ITERATION_COUNT; ++i) {
    rc = channel_open(co, &channel, 12);
    log_assert("channel open function", rc, 0);
    rc = channel_close(channel);
    log_assert("channel close function", rc, 0);
  }
}

void test_message_handling(conn_t* co)
{
  chnl_t *channel = NULL;
  int rc = 0, i = 0, a = 0, b = 0;
  sexp_t *sx = NULL;
  char *buf = NULL;
  sxmsg_t *msg = NULL;
  
  buf = malloc(4096);
  rc = channel_open(co, &channel, 12);
  log_assert("channel open function", rc, 0);
  for(i = 0; i < ITERATION_COUNT; ++i) {
    a = rand() % 100;
    b = rand() % 100;
    sprintf(buf, "(ar-add (%d %d))", a, b);
    sx = parse_sexp(buf, strlen(buf));
    rc = msg_send(channel, sx, &msg);
    log_assert("message send function", rc, a + b);
    // destroy_sexp(sx);
  }
  rc = channel_close(channel);
  free(buf);
  log_assert("channel close function", rc, 0);
}

int main(int argc, char **argv)
{
  // set detailed signal handler
  struct sigaction sigact;
  sigact.sa_flags = SA_SIGINFO;
  sigact.sa_sigaction = signal_error;
  sigemptyset(&sigact.sa_mask);
  sigaction(SIGFPE, &sigact, 0);
  sigaction(SIGILL, &sigact, 0);
  sigaction(SIGSEGV, &sigact, 0);
  sigaction(SIGBUS, &sigact, 0);
  
  char *rootca = NULL, *cert = NULL;
  int port = DEFAULT_PORT;
  char *addr = NULL, *login = NULL, *password = NULL;
  int opt;
#ifndef SIMPLE_TESTING
  int rc, i;
#endif
  while((opt = getopt(argc, argv, "p:r:a:u:l:w:")) != -1) {
    switch(opt) {
    case 'p':
      port = atoi(optarg);
      break;
    case 'r':
      rootca = strdup(optarg);
      break;
    case 'a':
      addr = strdup(optarg);
      break;
    case 'u':
      cert = strdup(optarg);
      break;
    case 'l':
      login = strdup(optarg);
      break;
    case 'w':
      password = strdup(optarg);
      break;
    default:
      fprintf(stderr, "usage: %s [-p <PORTNUM>] -r <PATH to Root CA> -a <Server ip address> -u <PATH"
              " to SSL certificate> -l <User login> -w <User password>\n", argv[0]);
      return EINVAL;
    }
  }

  if(!rootca) {
    fprintf(stderr, "Root CA not pointed.\n Failure.\n");
    return EINVAL;
  }

  if(!addr) {
    fprintf(stderr, "Server address not pointed.\n Failure.\n");
    return EINVAL;
  }

  if(!cert) {
    fprintf(stderr, "User certificate not pointed.\n Failure.\n");
    return EINVAL;
  }

  if(!login) {
    fprintf(stderr, "User login not pointed.\n Failure.\n");
    return EINVAL;
  }

  if(!password) {
    fprintf(stderr, "User password not pointed.\n Failure.\n");
    return EINVAL;
  }

  /* all is fine let's init connection subsystem */
  opt = connections_subsystem_init();
  if(opt) {
    fprintf(stderr, "Subsystem init failed: %d\n", opt);
    return opt;
  }
  /* set working certificates */
  opt = connections_subsystem_setsslserts(rootca, cert, cert);
  if(opt) {
    fprintf(stderr, "Subsystem init failed (set SSL x.509 pems): %d\n", opt);
    return opt;
  }
  
  /* Tests */
  /* try to open connection */
  conn_t *co = malloc(sizeof(conn_t)), *co2 = malloc(sizeof(conn_t));
  perm_ctx_t *ctx = (perm_ctx_t *)malloc(sizeof(perm_ctx_t));
  ctx->login = login;
  ctx->passwd = password;

  log_init("test.log");
          
  log_begin("Connection initiate");
  log_assert("Connection initiate", connection_initiate(co, addr, port, cert, ctx), 0);
  log_end("Connection initiate");

  log_begin("Connection initiate (second one for test)");
  log_assert("Connection initiate (second)", connection_initiate(co2, addr, port, cert, ctx), 0);
  log_end("Connection initiate (second)");

  printf("HERE!!!!\n");

  log_begin("Connection close");
  log_assert("Connection close", connection_close(co), 0);
  log_end("Connection close");

  log_begin("Connection close (second)");
  log_assert("Connection close (second)", connection_close(co2), 0);
  log_end("Connection close (second)");

  log_close();

  free(rootca);
  free(cert);
  free(co);
  free(ctx);
  free(password);
  free(login);
  free(addr);

  return 0;
}