#include <pthread.h>
#include <stdlib.h>
#include <string.h>

#include "msg.h"

#define N 2

struct port {
  struct msg messages[N];
  pthread_mutex_t m;
  pthread_cond_t empty;
  pthread_cond_t full;
  int count;
  int head;
  int tail;
};

void msg_send(struct port *p, struct msg *m)
{
  pthread_mutex_lock(&p->m);
  while (p->count == N) {
    pthread_cond_wait(&p->full, &p->m);
  }
  memcpy(&p->messages[p->tail], m, sizeof(struct msg));
  p->tail = (p->tail + 1) % N;
  p->count = p->count + 1;
  pthread_cond_signal(&p->empty);
  pthread_mutex_unlock(&p->m);
}

void msg_receive(struct port *p, struct msg *m)
{
  pthread_mutex_lock(&p->m);
  while (p->count == 0) {
    pthread_cond_wait(&p->empty, &p->m);
  }
  memcpy(m, &p->messages[p->head], sizeof(struct msg));
  p->head = (p->head + 1) % N;
  p->count = p->count - 1;
  pthread_cond_signal(&p->full);
  pthread_mutex_unlock(&p->m);
}

struct port *port_init(void)
{
  struct port *p;

  p = malloc(sizeof(struct port));
  if (p == NULL) {
    return NULL;
  }

  p->head = 0;
  p->tail = 0;
  p->count = 0;
  pthread_mutex_init(&p->m, NULL);
  pthread_cond_init(&p->empty, NULL);
  pthread_cond_init(&p->full, NULL);

  return p;
}