/*
 * grdxt.c - Three-levels Generic Radix-tree implementation.
 * 
 * authors  Alain Greiner (2016,2017,2018,2019))
 *
 * Copyright (c)  UPMC Sorbonne Universites
 * 
 * This file is part of ALMOS-MKH.
 *
 * ALMOS-MKH is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation; version 2.0 of the License.
 *
 * ALMOS-MKH is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with ALMOS-MKH; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include <hal_kernel_types.h>
#include <hal_special.h>
#include <hal_remote.h>
#include <errno.h>
#include <printk.h>
#include <kmem.h>
#include <grdxt.h>

////////////////////////////////////////////////////////////////////////////////////////
//               Local access functions
////////////////////////////////////////////////////////////////////////////////////////

/////////////////////////////////
error_t grdxt_init( grdxt_t * rt,
                    uint32_t  ix1_width,
                    uint32_t  ix2_width,
                    uint32_t  ix3_width )
{
    void      ** root;
	kmem_req_t   req;
  
	rt->ix1_width = ix1_width;
	rt->ix2_width = ix2_width;
	rt->ix3_width = ix3_width;

    // allocates first level array
	req.type  = KMEM_KCM;
	req.order = ix1_width + ( (sizeof(void*) == 4) ? 2 : 3 );
	req.flags = AF_KERNEL | AF_ZERO;
	root = kmem_alloc( &req );

	if( root == NULL )
    {
        printk("\n[ERROR] in %s : cannot allocate first level array\n", __FUNCTION__);
        return -1;
    }
  
	rt->root = root;

	return 0;

}  // end grdxt_init()

//////////////////////////////////
void grdxt_destroy( grdxt_t * rt )
{
	kmem_req_t req;

    uint32_t   w1 = rt->ix1_width;
    uint32_t   w2 = rt->ix2_width;
    uint32_t   w3 = rt->ix3_width;

    void    ** ptr1 = rt->root;
    void    ** ptr2;
    void    ** ptr3;

	uint32_t   ix1;
	uint32_t   ix2;
	uint32_t   ix3;

assert( (rt != NULL) , "pointer on radix tree is NULL\n" );

	for( ix1=0 ; ix1 < (uint32_t)(1 << w1) ; ix1++ )
	{
        ptr2 = ptr1[ix1];

		if( ptr2 == NULL ) continue;

        for( ix2=0 ; ix2 < (uint32_t)(1 << w2) ; ix2++ )
        {
            ptr3 = ptr2[ix2];

		    if( ptr3 == NULL ) continue;

            for( ix3=0 ; ix3 < (uint32_t)(1 << w3) ; ix3++ )
            {
                 if( ptr3[ix3] != NULL )
                 {
                     printk("\n[WARNING] in %s : ptr3[%d][%d][%d] non empty\n",
                     __FUNCTION__, ix1, ix2, ix3 );
                 }
            }

            // release level 3 array 
            req.type = KMEM_KCM;
		    req.ptr  = ptr3;
		    kmem_free( &req );
        }

        // release level 2 array
        req.type = KMEM_KCM;
		req.ptr  = ptr2;
		kmem_free( &req );
    }

    // release level 1 array
    req.type = KMEM_KCM;
	req.ptr  = ptr1;
	kmem_free( &req );

}  // end grdxt_destroy()

////////////////////////////////////
error_t grdxt_insert( grdxt_t  * rt,
                      uint32_t   key,
                      void     * value )
{
	kmem_req_t      req;

    uint32_t        w1 = rt->ix1_width;
    uint32_t        w2 = rt->ix2_width;
    uint32_t        w3 = rt->ix3_width;

// Check key value
assert( ((key >> (w1 + w2 + w3)) == 0 ), "illegal key value %x\n", key );

    // compute indexes
    uint32_t        ix1 = key >> (w2 + w3);              // index in level 1 array
	uint32_t        ix2 = (key >> w3) & ((1 << w2) -1);  // index in level 2 array
	uint32_t        ix3 = key & ((1 << w3) - 1);         // index in level 3 array

    // get ptr1
    void ** ptr1 = rt->root;

    if( ptr1 == NULL ) return -1;

    // get ptr2
	void ** ptr2 = ptr1[ix1];

    // If required, allocate memory for the missing level 2 array
	if( ptr2 == NULL )
	{
        // allocate memory for level 2 array
        req.type  = KMEM_KCM;
        req.order = w2 + ( (sizeof(void*) == 4) ? 2 : 3 );
        req.flags = AF_KERNEL | AF_ZERO;
        ptr2 = kmem_alloc( &req );

        if( ptr2 == NULL) return -1;

        // update level 1 array
        ptr1[ix1] = ptr2;
	}

    // get ptr3
	void ** ptr3 = ptr2[ix2];

    // If required, allocate memory for the missing level 3 array
	if( ptr3 == NULL )
	{
        // allocate memory for level 3 array
        req.type = KMEM_KCM;
        req.order = w3 + ( (sizeof(void*) == 4) ? 2 : 3 );
        req.flags = AF_KERNEL | AF_ZERO;
        ptr3 = kmem_alloc( &req );

        if( ptr3 == NULL) return -1;

        //  update level 3 array
		ptr2[ix2] = ptr3;
	}

    // register the value
	ptr3[ix3] = value;

	hal_fence();

	return 0;

}  // end grdxt_insert()

///////////////////////////////////
void * grdxt_remove( grdxt_t  * rt,
                     uint32_t   key )
{
    uint32_t        w1 = rt->ix1_width;
    uint32_t        w2 = rt->ix2_width;
    uint32_t        w3 = rt->ix3_width;

// Check key value
assert( ((key >> (w1 + w2 + w3)) == 0 ), "illegal key value %x\n", key );

    // compute indexes
    uint32_t        ix1 = key >> (w2 + w3);              // index in level 1 array
	uint32_t        ix2 = (key >> w3) & ((1 << w2) -1);  // index in level 2 array
	uint32_t        ix3 = key & ((1 << w3) - 1);         // index in level 3 array

    // get ptr1
    void ** ptr1 = rt->root;

    if( ptr1 == NULL ) return NULL;

    // get ptr2
	void ** ptr2 = ptr1[ix1];

	if( ptr2 == NULL ) return NULL;

    // get ptr3
	void ** ptr3 = ptr2[ix2];

	if( ptr3 == NULL ) return NULL;

    // get value
	void * value = ptr3[ix3];

    // reset selected slot
	ptr3[ix3] = NULL;
	hal_fence();

	return value;

}  // end grdxt_remove()

///////////////////////////////////
void * grdxt_lookup( grdxt_t  * rt,
                     uint32_t   key )
{
    uint32_t        w1 = rt->ix1_width;
    uint32_t        w2 = rt->ix2_width;
    uint32_t        w3 = rt->ix3_width;

// Check key value
assert( ((key >> (w1 + w2 + w3)) == 0 ), "illegal key value %x\n", key );

    void         ** ptr1 = rt->root;
    void         ** ptr2;
    void         ** ptr3;

    // compute indexes
    uint32_t        ix1 = key >> (w2 + w3);              // index in level 1 array
	uint32_t        ix2 = (key >> w3) & ((1 << w2) -1);  // index in level 2 array
	uint32_t        ix3 = key & ((1 << w3) - 1);         // index in level 3 array

    // get ptr2
	ptr2 = ptr1[ix1];
	if( ptr2 == NULL ) return NULL;

    // get ptr3
	ptr3 = ptr2[ix2];
	if( ptr3 == NULL ) return NULL;

    // get value
	void * value = ptr3[ix3];

	return value;

}  // end grdxt_lookup()

//////////////////////////////////////
void * grdxt_get_first( grdxt_t  * rt,
                        uint32_t   start_key,
                        uint32_t * found_key )
{
    uint32_t        ix1;
    uint32_t        ix2;
    uint32_t        ix3;

    uint32_t        w1 = rt->ix1_width;
    uint32_t        w2 = rt->ix2_width;
    uint32_t        w3 = rt->ix3_width;

// Check key value
assert( ((start_key >> (w1 + w2 + w3)) == 0 ), "illegal key value %x\n", start_key );

    // compute max indexes
    uint32_t        max1 = 1 << w1;
    uint32_t        max2 = 1 << w2;
    uint32_t        max3 = 1 << w3;

    // compute min indexes
    uint32_t        min1 = start_key >> (w2 + w3);            
	uint32_t        min2 = (start_key >> w3) & ((1 << w2) -1);
	uint32_t        min3 = start_key & ((1 << w3) - 1);  

    void         ** ptr1 = rt->root;
    void         ** ptr2;
    void         ** ptr3;

    for( ix1 = min1 ; ix1 < max1 ; ix1++ )
    {
        ptr2 = ptr1[ix1];
        if( ptr2 == NULL ) continue;

        for( ix2 = min2 ; ix2 < max2 ; ix2++ )
        {
            ptr3 = ptr2[ix2];
            if( ptr3 == NULL ) continue;

            for( ix3 = min3 ; ix3 < max3 ; ix3++ )
            {
                if( ptr3[ix3] == NULL ) continue;
                else                    
                {
                    *found_key = (ix1 << (w2+w3)) | (ix2 << w3) | ix3;
                    return ptr3[ix3];
                }
            }
        }
    }

    return NULL;

}  // end grdxt_get_first()



////////////////////////////////////////////////////////////////////////////////////////
//               Remote access functions
////////////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////////////
error_t grdxt_remote_insert( xptr_t     rt_xp,
                             uint32_t   key,
                             void     * value )
{
    kmem_req_t  req;

    // get cluster and local pointer on remote rt descriptor
	cxy_t     rt_cxy = GET_CXY( rt_xp );
    grdxt_t * rt_ptr = GET_PTR( rt_xp );

#if DEBUG_GRDXT_INSERT
uint32_t cycle = (uint32_t)hal_get_cycles();
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] enter / rt_xp (%x,%x) / key %x / value %x\n",
__FUNCTION__, rt_cxy, rt_ptr, key, (intptr_t)value ); 
#endif

    // get widths
    uint32_t        w1 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix1_width ) );
    uint32_t        w2 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix2_width ) );
    uint32_t        w3 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix3_width ) );

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] get widths : w1 %d / w2 %d / w3 %d\n",
__FUNCTION__, w1, w2, w3 ); 
#endif

// Check key value
assert( ((key >> (w1 + w2 + w3)) == 0 ), "illegal key value %x\n", key );

    // compute indexes
    uint32_t        ix1 = key >> (w2 + w3);              // index in level 1 array
	uint32_t        ix2 = (key >> w3) & ((1 << w2) -1);  // index in level 2 array
	uint32_t        ix3 = key & ((1 << w3) - 1);         // index in level 3 array

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] compute indexes : ix1 %d / ix2 %d / ix3 %d\n",
__FUNCTION__, ix1, ix2, ix3 ); 
#endif

    // get ptr1
    void ** ptr1 = hal_remote_lpt( XPTR( rt_cxy , &rt_ptr->root ) );

    if( ptr1 == NULL ) return -1;

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] compute ptr1 = %x\n",
__FUNCTION__, (intptr_t)ptr1 ); 
#endif

    // get ptr2
    void ** ptr2 = hal_remote_lpt( XPTR( rt_cxy , &ptr1[ix1] ) );

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] get current ptr2 = %x\n",
__FUNCTION__, (intptr_t)ptr2 ); 
#endif

    // allocate memory for the missing level_2 array if required
    if( ptr2 == NULL )
    {
        // allocate memory in remote cluster 
        req.type  = KMEM_KCM;
        req.order = w2 + ((sizeof(void*) == 4) ? 2 : 3 );
        req.flags = AF_ZERO | AF_KERNEL;
        ptr2 = kmem_remote_alloc( rt_cxy , &req );

        if( ptr2 == NULL ) return -1;
        
        // update level_1 entry
        hal_remote_spt( XPTR( rt_cxy , &ptr1[ix1] ) , ptr2 );

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] update ptr1[%d] : &ptr1[%d] = %x / ptr2 = %x\n",
__FUNCTION__, ix1, ix1, &ptr1[ix1], ptr2 );
#endif

    }

    // get ptr3
    void ** ptr3 = hal_remote_lpt( XPTR( rt_cxy , &ptr2[ix2] ) );

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] get current ptr3 = %x\n",
__FUNCTION__, (intptr_t)ptr3 ); 
#endif

    // allocate memory for the missing level_3 array if required
    if( ptr3 == NULL )
    {
        // allocate memory in remote cluster
        req.type  = KMEM_KCM;
        req.order = w3 + ((sizeof(void*) == 4) ? 2 : 3 );
        req.flags = AF_ZERO | AF_KERNEL;
        ptr3 = kmem_remote_alloc( rt_cxy , &req );

        if( ptr3 == NULL ) return -1;

        // update level_2 entry
        hal_remote_spt( XPTR( rt_cxy , &ptr2[ix2] ) , ptr3 );

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] update  ptr2[%d] : &ptr2[%d] %x / ptr3 %x\n",
__FUNCTION__, ix2, ix2, &ptr2[ix2], ptr3 );
#endif

    }

    // register value in level_3 array
    hal_remote_spt( XPTR( rt_cxy , &ptr3[ix3] ) , value );

#if DEBUG_GRDXT_INSERT
if(DEBUG_GRDXT_INSERT < cycle)
printk("\n[%s] update  ptr3[%d] : &ptr3[%d] %x / value %x\n",
__FUNCTION__, ix3, ix3, &ptr3[ix3], value );
#endif

    hal_fence();

	return 0;

}  // end grdxt_remote_insert()

////////////////////////////////////////////
void * grdxt_remote_remove( xptr_t    rt_xp,
                            uint32_t  key )
{
    // get cluster and local pointer on remote rt descriptor
	cxy_t     rt_cxy = GET_CXY( rt_xp );
    grdxt_t * rt_ptr = GET_PTR( rt_xp );

    // get widths
    uint32_t        w1 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix1_width ) );
    uint32_t        w2 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix2_width ) );
    uint32_t        w3 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix3_width ) );

// Check key value
assert( ((key >> (w1 + w2 + w3)) == 0 ), "illegal key value %x\n", key );

    // compute indexes
    uint32_t        ix1 = key >> (w2 + w3);              // index in level 1 array
	uint32_t        ix2 = (key >> w3) & ((1 << w2) -1);  // index in level 2 array
	uint32_t        ix3 = key & ((1 << w3) - 1);         // index in level 3 array

    // get ptr1
    void ** ptr1 = hal_remote_lpt( XPTR( rt_cxy , &rt_ptr->root ) );

    // get ptr2
	void ** ptr2 = hal_remote_lpt( XPTR( rt_cxy , &ptr1[ix1] ) );
	if( ptr2 == NULL ) return NULL;

    // get ptr3
	void ** ptr3 = hal_remote_lpt( XPTR( rt_cxy , &ptr2[ix2] ) );
	if( ptr3 == NULL ) return NULL;

    // get value
	void * value = hal_remote_lpt( XPTR( rt_cxy , &ptr3[ix3] ) );

    // reset selected slot
	hal_remote_spt( XPTR( rt_cxy, &ptr3[ix3] ) , NULL );
	hal_fence();

	return value;

}  // end grdxt_remote_remove()

////////////////////////////////////////////
xptr_t grdxt_remote_lookup( xptr_t    rt_xp,
                            uint32_t  key )
{
    // get cluster and local pointer on remote rt descriptor
    grdxt_t       * rt_ptr = GET_PTR( rt_xp );
    cxy_t           rt_cxy = GET_CXY( rt_xp );

    // get widths
    uint32_t        w1 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix1_width ) );
    uint32_t        w2 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix2_width ) );
    uint32_t        w3 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix3_width ) );

// Check key value
assert( ((key >> (w1 + w2 + w3)) == 0 ), "illegal key value %x\n", key );

    // compute indexes
    uint32_t        ix1 = key >> (w2 + w3);              // index in level 1 array
	uint32_t        ix2 = (key >> w3) & ((1 << w2) -1);  // index in level 2 array
	uint32_t        ix3 = key & ((1 << w3) - 1);         // index in level 3 array

    // get ptr1
    void ** ptr1 = hal_remote_lpt( XPTR( rt_cxy , &rt_ptr->root ) );

    // get ptr2
	void ** ptr2 = hal_remote_lpt( XPTR( rt_cxy , &ptr1[ix1] ) );
	if( ptr2 == NULL ) return XPTR_NULL;

    // get ptr3
	void ** ptr3 = hal_remote_lpt( XPTR( rt_cxy , &ptr2[ix2] ) );
	if( ptr3 == NULL ) return XPTR_NULL;

    // get pointer on registered item
    void  * item_ptr = hal_remote_lpt( XPTR( rt_cxy , &ptr3[ix3] ) );

    // return extended pointer on registered item 
    if ( item_ptr == NULL )  return XPTR_NULL;
	else                     return XPTR( rt_cxy , item_ptr );

}  // end grdxt_remote_lookup()

/////////////////////////i/////////////////
void grdxt_remote_display( xptr_t    rt_xp,
                           char    * name )
{
	uint32_t       ix1;  
	uint32_t       ix2;
	uint32_t       ix3;

    void        ** ptr1;
    void        ** ptr2;
    void        ** ptr3;

// check rt_xp
assert( (rt_xp != XPTR_NULL) , "pointer on radix tree is NULL\n" );

    // get cluster and local pointer on remote rt descriptor
    grdxt_t      * rt_ptr = GET_PTR( rt_xp );
    cxy_t          rt_cxy = GET_CXY( rt_xp );

    // get widths
    uint32_t       w1 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix1_width ) );
    uint32_t       w2 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix2_width ) );
    uint32_t       w3 = hal_remote_l32( XPTR( rt_cxy , &rt_ptr->ix3_width ) );

    ptr1 = hal_remote_lpt( XPTR( rt_cxy , &rt_ptr->root ) );

	printk("\n***** Generic Radix Tree for <%s>\n", name );

	for( ix1=0 ; ix1 < (uint32_t)(1<<w1) ; ix1++ )
	{
	    ptr2 = hal_remote_lpt( XPTR( rt_cxy , &ptr1[ix1] ) );
        if( ptr2 == NULL )  continue;
    
        for( ix2=0 ; ix2 < (uint32_t)(1<<w2) ; ix2++ )
        {
	        ptr3 = hal_remote_lpt( XPTR( rt_cxy , &ptr2[ix2] ) );
            if( ptr3 == NULL ) continue;

            for( ix3=0 ; ix3 < (uint32_t)(1<<w3) ; ix3++ )
            {
                void * value = hal_remote_lpt( XPTR( rt_cxy , &ptr3[ix3] ) );
                if( value == NULL )  continue;

                uint32_t key = (ix1<<(w2+w3)) + (ix2<<w3) + ix3;
                printk(" - key = %x / value = %x / ptr1 = %x / ptr2 = %x / ptr3 = %x\n",
                key, (intptr_t)value, (intptr_t)ptr1, (intptr_t)ptr2, (intptr_t)ptr3 );
            }
        }
	}

} // end grdxt_remote_display()


