Plan 9 from Bell Labs’s /usr/web/sources/contrib/axel/8021x/v214/ttls.c

Copyright © 2021 Plan 9 Foundation.
Distributed under the MIT License.
Download the Plan 9 distribution.


#include <u.h>
#include <libc.h>
#include <thread.h>
#include <ip.h>
#include <libsec.h>
#include "dat.h"
#include "fns.h"


typedef struct TTLS {
	uchar tp;
	uchar flags;
	uchar tln[4];	//optional, present if L flag set
} TTLS;

enum {
	TtlsFlagL		= 1<<7,	// header contains tln field
	TtlsFlagM		= 1<<6,	// more fragment(s) will follow for current msg
	TtlsFlagS		= 1<<5,	// start of tls session
	TtlsVersion	= (1<<2)|(1<<1)|(1<<0),

	TtlsShortHlen	= 2,	// without tln field
	TtlsLongHlen	= TtlsShortHlen+4, // with tln field

	// TTLS state
	Idle = 0,
	Start,
	Waiting,
	Timeout,
	Sending,
	RecvAck,
	Receiving,
	SendAck,
	Received,
	Cleanup,
};

char *snames[] = {
[Idle]			"Idle",
[Start]		"Start",
[Waiting]		"Waiting",
[Timeout]		"Timeout",
[Sending]		"Sending",
[RecvAck]		"RecvAck",
[Receiving]	"Receiving",
[SendAck]		"SendAck",
[Received]		"Received",
[Cleanup]		"Cleanup",
};

typedef struct TTLSstate {
	TLSconn conn;	// our handle to the tls connection
	int p[2];	// double pipe over which we talk with our tls
			// the stuff we read from it has to be fragmented, encapsulated and sent
			// the fragments we receive have to be reassembled and then written to it
	Channel *ctidc;		// client thread id, client about to exit will confirm here
	Channel *readc;	// contains pointer to Buf containing last msg read from p
	Channel *eofc;		// confirm eof on p
	Channel *cstarterc;	// start new clientclient session
	Channel *rstarterc;	// start new readproc session
	Channel *tickc;

	int fd;		// before tlsClient call: -2; after return: -1 (error) or >= 0 (ok)
	int ctid;	// clientproc thread id
	int rtid;	// readproc thread id
	int cStarted; // did tlsClient start?
	int cReturned; // have we already received from clientproc?

	Timer *ttlsWhile;
	Timer *cleanupWhile;
	int ttlsPeriod;
	int cleanupPeriod;


	int txLen;	// length of frame we prepared for sending
	int done;	// done processing the frame (and, if needed, preparing the response)?
	int state;	// ttls state we are in
	int more;	// was M flag set in previous packet received?
	uint version;

	Buf *rbuf[Nbuf];	// buffers read from p, to be sent/fragmented over ether
	int ridx;	// index of first free rbuf
	Buf *ptoe;	// current pipe-to-ether buf (one of rbuf)

	Buf *wbuf;	// buffer read/ reassembled  from ether, to be written to p
	Buf *etop;	// current ether-to-pipe buf (wbuf)

	Thumbprint *thumbTable;

	int inuse;		// is cleanup needed at all?

	uchar*theSessionCert;
	int theSessionCertlen;
	uchar* theSessionID;
	int theSessionIDlen;

} TTLSstate;

static TTLSstate theTTLSstate;

static void readproc(void*);
static void clientproc(void*);
static void cleanup(TTLSstate*);
static void setupTls(TTLSstate*);
static int buildFrameStart(TTLSstate*, Packet*);
static int buildFrameMiddle(TTLSstate*, Packet*);
static int buildMsg(TTLSstate*, Packet*);
static void buildAck(TTLSstate*, Packet*);
static void trans(TTLSstate*, int);
static void ttls(TTLSstate*, Packet*, Packet*, int*, int*);
static Buf* allocBuf(ulong);
static Buf* reallocBuf(Buf*, ulong);


static void
readproc(void *arg)
{
	TTLSstate *s;
	Buf *prx;
	int fd, n;

	s = arg;

	loglog("readproc starts: %d", threadid());
	while(recvul(s->rstarterc)) {
		loglog("readproc monitoring pipe: %d", s->p[0]);
		fd = s->p[0];
		for(;;) {
			if (s->p[0] < 0) {
				loglog("(readproc pipe not active: %d :%d)", fd, s->p[0]);
//				break;
			}
			prx = s->rbuf[s->ridx];
			prx->p = prx->b;
			prx->e = prx->b;
			n = read(fd, prx->b,  prx->end - prx->b);
			if (n < 0) {
				loglog("readproc fail on pipe: %d: %r", fd);
				break;
			} else if (n == 0) {
				loglog("readproc read 0 or eof on pipe: %d", fd);
				break;
			} else {
				prx->e = prx->b +n;
				loglog("readproc read from %d into %p: %d", fd, prx, n);
			}
			if (s->p[0] < 0) {
				loglog("(readproc pipe no longer active: %d: %d)", fd, s->p[0]);
//				break;
			}
//			loglog("readproc sending...");
			send(s->readc, &prx);
			s->ridx = (s->ridx+1)%Nbuf;
		}
		loglog("readproc sending eofc: %d: %d", fd, s->p[0]);
		sendul(s->eofc, 0);
		loglog("readproc restarts: %d: %d", fd, s->p[0]);
	}
	loglog("readproc exits: %d", threadid());
	threadexits(nil);
}

static void
clientproc(void *arg)
{
	TTLSstate *s;
	int fd;
	uchar hash[SHA1dlen];

	s = arg;
	loglog("clientproc starts: %d", threadid());
	s->cStarted = 1; // beyond this we are forced to send on ctidc

	loglog("clientproc (re)starting: p[1]=%d", s->p[1]);
	if (s->p[1] <= 0)
		logfatal(0,  "clientproc: no fd for tlsClient:%d", s->p[1]);

	loglog("calling tlsClient debugTLS=%d, tlslog=%p", debugTLS, s->conn.trace);
	markPhaseStart(1, "ttls");
	fd  = tlsClient(s->p[1], &s->conn);
	loglog("return tlsClient...");
	markPhaseResult(1, "ttls", fd >= 0);
	if (s->ctid != threadid())
		logall("oops tlsClient schizophrenie ctid=%d threadid=%d", s->ctid, threadid());
	else
		s->fd  = fd;
	loglog("tlsClient %d result: fd=%d", threadid(), fd);
	if (fd < 0 && s->state == Cleanup)
		logall("while cleaning up: tlsClient %d failed: %r", threadid());
	else if (fd < 0)
		logall("tlsClient %d failed: %r", threadid());
	else {
		loglog("tlsClient %d ok fd=%d", threadid(), fd);
		if (s->conn.cert==nil || s->conn.certlen<=0)
			logall("server did not provide TLS certificate");
		else {
			// X509dump(s->conn.cert, s->conn.certlen);
			if (s->thumbTable != nil) {
				sha1(s->conn.cert, s->conn.certlen, hash, nil);
				if(!okThumbprint(hash, s->thumbTable))
					logall("server certificate %.*H not recognized", SHA1dlen, hash);
			} else
				logall("no thumbprint to check server certificate");
		}
	}

	// clean up before we (implicitly) yield in the sendul
	if (s->conn.sessionID != nil)
		free(s->conn.sessionID);
	s->conn.sessionID = nil;
	s->conn.sessionIDlen = 0;

	if (s->conn.cert)
		free(s->conn.cert);
	s->conn.cert = nil;
	s->conn.certlen = 0;

	sendul(s->ctidc, threadid());

	loglog("clientproc  %d ... finished: fd=%d", threadid(), s->fd);
	loglog("clientproc exits: %d", threadid());
	threadexits(nil);
}

static void
cleanup(TTLSstate* s)
{
	int consumeRead, consumeClient, id, ret, timeout, n;
	Buf *prx;
	char dummy[1];
	Alt a[] = {
	/*	 c			v		op   */
		{s->readc,	&prx,	CHANRCV},
		{s->eofc,		nil,	CHANRCV},
		{s->ctidc,	&id,	CHANRCV},
		{s->tickc,	nil,	CHANRCV},
		{nil,			nil,	CHANEND},
	};

	loglog("cleanup pre fd=%d cStarted=%d cReturned=%d ctid=%d cpid=%d p[0]=%d  p[1]=%d", s->fd, s->cStarted, s->cReturned, s->ctid, threadpid(s->ctid), s->p[0], s->p[1]);
	trans(s, Cleanup);

	if (!s->inuse)
		return;

	if (s->fd >= 0) {
		loglog("\tcleanup: closing fd: %d", s->fd);
		// should make devtls  close s->p[1], causing eof on s->p[0] in readproc
		if (close(s->fd) < 0)
			loglog("\tcleanup: failed closing fd: %d:%r", s->fd);
		else
			loglog("\tcleanup: closed fd: %d", s->fd);
		s->fd = -2;
		s->p[1] = -1;
	} else {
		if (s->p[1] >= 0) {
			loglog("\tcleanup: writing 0 to p[1]: %d", s->p[1]);
			if ((n = write(s->p[1], dummy, 0)) < 0)
				loglog("\tcleanup: failed writing 0 to p[1]: %d : %r", s->p[1]);
			else
				loglog("\tcleanup: written 0 to p[1]: %d : %d", s->p[1], n);
			loglog("\tcleanup: closing p[1]: %d", s->p[1]);
			if (close(s->p[1]) < 0)
				loglog("\tcleanup: failed closing p[1]: %d: %r", s->p[1]);
			else
				loglog("\tcleanup: closed p[1]: %d", s->p[1]);
			s->p[1] = -1;
		} else {
			logall("\tcleanup: oops should not happen p[1] < 0: %d", s->p[1]);
		}
		if (s->p[0] >= 0) {
			loglog("\tcleanup: closing p[0]: %d", s->p[0]);
			if (close(s->p[0]) < 0)
				loglog("\tcleanup: failed closing p[0]: %d:%r", s->p[0]);
			else
				loglog("\tcleanup: closed p[0]: %d", s->p[0]);
			s->p[0] = -1;
		} else {
			logall("\tcleanup: oops should not happen p[0] < 0: %d", s->p[0]);
		}
	}
	consumeRead = 1;
	if (s->cStarted && !s->cReturned)
		consumeClient = 1;
	else
		consumeClient = 0;

	loglog("\tcleanup middle consumeRead=%d consumeClient=%d fd=%d cStarted=%d cReturned=%d  ctid=%d cpid=%d p[0]=%d p[1]=%d", consumeRead, consumeClient, s->fd, s->cStarted, s->cReturned, s->ctid, threadpid(s->ctid), s->p[0], s->p[1]);

	timeout = 0;
	startTimer(s->cleanupWhile, s->cleanupPeriod);
	while((consumeRead || consumeClient) && !timeout) {
		loglog("\tcleanup receiving...");
		switch(ret = alt(a)){
		case 0:
			loglog("\t\toops... cleanup recv from readc: %p", prx);
			// is this the close assert . if so, should we write this to ether? 
			break;
		case 1:
			loglog("\t\tcleanup: confirmed eof from readproc");
			consumeRead =  0;
			if (s->p[0] >= 0) {
				loglog("\tcleanup: closing p[0]: %d", s->p[0]);
				if (close(s->p[0]) < 0)
					loglog("\tcleanup: failed closing p[0]: %d:%r", s->p[0]);
				else
					loglog("\tcleanup: closed p[0]: %d", s->p[0]);
				s->p[0] = -1;
			}
			break;
		case 2:
			if (s->ctid == id) {
				loglog("\t\tcleanup: confirmed return from clientproc %d", id);
				consumeClient = 0;
			} else
				loglog("\t\tcleanup: oops return from older clientproc %d", id);
			break;
		case 3:	/* timer tick */
			// loglog("ttls cleanup timer tick");
			tickTimer(s->cleanupWhile);
			if (s->cleanupWhile->counter == 0){
				logall("ttls cleanup timer expired");
				if (consumeClient) {
					logall("ttls cleanup threadkill clientproc %d", s->ctid);
					threadkill(s->ctid);
					consumeClient = 0;
					startTimer(s->cleanupWhile, s->cleanupPeriod);
				} else if (consumeRead) {
					logall("ttls cleanup threadkill readproc %d", s->rtid);
					threadkill(s->rtid);
					s->rtid = proccreate(readproc, s, STACK);
					logall("ttls cleanup started readproc tid=%d pid=%d", s->rtid, threadpid(s->rtid));
					consumeRead = 0;
					startTimer(s->cleanupWhile, s->cleanupPeriod);
				} else {
					logall("ttls cleanup giving up");
					timeout = 1;
				}
			}
			break;
		default:
			logall("\t\tcleanup: unexpected %d", ret);
			break;
		}
	}
	resetTimer(s->cleanupWhile);
	if (s->fd != -2) {
		loglog("\tcleanup: reset fd: %d", s->fd);
		s->fd = -2;
	}
	loglog("\tcleanup post consumeRead=%d consumeClient=%d fd=%d cStarted=%d cReturned=%d  ctid=%d cpid=%d p[0]=%d p[1]=%d", consumeRead, consumeClient, s->fd, s->cStarted, s->cReturned, s->ctid, threadpid(s->ctid), s->p[0], s->p[1]);
}

static void
setupTls(TTLSstate *s)
{
	loglog("setupTls pre p[0]=%d  p[1]=%d", s->p[0], s->p[1]);
	if (s->p[0] >= 0 || s->p[1] >= 0)
		logfatal(0, "setupTls: pipe already open? %d %d", s->p[0], s->p[1]);
	if (s->fd != -2)
		logfatal(0, "setupTls: fd init error? %d (expected -2)", s->fd);
	if (pipe(s->p) < 0)
		logfatal(0, "pipe failed: %r");

	s->cStarted = 0;
	s->cReturned = 0;
	// call tlsClient and wait for result
	loglog("setupTls clientproc...");
	s->ctid = procrfork(clientproc, s, STACK, RFNAMEG|RFNOTEG);
	loglog("started clientproc tid=%d pid=%d", s->ctid, threadpid(s->ctid));
	

	// signal reader to restart
	loglog("setupTls rstarterc...");
	sendul(s->rstarterc, 1);

	s->inuse = 1;

	loglog("setupTls post p[0]=%d  p[1]=%d", s->p[0], s->p[1]);
}

static int
buildFrameStart(TTLSstate *s, Packet *p)
{
	TTLS *t;
	int todo;

	todo = s->ptoe->e - s->ptoe->p;
	if (p->n <= TtlsLongHlen)
		logall("buildFrameStart error: mtu much too small: mtu=%d, longhdr=%d", p->n, TtlsLongHlen);
	if (todo <= p->n - TtlsLongHlen)
		logall("buildFrameStart error: small enough, no framing needed: sz=%d, space=%d", todo, p->n - TtlsLongHlen);
	t = (TTLS*)p->p;
	memset(t, 0, TtlsLongHlen);
	t->tp = EapTpTtls;
	t->flags = TtlsFlagM | TtlsFlagL;
	hnputl(t->tln, todo);
	memcpy(p->p + TtlsLongHlen, s->ptoe->p, p->n - TtlsLongHlen);
	s->txLen = p->n;
	s->ptoe->p += p->n - TtlsLongHlen;
	return p->n;
}

static int
buildFrameMiddle(TTLSstate *s, Packet *p)
{
	TTLS *t;
	int todo;

	todo = s->ptoe->e - s->ptoe->p;
	if (p->n <= TtlsShortHlen)
		logall("buildFrameMiddle error: mtu much too small: mtu=%d, longhdr=%d", p->n, TtlsShortHlen);
	if (todo <= p->n - TtlsShortHlen)
		logall("buildFrameMiddle error: small enough, no framing needed: sz=%d, space=%d", todo, p->n-TtlsShortHlen);
	t = (TTLS*)p->p;
	memset(t, 0, TtlsShortHlen);
	t->tp = EapTpTtls;
	t->flags = TtlsFlagM;
	memcpy(p->p + TtlsShortHlen, s->ptoe->p, p->n - TtlsShortHlen);
	s->txLen = p->n;
	s->ptoe->p += p->n - TtlsShortHlen;
	return p->n;
}

static int
buildMsg(TTLSstate *s, Packet *p)
{
	TTLS *t;
	int todo;

	todo = s->ptoe->e - s->ptoe->p;
	if (p->n <= TtlsShortHlen)
		logall("buildMsg error: mtu much too small: mtu=%d, longhdr=%d", p->n, TtlsShortHlen);
	if (todo > p->n - TtlsShortHlen)
		logall("buildMsg error: too big, framing needed: sz=%d, space=%d", todo, p->n - TtlsShortHlen);
	t = (TTLS*)p->p;
	memset(t, 0, TtlsShortHlen);
	t->tp = EapTpTtls;
	memcpy(p->p + TtlsShortHlen, s->ptoe->p, todo);
	s->txLen = TtlsShortHlen + todo;
	s->ptoe->p = 0;
	return todo;
}

static void
buildAck(TTLSstate *s, Packet *p)
{
	TTLS *t;

	t = (TTLS*)p->p;
	memset(t, 0, TtlsShortHlen);
	t->tp = EapTpTtls;
	s->txLen = TtlsShortHlen;
}

static void
trans(TTLSstate *s, int new)
{
	loglog("ttls trans: %s -> %s", snames[s->state], snames[new]);
	switch(new){
	case RecvAck:
		s->done = 1;
		break;
	case Receiving:
		s->done = 1;
		break;
	case Idle:
		s->done = 1;
		break;
	}
	s->state = new;
}

static void
ttls(TTLSstate *s, Packet* erx, Packet *etx, int *succp, int *failp)
{
	Buf *prx;
	int id, done;
	Alt a[] = {
	/*	 c			v		op   */
		{s->ctidc,	&id,		CHANRCV},
		{s->readc,	&prx,	CHANRCV},
		{s->tickc,	nil,		CHANRCV},
		{nil,			nil,		CHANEND},
	};
	TTLS *t;
	uchar *p;
	uint l;
	int n;
	int olen, flen;
	int todo, total;

	USED(succp); // just get rid of it?
	switch(s->state){
	case Idle:
		trans(s, Idle);
		break;
	case Start:
		setupTls(s); // new session
		s->etop = s->wbuf;
		s->etop->p = s->etop->b;
		s->etop->e = s->etop->b;
		s->more = 0;
		trans(s, Waiting);
		break;
	case Waiting:
		done = 0;
		startTimer(s->ttlsWhile, s->ttlsPeriod);
		while(!done)
			switch(alt(a)){
			case 0: // the tlsClient call returned
				loglog("ttls Waiting tlsClient %d returns", id);
				done = 1;
				if (s->ctid == id) {
					loglog("ttls Waiting tlsClient %d return %d", id, s->fd);
					s->cReturned = 1;
					if (s->fd < 0) {
						*failp = 1;
						// XXX build fail packet???
						trans(s, Idle);
					} else {
						markPhaseStart(2, "pap");
						doTTLSphase2(s->fd);
						markPhaseDone(2, "pap");
					}
				} else
					loglog("ttls Waiting oops older tlsClient %d returned %d", id, s->fd);
				break;
			case 1: // something read from p: encapsulate and send
				    // we do no treat end-of-file on p here, but leave that for cleanup.
				    // is this a wise choice?
				loglog("ttls Waiting read from p:%p", prx);
				done = 1;
				s->ptoe = prx;
				s->ptoe->p = s->ptoe->b;
				if (debug) print("ttls readc: prx=%p prx->p=%p prx->e=%p n=%ld\n", prx, prx->p, prx->e, prx->e - prx->b);
				trans(s, Sending);
				break;
			case 2:	/* timer tick */
				// loglog("ttls Waiting tlsClient timer tick");
				tickTimer(s->ttlsWhile);
				if (s->ttlsWhile->counter == 0) {
					logall("ttls Waiting: tlsClient timer expired");
					done = 1;
					trans(s, Timeout);
				}
			break;
			}
		resetTimer(s->ttlsWhile);
		break;
	case Timeout:
		trans(s, Receiving); // seems we need more stuff to satisfy tlsClient
		break;
	case Sending:
		todo = s->ptoe->e - s->ptoe->p;
		total = s->ptoe->e - s->ptoe->b;
		if (s->ptoe->p == s->ptoe->b && todo > etx->n - TtlsShortHlen) {
			olen = todo;
			flen = buildFrameStart(s, etx);
			if (debug) print("ttls first fragment framed %d of %d, total %d, todo %d\n", flen, olen, total, todo);
			trans(s, RecvAck);
		} else if (todo > etx->n - TtlsShortHlen) {
			olen = todo;
			flen = buildFrameMiddle(s, etx);
			if (debug) print("ttls framed %d of %d, total %d, todo %d\n", flen, olen, total, todo);
			trans(s, RecvAck);
		} else {
			olen = todo;
			flen = buildMsg(s, etx);
			if (debug) print("ttls framed %d of %d, total %d, todo %d\n", flen, olen, total, todo);
			s->etop = s->wbuf;
			s->etop->p = s->etop->b;
			s->etop->e = s->etop->b;
			s->more = 0;
			trans(s, Receiving);
		}
		break;
	case RecvAck:
		t = (TTLS*)erx->p;
		if (t->flags&TtlsFlagS)
			logall("ttls %s: unexpected TtlsFlagS", snames[s->state]);
		if (t->flags&TtlsFlagM)
			logall("ttls %s: unexpected TtlsFlagM", snames[s->state]);
		if (t->flags&TtlsFlagL)
			logall("ttls %s: unexpected TtlsFlagL", snames[s->state]);
		trans(s, Sending);
		break;
	case Receiving:
		t = (TTLS*)erx->p;
		if (t->flags&TtlsFlagS)
			logall("tls %s: unexpected TtlsFlagS", snames[s->state]);
		if (t->flags&TtlsFlagL) {
			if (s->more)
				logall("ttls %s: L flag but more=%d", snames[s->state], s->more);
			if (s->etop->p != s->etop->b) {
				logall("ttls %s: L flag: etop->p=%p != etop->b=%p", snames[s->state], s->etop->p, s->etop->b);
				s->etop->e = s->etop->b;
			}
			total = nhgetl(t->tln);
			if (total > s->etop->end - s->etop->b) {
				logall("ttls %s: requested length=%d > reassemble bufsize=%ld", snames[s->state], total, s->etop->end - s->etop->b);
				reallocBuf(s->etop, total);
			}
			s->etop->e = s->etop->b + total;
			if (debug) print("ttls:  TtlsFlagL len=%d\n", total);
			p = erx->p + TtlsLongHlen;
			l = erx->n - TtlsLongHlen;
		} else if (s->etop->e > s->etop->b) {
			if (!s->more)
				logall("ttls %s: etop->p=%p > etop->b=%p but more=%d", snames[s->state], s->etop->p, s->etop->b, s->more);
			p = erx->p + TtlsShortHlen;
			l = erx->n -TtlsShortHlen;
			total = s->etop->e - s->etop->b;
		} else  if (s->etop->e == s->etop->b) {
			if (s->more)
				logall("ttls %s: etop->p=%p == etop->b=%p but more=%d", snames[s->state], s->etop->p, s->etop->b, s->more);
			p = erx->p + TtlsShortHlen;
			l = erx->n -TtlsShortHlen;
			total = l;
			if (total > s->etop->end - s->etop->b) {
				logall("ttls %s: requested length=%d > reassemble bufsize=%ld", snames[s->state], total, s->etop->end - s->etop->b);
				reallocBuf(s->etop, total);
			}
			s->etop->e = s->etop->b + total;
		} else {	// (s->etop->e < s->etop->b)
			SET(total);
			SET(p);
			SET(l);
			logfatal(0, "ttls %s: should not happen: s->etop->e=%p < s->etop->b=%p", snames[s->state], s->etop->e , s->etop->b);
		}
		if (t->flags&TtlsFlagM)
			s->more = 1;
		else
			s->more = 0;
		if (s->etop->p + l > s->etop->e) {
			logall("ttls: fragment sz=%d + sofar=%ld > total=%d", l, s->etop->p - s->etop->b, total);
			// just store, it will be discarded later (because ->p != ->e);
			// should discard this + future framents now
			// (extend state machine with  Discard state version of Receive?)
		}
		if (s->etop->p + l > s->etop->end) {
			logall("ttls %s: buffer still to small to reassemble: bufsize=%ld, needed=%ld", snames[s->state], s->etop->end - s->etop->b, s->etop->p + l -  s->etop->b);
			reallocBuf(s->etop, s->etop->p + l - s->etop->b);
		}
		memcpy(s->etop->p, p, l);
		s->etop->p += l;
		if (debug) print("ttls %s: received %d; total=%d\n", snames[s->state], l, total);
		if (t->flags&TtlsFlagM)
			trans(s, SendAck);
		else {
			if (s->etop->b != s->etop->e && s->etop->p != s->etop->e)
				logall("ttls %s: reassembled=%ld !=  total=%d", snames[s->state], s->etop->p  - s->etop->b, total);
			if (s->etop->b != s->etop->e && s->etop->p == s->etop->e)
				trans(s, Received);
			else if (s->etop->b != s->etop->e && s->etop->p != s->etop->e) {
				// just ignore the packet
				// should we report the wrong-amount-read to other side?
				logall("ttls %s: skipping short reassembly (total=%d reass=%ld)", snames[s->state], total, s->etop->p  - s->etop->b);
				trans(s, Waiting);
			} else
				trans(s, Waiting);
		}
		break;
	case SendAck:
		buildAck(s, etx);
		trans(s, Receiving);
		break;
	case Received:
		// we only get here if we read the right amount
		total = s->etop->e - s->etop->b;
		if (debug) print("ttls %s: writing p[0]: %.*H\n", snames[s->state], total, s->etop->b);
		n = write(s->p[0], s->etop->b, total);
		if (n<0)
			logall("ttls %s: error writing p[0]: %r", snames[s->state]);
		loglog("writeproc written to %d: %d", s->p[0], n);
		if (n != total)
			logall("ttls %s: writing p[0]: n=%d != total=%d", snames[s->state], n, total);
		trans(s, Waiting);
		break;
	}
}

static Buf*
allocBuf(ulong n)
{
	Buf *p;

	p = malloc(sizeof(Buf));
	if (p == nil)
		logfatal(1, "could not alloc Buf");
	memset(p, 0, sizeof(Buf));
	p->b = malloc(n);
	if (p->b == nil)
		logfatal(1, "could not alloc Buf buffer");
	memset(p->b, 0, n);
	p->e = p->b;
	p->end = p->b + n;
	return p;
}

static Buf*
reallocBuf(Buf* p, ulong n)
{
	int offe, offp;

	if (n <= p->end - p->b) {
		logall("will not shrink Buf (from %uld to %uld)", p->end - p->b, n);
		return p;
	}

	offe = p->e - p->b;
	offp = p->p - p-> b;

	p->b = realloc(p->b, n);
	if (p->b == nil)
		logfatal(1, "could not realloc Buf buffer");
	p->e = p->b + offe;
	p->p = p->b + offp;
	p->end = p->b + n;
	return p;
}

void
initTTLS(char *file, char *filex, Timers *t)
{
	TTLSstate *s;
	int i;

	loglog("initTTLS");

	s = &theTTLSstate;
	memset(s, 0, sizeof(TTLSstate));

	s->state = Idle;

	s->ctidc = chancreate(sizeof(int), 0);
	s->readc = chancreate(sizeof(Buf*), 0);
	s->eofc = chancreate(sizeof(int), 0);
	s->cstarterc = chancreate(sizeof(int), 0);
	s->rstarterc = chancreate(sizeof(int), 0);
	s->tickc = t->chan;

	s->fd = -2;
	s->p[0] = -1;
	s->p[1] = -1;
	s->ttlsWhile = addTimer(t, "ttlsWhile");
	s->cleanupWhile = addTimer(t, "cleanupWhile");

	s->ttlsPeriod = 5; //seconds
	s->cleanupPeriod = 5; //seconds

	s->conn.sessionType = "ttls";
	s->conn.sessionConst = "ttls keying material";
	s->conn.sessionKey = theSessionKey;
	s->conn.sessionKeylen = sizeof(theSessionKey);
	if (debugTLS)
		s->conn.trace = tlslog;
	else
		s->conn.trace = nil;

	fmtinstall('H', encodefmt);
	if (file) {
		s->thumbTable = initThumbprints(file, filex);
		if (s->thumbTable == nil)
			logfatal(1, "initThumbprints: %r");
	}

	for (i = 0; i < Nbuf; i++)
		s->rbuf[i] = allocBuf(Buflen);
	s->wbuf = allocBuf(0);

	s->rtid = procrfork(readproc, s, STACK, RFNAMEG|RFNOTEG);
	loglog("initTTLS started readproc tid=%d pid=%d", s->rtid, threadpid(s->rtid));

}

void
abortTTLS(void)
{
	TTLSstate *s;

	loglog("abortTTLS");
	s = &theTTLSstate;

	if (s->p[0] >= 0) {
		close(s->p[0]);
		s->p[0] = -1;
	}
}

int
processTTLS(Packet *rx, Packet *tx, int expectStart, int*succp, int*failp)
{
	TTLS *t;
	uchar version;
	TTLSstate *s;

	s = &theTTLSstate;

	t = (TTLS*)rx->p;

	if (t->tp != EapTpTtls)
		return 0; // flag error??

	// first thing should be EAP-TTLS start packet
	// check length: do we have enough data for flags?
	version = t->flags & TtlsVersion;
	loglog("processTTLS flags=%s%s%s ver=%d mtu=%d bl=%d",
		(t->flags&TtlsFlagS ? "S":""),
		(t->flags&TtlsFlagM ? "M":""),
		(t->flags&TtlsFlagL ? "L":""),
		version, tx->n, rx->n);
	if (expectStart && !t->flags&TtlsFlagS)
		logfatal(0, "expected EAP-TTLS start packet");
	if (t->flags & TtlsFlagS) {
		cleanup(s); // previous session

		// ack??
		// look for piggy-backed stuff?

		s->version = version;
		s->state = Start;
		s->done = 0;
		s->ptoe = nil;
		s->etop = nil;
		s->wbuf->p = s->wbuf->b;
		s->wbuf->e = s->wbuf->e;
		// we don't have a client certificate
		s->conn.cert = nil;
		s->conn.certlen = 0;
		// avoid trying session resumption - tlsClient does not support it
		s->conn.sessionID = nil;
		s->conn.sessionIDlen = 0;

		if (debugTLS)
			s->conn.trace = tlslog;
		else
			s->conn.trace = nil;
		
	}
	s->txLen = 0;
	s->done = 0;
	while (!s->done)
		ttls(s, rx, tx, succp, failp);
	return s->txLen;
}

Bell Labs OSI certified Powered by Plan 9

(Return to Plan 9 Home Page)

Copyright © 2021 Plan 9 Foundation. All Rights Reserved.
Comments to [email protected].