/*
 * sys_socket.c - implement the various socket related system calls
 *
 * Author    Alain Greiner (2016,2017,2018,2019,2020)
 *  
 * Copyright (c) UPMC Sorbonne Universites
 *
 * This file is part of ALMOS-MKH.
 *
 * ALMOS-MKH is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation; version 2.0 of the License.
 *
 * ALMOS-MKH is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with ALMOS-MKH; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include <hal_kernel_types.h>
#include <hal_uspace.h>
#include <hal_vmm.h>
#include <errno.h>
#include <vmm.h>
#include <cluster.h>
#include <thread.h>
#include <process.h>
#include <ksocket.h>
#include <string.h>
#include <shared_syscalls.h>
#include <shared_socket.h> 
#include <remote_barrier.h>
#include <vfs.h>
#include <mapper.h>

#include <syscalls.h>

/////////////////////////////////////////////////////////////////////////////////
// This function returns a printable string for the socket related command type.
/////////////////////////////////////////////////////////////////////////////////

#if DEBUG_SYS_SOCKET
static char* socket_cmd_type_str( uint32_t type )
{
    if     ( type == SOCK_CREATE      ) return "CREATE"; 
    else if( type == SOCK_BIND        ) return "BIND"; 
    else if( type == SOCK_LISTEN      ) return "LISTEN"; 
    else if( type == SOCK_CONNECT     ) return "CONNECT"; 
    else if( type == SOCK_ACCEPT      ) return "ACCEPT";
    else if( type == SOCK_SEND        ) return "SEND"; 
    else if( type == SOCK_SENDTO      ) return "SENDTO"; 
    else if( type == SOCK_RECV        ) return "RECV"; 
    else if( type == SOCK_RECVFROM    ) return "RECVFROM"; 
    else                                return "undefined";
}
#endif

/////////////////////////////
int sys_socket( reg_t  arg0,
                reg_t  arg1,
                reg_t  arg2,
                reg_t  arg3 )
{

    int32_t         ret;
    vseg_t        * vseg;

    sockaddr_in_t   k_sockaddr;  // kernel buffer for one socket address

    thread_t      * this    = CURRENT_THREAD;
    process_t     * process = this->process;

    uint32_t        cmd = arg0;

#if (DEBUG_SYS_SOCKET || CONFIG_INSTRUMENTATION_SYSCALLS)
uint64_t     tm_start = hal_get_cycles();
#endif

#if DEBUG_SYS_SOCKET
tm_start = hal_get_cycles();
if( DEBUG_SYS_SOCKET < tm_start )
printk("\n[%s] thread[%x,%x] enter / %s / a1 %x / a2 %x / a3 %x / cycle %d\n",
__FUNCTION__, process->pid, this->trdid, socket_cmd_type_str(cmd), 
arg1, arg2, arg3, (uint32_t)tm_start );
#endif

    switch( cmd )
    {
        /////////////////
        case SOCK_CREATE:
        {
            uint32_t domain = arg1;
            uint32_t type   = arg2;

            if( domain != AF_INET )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for CREATE domain %d =! AF_INET\n",
__FUNCTION__ , domain );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            if( (type != SOCK_DGRAM) && (type != SOCK_STREAM) )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for CREATE : socket must be SOCK_STREAM(TCP) or SOCK_DGRAM(UDP)\n",
__FUNCTION__ );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // call relevant kernel socket function
            ret = socket_build( domain , type );

            if( ret == -1 )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for CREATE : cannot create socket\n",
__FUNCTION__ );
#endif
                this->errno = EINVAL;
            } 
            break;
        }
        ///////////////
        case SOCK_BIND:
        {
            uint32_t        fdid = arg1;
            sockaddr_in_t * u_sockaddr = (sockaddr_in_t *)(intptr_t)arg2;

            // check addr pointer in user space
            if( vmm_get_vseg( process , (intptr_t)arg2 , &vseg ) )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for BIND : address %x unmapped\n",
__FUNCTION__ , (intptr_t)arg2 );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // copy sockaddr structure from uspace to kernel space
            hal_copy_from_uspace( XPTR( local_cxy , &k_sockaddr ),
                                  u_sockaddr, 
                                  sizeof(sockaddr_in_t) );

            // call relevant kernel socket function
	        ret = socket_bind( fdid,
                               k_sockaddr.sin_addr,
                               k_sockaddr.sin_port );

            if( ret )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for BIND : cannot access socket[%x,%d]\n",
__FUNCTION__ , process->pid, fdid );
#endif
                this->errno = EINVAL;
            }
            break;
        }
        /////////////////
        case SOCK_LISTEN:
        {
            uint32_t     fdid        = (uint32_t)arg1;
            uint32_t     max_pending = (uint32_t)arg2;

            // call relevant kernel socket function
	        ret = socket_listen( fdid , max_pending );

            if( ret )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for LISTEN : cannot access socket[%x,%d]\n",
__FUNCTION__ , process->pid, fdid );
#endif
                this->errno = EINVAL;
            }
            break;
        }
        //////////////////
        case SOCK_CONNECT:
        {
            uint32_t        fdid = (uint32_t)arg1;
            sockaddr_in_t * u_sockaddr = (sockaddr_in_t *)(intptr_t)arg2;

            // check addr pointer in user space
            if( vmm_get_vseg( process , (intptr_t)arg2 , &vseg ) )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for CONNECT : server address %x unmapped\n",
__FUNCTION__ , (intptr_t)arg2 );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // copy sockaddr structure from uspace to kernel space
            hal_copy_from_uspace( XPTR( local_cxy , &k_sockaddr ),
                                  u_sockaddr ,
                                  sizeof(sockaddr_in_t) );

            // call relevant kernel function
	        ret = socket_connect( fdid,
                                  k_sockaddr.sin_addr,
                                  k_sockaddr.sin_port );

            if( ret )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for CONNECT : cannot access socket[%x,%d]\n",
__FUNCTION__ , process->pid, fdid );
#endif
                this->errno = EINVAL;
            }
            break;
        }
        /////////////////
        case SOCK_ACCEPT:
        {
            uint32_t        fdid = (uint32_t)arg1;
            sockaddr_in_t * u_sockaddr = (sockaddr_in_t *)(intptr_t)arg2;

            // check addr pointer in user space
            if( vmm_get_vseg( process , (intptr_t)arg2 , &vseg ) )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for CONNECT : server address %x unmapped\n",
__FUNCTION__ , (intptr_t)arg2 );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // call relevant kernel function
            ret = socket_accept( fdid,
                                 &k_sockaddr.sin_addr, 
                                 &k_sockaddr.sin_port );

            if( ret < 0 )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for ACCEPT : cannot access socket[%x,%d]\n",
__FUNCTION__ , process->pid, fdid );
#endif
                this->errno = EINVAL;
            }

            // copy sockaddr structure from kernel space to uspace
            hal_copy_to_uspace( u_sockaddr,
                                XPTR( local_cxy , &k_sockaddr ),
                                sizeof(sockaddr_in_t) );

            break;
        }
        ///////////////
        case SOCK_SEND:
        {
            uint32_t     fdid   = (uint32_t)arg1;
            uint8_t    * u_buf  = (uint8_t *)(intptr_t)arg2;
            uint32_t     length = (uint32_t)arg3;

            // check buffer is mapped in user space
            if( vmm_get_vseg( process , (intptr_t)arg2 , &vseg ) )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for SEND : buffer %x unmapped\n",
__FUNCTION__ , (intptr_t)arg2 );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // check length
            if( length == 0 )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for SEND : buffer length is 0\n",
__FUNCTION__ , (intptr_t)arg2 );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // cal relevant relevant socket function
            ret = socket_send( fdid , u_buf , length );

            if( ret < 0 )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for SEND : cannot access socket[%x,%d] \n",
__FUNCTION__ , process->pid, fdid );
#endif
                this->errno = EINVAL;
            }
            break;
        }
        ///////////////
        case SOCK_RECV:
        {
            uint32_t     fdid   = (uint32_t)arg1;
            uint8_t    * u_buf  = (uint8_t *)(intptr_t)arg2;
            uint32_t     length = (uint32_t)arg3;

            // check buffer is mapped in user space
            if( vmm_get_vseg( process , (intptr_t)arg2 , &vseg ) )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for SEND : buffer %x unmapped\n",
__FUNCTION__ , (intptr_t)arg2 );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // check length
            if( length == 0 )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for SEND : buffer length is 0\n",
__FUNCTION__ , (intptr_t)arg2 );
#endif
                this->errno = EINVAL;
                ret = -1;
                break;
            }

            // cal relevant kernel socket function
            ret =  socket_recv( fdid , u_buf , length );

            if( ret < 0 )
            {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s for RECV : cannot access socket[%x,%d] \n",
__FUNCTION__ , process->pid, fdid );
#endif
                this->errno = EINVAL;
            }
            break;
        }
        ////////
        default: 
        {

#if DEBUG_SYSCALLS_ERROR
printk("\n[ERROR] in %s : undefined socket operation %d\n",
        __FUNCTION__ , cmd );
#endif
            this->errno = EINVAL;
            ret = -1;
            break;
        }
    }  // end switch on cmd

#if (DEBUG_SYS_SOCKET || CONFIG_INSTRUMENTATION_SYSCALLS)
uint64_t     tm_end = hal_get_cycles();
#endif

#if DEBUG_SYS_SOCKET
if( DEBUG_SYS_SOCKET < tm_end )
printk("\n[%s] thread[%x,%x] exit / cycle %d\n",
__FUNCTION__, process->pid, this->trdid, (uint32_t)tm_end );
#endif

#if CONFIG_INSTRUMENTATION_SYSCALLS
hal_atomic_add( &syscalls_cumul_cost[SYS_SOCKET] , tm_end - tm_start );
hal_atomic_add( &syscalls_occurences[SYS_SOCKET] , 1 );
#endif

    return ret;

}  // end sys_socket()
