Simple coroutine library / objects in plain C Snapshot
cthread.c
Go to the documentation of this file.
00001 #include "cthread.h"
00002 #include <pthread.h>
00003 #include <stdio.h>
00004 
00005 #ifdef NO_TLS
00006 static pthread_key_t tls_key;
00007 #else
00008 __thread CTHREAD *tls_thread;
00009 #endif
00010 
00011 #ifdef NO_TLS
00012 #define GET_TLS()  (CTHREAD *) pthread_getspecific( tls_key );
00013 #define SET_TLS(val)  (CTHREAD *) pthread_setspecific( tls_key, val );
00014 #else 
00015 #define GET_TLS()  tls_thread
00016 #define SET_TLS(val) do { tls_thread = val; } while( 0 );
00017 #endif
00018 
00019 
00020 static uint32_t next_tid;
00021 
00022 int CTHREAD_libinit()
00023 {
00024 #ifdef NO_TLS
00025   pthread_key_create( &tls_key, 0);
00026 #endif
00027   return 0;
00028 }
00029 
00030 void cthread_init( CTHREAD *arg )
00031 {
00032   arg->proc( &arg->caller_to_thread_value );
00033   
00034   arg->state = CTHREAD_STATE_EXIT; 
00035 
00036   //fprintf( stderr,"relase stack %p\n",arg->stack_entry->stack_start);
00037 
00038 
00039   STACKS_release( arg->stack_entry );
00040 
00041 #if 0
00042   setcontext( &arg->context_caller );
00043 #endif
00044 }
00045 
00046 typedef void (*MK_CTX_FUNC) (void);
00047 
00048 CTHREAD * CTHREAD_init( STACKS *stacks, CTHREAD_PROC proc )
00049 {
00050   void *stack;
00051   CTHREAD *ret;
00052   int stage = 0;
00053 
00054   ret = (CTHREAD *) malloc( sizeof(CTHREAD) );
00055   if (!ret) {
00056     return 0;
00057   }
00058   stage = 1;
00059 
00060   if (VALUES_init(&ret->thread_to_caller_value)) {
00061     goto err;
00062   } 
00063   stage = 2;
00064 
00065   if (VALUES_init(&ret->caller_to_thread_value)) {
00066     goto err;
00067   } 
00068   stage = 3;
00069 
00070   stack = STACKS_get( stacks, &ret->stack_entry );
00071   if (!stack) {
00072     goto err;
00073   }
00074   stage = 4;
00075 
00076   ret->proc = proc;
00077 
00078   ret->prev_thread = 0;
00079 
00080   ret->thread_id = -1;   
00081   ret->state = CTHREAD_STATE_INIT;
00082 
00083   ret->caller_to_thread_value_set = ret->thread_to_caller_value_set = 0;
00084  
00085   if (getcontext( &ret->context_coroutine )) {
00086     if (ret->state == CTHREAD_STATE_INIT) {
00087       goto err;
00088     }
00089     return 0;
00090   }
00091 
00092   //fprintf( stderr,"thread stack start %p\n",stack);
00093 
00094   ret->context_coroutine.uc_stack.ss_sp = stack;
00095   ret->context_coroutine.uc_stack.ss_size = STACKS_get_stack_size( stacks );
00096 
00097 
00098   return ret;
00099 
00100 err:
00101   if (stage > 3) {
00102     STACKS_release(ret->stack_entry);
00103   }
00104   if (stage > 2) {
00105     VALUES_free(&ret->caller_to_thread_value);
00106   }
00107    if (stage > 1) {
00108     VALUES_free(&ret->thread_to_caller_value);
00109   }
00110   if (stage > 0) {
00111     free(ret);
00112   }
00113   return 0;
00114 }
00115 
00116 static int do_start( CTHREAD *thread )
00117 {
00118   thread->state = CTHREAD_STATE_RUNNING;
00119   
00120   thread->thread_id = ++next_tid;
00121   
00122   thread->prev_thread = GET_TLS();  
00123   SET_TLS( thread );
00124 
00125 #if 0
00126   thread->context_coroutine.uc_link = 0;
00127 #else 
00128   thread->context_coroutine.uc_link = &thread->context_caller;
00129 #endif
00130   makecontext( & thread->context_coroutine, (MK_CTX_FUNC) cthread_init, 1, thread);
00131   
00132   setcontext( &thread->context_coroutine );
00133 
00134  
00135   // should not get here.
00136   return -1;
00137 }
00138 
00139 int CTHREAD_start( CTHREAD *thread, VALUES **rvalue, const char *format , ... )
00140 {
00141   if (thread->state != CTHREAD_STATE_INIT) {
00142     return -1;
00143   }
00144   getcontext( &thread->context_caller );
00145   if (thread->state == CTHREAD_STATE_RUNNING) {
00146     // got here when the running thread called yield for the first time, or exited without calling yield.
00147     thread->state = CTHREAD_STATE_SUSPENDED;
00148     if (rvalue) {
00149      if (thread->thread_to_caller_value_set) {
00150       *rvalue = &thread->thread_to_caller_value; 
00151      } else {
00152        *rvalue = 0;
00153      }
00154     }
00155     return 0;
00156   }
00157   if (thread->state == CTHREAD_STATE_EXIT) {
00158     // got here from thread that has exited.
00159     if (rvalue) {
00160      if (thread->thread_to_caller_value_set) {
00161        *rvalue = &thread->thread_to_caller_value; 
00162      } else {
00163        *rvalue = 0;
00164      }
00165     }
00166     return 0;
00167   }
00168 
00169   thread->caller_to_thread_value_set = 0;
00170   if (format) {
00171     va_list vlist;
00172     
00173     va_start( vlist, format );
00174     
00175     if (VALUES_printv( &thread->caller_to_thread_value, format, vlist ) ) {
00176       return -1;
00177     }
00178 
00179     thread->caller_to_thread_value_set = 1;
00180   } 
00181   
00182 
00183   return do_start( thread );
00184 }
00185 
00186 
00187 
00188 
00189 int CTHREAD_join( CTHREAD *thread, VALUES **rvalue )
00190 {  
00191    if( thread->state == CTHREAD_STATE_INIT ||
00192        thread->state == CTHREAD_STATE_RUNNING) {
00193      return -1;
00194    }
00195    while(thread->state != CTHREAD_STATE_EXIT) {
00196      CTHREAD_resume( thread, rvalue, 0 );
00197    }
00198    return 0;
00199 }
00200 
00201 int CTHREAD_resume( CTHREAD *thread, VALUES **rvalue, const char *format, ... )
00202 {
00203   if (thread->state != CTHREAD_STATE_SUSPENDED) {
00204     return -1;
00205   }
00206 
00207   getcontext( &thread->context_caller );
00208   if (thread->state == CTHREAD_STATE_RUNNING) {
00209     // got here when the running thread called yield for the first time, or exited without calling yield.
00210     thread->state = CTHREAD_STATE_SUSPENDED;
00211     if (rvalue) {
00212      if (thread->thread_to_caller_value_set) {
00213        *rvalue = &thread->thread_to_caller_value; 
00214      } else {
00215        *rvalue = 0;
00216      }
00217     }
00218     return 0;
00219   }
00220   if (thread->state == CTHREAD_STATE_EXIT) {
00221     // got here from thread that has exited.
00222     if (rvalue) {
00223      if (thread->thread_to_caller_value_set) {
00224        *rvalue = &thread->thread_to_caller_value; 
00225      } else {
00226        *rvalue = 0;
00227      }
00228     }
00229     return 0;
00230   }
00231 
00232   thread->caller_to_thread_value_set = 0;
00233   if (format) {
00234     va_list vlist;
00235     
00236     va_start( vlist, format );
00237     
00238     if (VALUES_printv( &thread->caller_to_thread_value, format, vlist ) ) {
00239       return -1;
00240     }
00241     thread->caller_to_thread_value_set = 1;
00242   } 
00243 
00244   SET_TLS( thread );
00245   
00246   return setcontext( &thread->context_coroutine );
00247 }
00248 
00249 
00250 int CTHREAD_yield(VALUES **rvalue, const char *format, ... )
00251 {
00252   CTHREAD *thread;
00253   
00254   thread = GET_TLS();
00255 
00256   if (!thread) {
00257     return -1;
00258   }
00259 
00260   if (thread->state != CTHREAD_STATE_RUNNING) {
00261       return -1;
00262   }
00263 
00264   getcontext( &thread->context_coroutine );
00265   if (thread->state != CTHREAD_STATE_RUNNING) {
00266     // got here when this thread was resumed.
00267     thread->state = CTHREAD_STATE_RUNNING;
00268     if (rvalue) {
00269      if (thread->caller_to_thread_value_set) {
00270        *rvalue = &thread->caller_to_thread_value; 
00271      } else {
00272        *rvalue = 0;
00273      }
00274     }
00275     return 0;
00276   }
00277  
00278   thread->thread_to_caller_value_set = 0;
00279   if (format) {
00280     va_list vlist;
00281     
00282     va_start( vlist, format );
00283     
00284     if (VALUES_printv( &thread->thread_to_caller_value, format, vlist ) ) {
00285       return -1;
00286     }
00287     thread->thread_to_caller_value_set = 1;
00288   }
00289 
00290   SET_TLS( thread->prev_thread );
00291  
00292   return setcontext( &thread->context_caller );
00293 }
00294 
00295 
00296 int CTHREAD_free( CTHREAD *thread )
00297 {
00298   if (thread->state != CTHREAD_STATE_EXIT) {
00299     return -1;
00300   }
00301   free(thread);
00302   return 0;
00303 }
00304 
00305 uint32_t CTHREAD_get_tid()
00306 {
00307   CTHREAD *thread;
00308  
00309   thread = GET_TLS();
00310   if (!thread) {
00311     return (uint32_t) -1;
00312   }
00313   return thread->thread_id;
00314 }
00315 
00316 int CTHREAD_set_return_value( const char *format, ... )
00317 {
00318   CTHREAD *thread;
00319   va_list vlist;
00320   
00321   thread = GET_TLS();
00322   if (!thread) {
00323    return (uint32_t) -1;
00324   }
00325 
00326   va_start( vlist, format );
00327   
00328   thread->thread_to_caller_value_set = 0;
00329   if (VALUES_printv( &thread->thread_to_caller_value, format, vlist )) {
00330     thread->thread_to_caller_value_set = 0;
00331     return -1;
00332   }
00333 
00334   thread->thread_to_caller_value_set = 1;
00335   return 0;
00336 }
00337 
00338 
00339 
00340