/******************************************************************************
 * 
 * rb_stm_holdrel.c
 * 
 * Lock-free red-black trees, based on in-the-clear STM.
 */

#define __SET_IMPLEMENTATION__

#include <stdio.h>
#include <assert.h>
#include <stdlib.h>
#include <unistd.h>
#include <setjmp.h>
#include "portable_defns.h"
#include "gc.h"
#include "stm_holdrel.h"
#include "set.h"

#define IS_BLACK(_v)   ((int_addr_t)(_v)&1)
#define IS_RED(_v)     (!IS_BLACK(_v))
#define MK_BLACK(_v)   ((setval_t)((int_addr_t)(_v)|1))
#define MK_RED(_v)     ((setval_t)((int_addr_t)(_v)&~1))
#define GET_VALUE(_v)  (MK_RED(_v))
#define GET_COLOUR(_v) (IS_BLACK(_v))
#define SET_COLOUR(_v,_c) ((setval_t)((unsigned long)(_v)|(unsigned long)(_c)))

#define R(_o,_f,_t)    ((_t) stm_read_value(stm, (addr_t) &((_o)->_f)))
#define W(_o,_f,_v)    stm_write_value(stm, (addr_t) &((_o)->_f), (word_t)(_v))

#define RMW(_o,_f,_b,_e) do { word_t _b = stm_read_value(stm, (addr_t) &((_o)->_f)); word_t _b2 = (word_t) (_e); if (_b2 != (_b)) stm_write_value(stm, (addr_t) &((_o)->_f), _b2);} while (0)

#define SET_BLACK(_o,_f) RMW(_o,_f,t,MK_BLACK(t))
#define SET_RED(_o,_f)   RMW(_o,_f,t,MK_RED(t))

#define LOG_VAL_PERIOD 16
#define VAL_PERIOD_MASK ((1<<LOG_VAL_PERIOD)-1)
#define VALIDATE(ptst,stm,bail) if (((ptst->op_count++) & VAL_PERIOD_MASK) == 0) { if (!stm_validate(stm)) { stm_abort(stm); siglongjmp(*(bail), 0); } }

typedef struct node_st node_t;
typedef struct node_t set_t;

struct node_st
{
    setkey_t k;
    setval_t v;
    node_t   *l, *r, *p;
};

static node_t null;
static int gc_id;

static void left_rotate(stm_state_t *stm, node_t *x)
{
    node_t *y = R(x,r,node_t*), *p = R(x,p,node_t*);
    node_t *t;

    /* No need to write-lock to update parent link. */
    t = R(y,l,node_t*);
    W(x,r,t);
    if ( t != &null ) W(t,p,x);

    W(x,p,y);
    W(y,l,x);
    W(y,p,p);
    if ( x == R(p,l,node_t*) ) W(p,l,y); else W(p,r,y);
}


static void right_rotate(stm_state_t *stm, node_t *x)
{
    node_t *y = R(x,l,node_t*), *p = R(x,p,node_t*);
    node_t *t;

    /* No need to write-lock to update parent link. */
    t = R(y,r,node_t*);
    W(x,l,t);
    if ( t != &null ) W(t,p,x);

    W(x,p,y);
    W(y,r,x);
    W(y,p,p);
    if ( x == R(p,l,node_t*) ) W(p,l,y); else W(p,r,y);
}


static void delete_fixup(ptst_t *ptst, 
			 stm_state_t *stm, 
			 set_t *s, 
			 node_t *x,
			 sigjmp_buf *jb)
{
    node_t *p, *w;

    while ( (R(x,p,node_t*) != (node_t*) s) && IS_BLACK(R(x,v,setval_t)) )
    {
        VALIDATE(ptst, stm, jb);
        p = R(x,p,node_t*);
        
        if ( x == R(p,l,node_t*) )
        {
  	    node_t *wl, *wr;
            w = R(p,r,node_t*);
            if ( IS_RED(R(w,v,setval_t)) )
            {
	        SET_BLACK(w, v);
  	        SET_RED(p, v);
                /* Node W will be new parent of P. */
                left_rotate(stm, p);
                /* Get new sibling W. */
                w = R(p,r,node_t*);
            }
            
	    wl = R(w,l,node_t*);
	    wr = R(w,r,node_t*);
            if ( IS_BLACK(R(wl, v, setval_t)) && IS_BLACK(R(wr, v, setval_t)) )
            {
                SET_RED(w, v);
                x = p;
            }
            else
            {
                if ( IS_BLACK(R(wr,v,setval_t)) )
                {
                    /* w->l is red => it cannot be null node. */
                    SET_BLACK(wl, v);
                    SET_RED(w, v);
                    right_rotate(stm, w);
                    /* Old w is new w->r. Old w->l is new w.*/
                    w = R(p,r,node_t*);
                }
                
		wr = R(w, r, node_t*);
                W(w,v, SET_COLOUR(GET_VALUE(R(w,v,setval_t)), 
				  GET_COLOUR(R(p,v,setval_t))));
                SET_BLACK(p, v);
                SET_BLACK(wr, v);
                left_rotate(stm, p);
                break;
            }
        }
        else /* SYMMETRIC CASE */
        {
  	    node_t *wl, *wr;
            w = R(p,l,node_t*);
            if ( IS_RED(R(w,v,setval_t)) )
            {
                SET_BLACK(w, v);
                SET_RED(p, v);
                /* Node W will be new parent of P. */
                right_rotate(stm, p);
                /* Get new sibling W. */
                w = R(p, l, node_t*);
            }
            
	    wl = R(w,l,node_t*);
	    wr = R(w,r,node_t*);
            if ( IS_BLACK(R(wl,v,setval_t)) && IS_BLACK(R(wr,v,setval_t)) )
            {
                SET_RED(w, v);
                x = p;
            }
            else
            {
                if ( IS_BLACK(R(wl, v, setval_t)) ) 
                {
                    /* w->r is red => it cannot be the null node. */
                    SET_BLACK(wr, v);
                    SET_RED(w, v);
                    left_rotate(stm, w);
                    /* Old w is new w->l. Old w->r is new w.*/
                    w = R(p, l, setval_t);
                }
                
		wl = R(w, l, node_t*);
                W(w, v, SET_COLOUR(GET_VALUE(R(w, v, setval_t)), 
				   GET_COLOUR(R(p, v, setval_t))));
                SET_BLACK(p, v);
                SET_BLACK(wl, v);
                right_rotate(stm, p);
                break;
            }
        }
    }

    SET_BLACK(x, v);
}


set_t *set_alloc(void)
{
    node_t  *root;

    root = (node_t *) malloc(sizeof (*root));
    root->k = SENTINEL_KEYMIN;
    root->v = MK_RED(NULL);
    root->l = &null;
    root->r = &null;
    root->p = NULL;

    return (set_t *) root;
}

setval_t set_update(set_t *s, setkey_t k, setval_t v, int overwrite)
{
    ptst_t *ptst;
    stm_state_t  *stm;
    node_t  *x, *p, *g, *y, *new = NULL;
    setkey_t xk;
    node_t *alloc = NULL;
    setval_t ov;
    bool_t committed;
    sigjmp_buf *jb;

    k = CALLER_TO_INTERNAL_KEY(k);
    ptst = critical_enter ();
    stm = get_stm_st(ptst);

retry:
    stm_start(stm, &jb);

    x = (node_t *) s;
    xk = R(x,k,setkey_t);
    while ( (y = (k < xk) ? R(x,l,node_t*) : R(x,r,node_t*)) != &null )
    {
        VALIDATE(ptst, stm, jb);
        x = y;
	xk = R(x,k,setkey_t);
        if ( k == xk ) break;
    }
    
    if ( k == xk )
    {
        ov = R(x,v,setval_t);
        /* Lock X to change mapping. */
        if ( overwrite ) W(x,v, SET_COLOUR(v, GET_COLOUR(ov)));
        ov = GET_VALUE(ov);
	new = NULL;
    }
    else
    {
        ov = NULL;

        if (alloc == NULL) {
	  alloc = (node_t *)gc_alloc(ptst, gc_id);
	}
	new = alloc;
	new -> k = k;
	new -> v = MK_RED(v);
	new -> l = &null;
	new -> r = &null;
	new -> p = x;

        /* Lock X to change a child. */
        if ( k < xk ) W(x, l, new); else W(x, r, new);
        x = new;

        /* No locks held here. Colour changes safe. Rotations lock for us. */
        for ( ; ; )
        {
  	    VALIDATE (ptst, stm, jb);
            if ( (p = R(x, p, node_t*)) == (node_t*)s )
            {
                SET_BLACK(x, v);
                break;
            }

            if ( IS_BLACK(R(p, v, setval_t)) ) break;

            g = R(p, p, node_t*);
            if ( p == R(g, l, node_t*) )
            {
 	        node_t *yv;
                y = R(g, r, node_t*);
		yv = R(y, v, setval_t);
                if ( IS_RED(yv) )
                {
                    SET_BLACK(p, v);
                    W(y, v, MK_BLACK(yv));
                    SET_RED(g, v);
                    x = g;
                }
                else
                {
                    if ( x == R(p, r, node_t*) )
                    {
                        x = p;
                        left_rotate(stm, x);
                        /* X and P switched round. */
                        p = R(x, p, node_t*);
                    }
                    SET_BLACK(p, v);
                    SET_RED(g, v);
                    right_rotate(stm, g);
                    /* G no longer on the path. */
                }
            }
            else /* SYMMETRIC CASE */
            {
	        node_t *yv;
                y = R(g, l, node_t*);
		yv = R(y, v, setval_t);
                if ( IS_RED(yv) )
                {
                    SET_BLACK(p, v);
                    W(y, v, MK_BLACK(yv));
                    SET_RED(g, v);
                    x = g;
                }
                else
                {
                    if ( x == R(p, l, node_t*) )
                    {
                        x = p;
                        right_rotate(stm, x);
                        /* X and P switched round. */
                        p = R(x, p, node_t*);
                    }
                    SET_BLACK(p, v);
                    SET_RED(g, v);
                    left_rotate(stm, g);
                    /* G no longer on the path. */
                }
            }
        }
    }

    stm_remove_update (stm, (addr_t) &(null.p));
    committed = stm_commit(stm);
    if (!committed) goto retry;

    if (new == NULL && alloc != NULL) {
      gc_free(ptst, alloc, gc_id);
    }

    critical_exit (ptst);

    return ov;
}


setval_t set_remove(set_t *s, setkey_t k)
{
    ptst_t *ptst;
    stm_state_t  *stm;
    node_t  *x, *y, *z;
    setkey_t zk;
    setval_t ov;
    bool_t committed;
    void *to_free;
    sigjmp_buf *jb;

    k = CALLER_TO_INTERNAL_KEY(k);
    ptst = critical_enter();
    stm = get_stm_st(ptst);

retry:
    stm_start(stm, &jb);
    ov = NULL;
    to_free = NULL;

    z = (node_t *) s;
    zk = R(z,k,setkey_t);
    while ( (z = (k < zk) ? R(z, l, node_t*) : R(z, r, node_t*)) != &null )
    {
        VALIDATE(ptst, stm, jb);
        zk = R(z,k,setkey_t);
        if ( k == zk ) break;
    }

    if ( k == zk )
    {
        node_t *yl;
        node_t *yp;
        ov = GET_VALUE(R(z,v,setval_t));

        if ( (R(z,l,node_t*) != &null) && (R(z,r,node_t*) != &null) )
        {
            /* Lock Z. It will get new key copied in. */
            y = R(z,r,node_t*);
	    yl = R(y,l,node_t*);
	    
            /*
             * Write-lock from Z to Y. We end up with (YP,Y) locked.
             * Write-coupling is needed so we don't overtake searches for Y.
             */
            while ( yl != &null )
            {
	        VALIDATE(ptst, stm, jb);
                y = yl;
		yl = R(y,l,node_t*);
            }
        }
        else
        {
            y = z;
	    yl = R(y,l,node_t*);
        }
	
	yp = R(y,p,node_t*);
        /* No need to lock X. Only parent link is modified. */
        x = (yl != &null) ? yl : R(y,r,node_t*);
        W(x, p, yp);

        if ( y == R(yp,l,node_t*) ) W(yp,l,x); else W(yp,r,x);

        if ( y != z )
        {
            W(z,k,R(y,k,setkey_t));
            W(z, v, SET_COLOUR(GET_VALUE(R(y, v, setval_t)), GET_COLOUR(R(z, v, setval_t))));
        }

	to_free = y;

        if ( IS_BLACK(R(y,v,setval_t)) ) delete_fixup(ptst, stm, s, x, jb);
	stm_remove_update (stm, (addr_t) &(null.p));
    }

    committed = stm_commit(stm);
    if (!committed) goto retry;

    if (to_free != NULL) {
      gc_free(ptst, to_free, gc_id);
    }

    critical_exit(ptst);

    return ov;
}


setval_t set_lookup(set_t *s, setkey_t k)
{
    ptst_t *ptst;
    stm_state_t  *stm;
    node_t  *m, *n;
    setkey_t nk;
    setval_t v = NULL;
    sigjmp_buf *jb;

    k = CALLER_TO_INTERNAL_KEY(k);

    ptst = critical_enter ();
    stm = get_stm_st(ptst);

    do
      {
	stm_start(stm, &jb);
	v = NULL;
	
	n = (node_t *) s;
	nk = R(n,k,setkey_t);
	
	while ( (n = ((k < nk) ? R(n,l,node_t*) : R(n,r,node_t*))) != &null )
	  {
	    VALIDATE (ptst, stm, jb);
	    nk = R(n,k,setkey_t);
	    if ( k == nk )
	      {
		v = GET_VALUE(R(n,v,setval_t));
		break;
	      }
	  }
      }
    while (!stm_commit(stm));

    critical_exit(ptst);

    return v;
}

void _init_set_subsystem(void)
{
  stm_init();
  
  gc_id = gc_add_allocator(sizeof(node_t));

  null.k = 0;
  null.v = MK_BLACK(NULL);
  null.l = NULL;
  null.r = NULL;
  null.p = NULL;
}

void _init_set_per_thread(int id, int t)
{
  for (int i = id ; i < OREC_TABLE_SIZE; i += t) {
    stm_init_cluster(i);
  }
}

void _destroy_set_subsystem(void)
{
}

