/**
* Copyright 1981-2007 ECMWF
* 
* Licensed under the GNU Lesser General Public License which
* incorporates the terms and conditions of version 3 of the GNU
* General Public License.
* See LICENSE and gpl-3.0.txt for details.
*/

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <errno.h>

#include <sys/stat.h>
#include <fcntl.h>
#include <sys/time.h>

#include <sys/sem.h>

#include "sharedlib.h"

int sharedlib_dbg = 0;

/* 10 MB */

#define BUFFER_SIZE 10485760
#define MAGIC       1234567890

struct sembuf _lock[] = {
    { 0, 0,  SEM_UNDO }, /* test */
    { 0, 1,  SEM_UNDO }, /* lock */
};

struct sembuf _unlock[] = {
    { 0, -1, SEM_UNDO }, /* ulck */
};

#define INFO_PATH 1024

struct info {
    int ready;
    int magic;
    char path[INFO_PATH];
};



void *share_file(const char* file) {
    char message[1024];
    char path[1024];

    struct  stat64 s;
    struct  stat64 s2;
    int fd = -1;
    int err = 0;
    void *ptr = NULL;
    int shmid = 0;
    key_t key;
    char *addr;
    struct timeval start, end, diff;
    double time;
    int sem;
    int loadfile = 1;
    int locked = 0;
    int page_size = getpagesize();
    struct info *nfo;

    if(page_size < 0)
    {
        fprintf(stderr,"ERR: sharedmem:get_page_size faile\n");
        return NULL;
    }

    if(strlen(file) + 1 > INFO_PATH)
    {
        fprintf(stderr,"ERR: sharedmem:path too long(%d) max is %d\n",strlen(file), INFO_PATH);
        return NULL;
    }

    if(strlen(file) + 1 > sizeof(path))
    {
        fprintf(stderr,"ERR: sharedmem:path too long(%d) max is %d\n",strlen(file), sizeof(path));
        return NULL;
    }

    if(realpath(file,path) == 0) {
        sprintf(message,"ERR: sharedmem:realpath(%s)",file);
        err = -1;
        goto error;
    }

    if (sharedlib_dbg)
        fprintf(stdout,"sharedmem: sharing %s\n",path);

    if(sizeof(s.st_size) < 8) {
        fprintf(stderr,"ERR: sharedmem:stat.st_size(%d) is too small for 64bits files\n",sizeof(s.st_size));
        return NULL;
    }


    key = ftok(path,1);

    if(key ==  (key_t)-1) {
        sprintf(message,"ERR: sharedmem:ftok(%s)",path);
        err = -1;
        goto error;
    }


    if((sem = semget(key,1,IPC_CREAT|0600)) < 0)
    {
        sprintf(message,"ERR: sharedmem:semget(%s)",path);
        err = -1;
        goto error;
    }

    if (sharedlib_dbg) gettimeofday( &start, NULL );

    if(semop(sem,_lock, 2 ) < 0) {

        sprintf(message,"ERR: sharedmem:semop:lock(%s)",path);
        err = -1;
        goto error;
    }
    locked = 1;

    if (sharedlib_dbg)
	{
        gettimeofday( &end, NULL );

        diff.tv_sec  = end.tv_sec  - start.tv_sec;
        diff.tv_usec = end.tv_usec - start.tv_usec;

        if (diff.tv_usec < 0)
        {
            diff.tv_sec--;
            diff.tv_sec--;
            diff.tv_usec += 1000000;
        }
        time = (double)diff.tv_sec + ((double)diff.tv_usec / 1000000.);

        fprintf( stdout, "sharedmem:semop:lock wait %g secs\n", time);
    }

    if((fd = open(path,O_RDONLY | O_LARGEFILE))  < 0)
    {

        sprintf(message,"ERR: sharedmem:open(%s)",path);
        err = -1;
        goto error;
    }

    if(stat64(path,&s)) 
    {
        sprintf(message,"ERR: sharedmem:stat(%s)",path);
        err = -1;
        goto error;
    }

    if((shmid = shmget(key, ((s.st_size + page_size-1)/page_size)*page_size + sizeof(struct info),IPC_CREAT|0600)) < 0)
    {
        sprintf(message,"ERR: sharedmem:shmget(%s)",path);
        err = -1;
        goto error;
    }


#ifdef SHM_PAGESIZE
    {

    /* Use 64K pages to back the shared memory region */
    size_t shm_size;
    struct shmid_ds shm_buf = { 0 };
    psize_t psize_64k;
    psize_64k = 64 * 1024;

    shm_buf.shm_pagesize = psize_64k;
    if (shmctl(shmid, SHM_PAGESIZE, &shm_buf))
    {
        /*perror("shmctl(SHM_PAGESIZE) failed");*/
    }
    }

#endif

    /* attach shared memory */

    ptr = shmat( shmid, NULL, 0 );
    if (ptr == (void*)-1) {
        sprintf(message,"sharedmem:shmget(%s)",path);
        err = -1;
        goto error;
    }

    addr = (char*)ptr;
    nfo  = (struct info*)(addr + (((s.st_size + page_size-1)/page_size)*page_size));

    if(nfo->ready) { 
        loadfile = 0;
        if(nfo->magic != MAGIC)
        {
            sprintf(message,"ERR: sharedmem:check: bad magic %d\n",nfo->magic);
            err = -1;
            goto error;
        }

        if(strcmp(nfo->path,path) != 0)
        {
            sprintf(message,"ERR: sharedmem:check: invalid path [%s]\n",nfo->path);
            err = -1;
            goto error;
        }
    }


    if(loadfile) {

        s2.st_size =  s.st_size;

        if (sharedlib_dbg) gettimeofday( &start, NULL );

        while(s.st_size > 0) 
        {
            size_t len = s.st_size > BUFFER_SIZE ? BUFFER_SIZE : s.st_size;
            if(read(fd, addr, len) != len) {
                sprintf(message,"ERR: sharedmem:read(%s)",path);
                err = -1;
                goto error;
            }
            s.st_size -= len;
            addr      += len;
        }

        if (sharedlib_dbg)
    	{
            gettimeofday( &end, NULL );

            diff.tv_sec  = end.tv_sec  - start.tv_sec;
            diff.tv_usec = end.tv_usec - start.tv_usec;

            if (diff.tv_usec < 0)
            {
                diff.tv_sec--;
                diff.tv_usec += 1000000;
            }
            time = (double)diff.tv_sec + ((double)diff.tv_usec / 1000000.);

            fprintf( stdout, "sharedmem:read %lld bytes in %g secs\n",s2.st_size, time);
        }

        nfo->magic = MAGIC;
        strcpy(nfo->path,path);
        nfo->ready = 1;
    }
    else
    {
        if (sharedlib_dbg)
        fprintf( stdout, "sharedmem:read file already loaded\n");
    }

    close(fd);

error:

    if(fd>0) {
        close(fd);
    }

    if(err) {
        perror(message);
        if(ptr) shmdt(ptr);
        ptr = NULL;

    }
    if(locked) {
        if(semop(sem,_unlock,1) < 0)
        {
            sprintf(message,"ERR: sharedmem:semop:lock(%s)",path);
            perror(message);
        }

    }

    return ptr;
}

void release_shared_file(void *ptr) {
    shmdt(ptr);
}

int remove_shared_file(const char* file) {
    char message[1024];
    char path[1024];
    int err = 0;
    int shmid = 0;
    key_t key;
    int sem;

    if(strlen(file) + 1 > sizeof(path))
    {
        fprintf(stderr,"ERR: sharedmem:path too long(%d) max is %d\n",strlen(file), sizeof(path));
        return -1;
    }

    if(realpath(file,path) == 0) {
        sprintf(message,"ERR: sharedmem:realpath(%s)",file);
        err = -1;
        goto error;
    }

    key  = ftok(path,1);

    if(key ==  (key_t)-1) {
        sprintf(message,"ERR: sharedmem:ftok(%s)",path);
        err = -1;
        goto error;
    }

    if((shmid = shmget(key,0,0600)) < 0)
    {
        sprintf(message,"ERR: sharedmem:shmget(%s)",path);
        err = -1;
        goto error;
    }

    fprintf(stdout,"sharedmem: removing shared memory for %s\n",path);

    if(shmctl(shmid, IPC_RMID, NULL) < 0) 
    {
        sprintf(message,"ERR: sharedmem:shmctl:IPC_RMID,(%s)",path);
        err = -1;
        goto error;
    }


    if((sem = semget(key,1,0600)) <0 )
    {
        sprintf(message,"ERR: sharedmem:semget(%s)",path);
        err = -1;
        goto error;
    }

    if(semctl(sem, 0, IPC_RMID, NULL) < 0) 
    {
        sprintf(message,"ERR: sharedmem:semctl:IPC_RMID,(%s)",path);
        err = -1;
        goto error;
    }

    

error:



    if(err) 
        perror(message);

    return err;

}
