From e153afc401ccb9256385a0a3da10bf412d87fe1f Mon Sep 17 00:00:00 2001
From: venaas <venaas>
Date: Thu, 18 Sep 2008 11:10:44 +0000
Subject: fixed some bugs, improved duplicate detection

git-svn-id: https://svn.testnett.uninett.no/radsecproxy/trunk@391 e88ac4ed-0b26-0410-9574-a7f39faa03bf
---
 radmsg.c      | 10 +++++---
 radsecproxy.c | 75 ++++++++++++++++++++++++++---------------------------------
 radsecproxy.h |  7 +++---
 udp.c         | 36 ++++++++++++++--------------
 4 files changed, 62 insertions(+), 66 deletions(-)

diff --git a/radmsg.c b/radmsg.c
index 422186d..0ea6ee7 100644
--- a/radmsg.c
+++ b/radmsg.c
@@ -225,9 +225,13 @@ uint8_t *radmsg2buf(struct radmsg *msg, uint8_t *secret) {
 	free(buf);
 	return NULL;
     }
-    if (secret && (msg->code == RAD_Access_Accept || msg->code == RAD_Access_Reject || msg->code == RAD_Access_Challenge || msg->code == RAD_Accounting_Response || msg->code == RAD_Accounting_Request) && !_radsign(buf, secret)) {
-	free(buf);
-	return NULL;
+    if (secret) {
+	if ((msg->code == RAD_Access_Accept || msg->code == RAD_Access_Reject || msg->code == RAD_Access_Challenge || msg->code == RAD_Accounting_Response || msg->code == RAD_Accounting_Request) && !_radsign(buf, secret)) {
+	    free(buf);
+	    return NULL;
+	}
+	if (msg->code == RAD_Accounting_Request)
+	    memcpy(msg->auth, buf + 4, 16);
     }
     return buf;
 }
diff --git a/radsecproxy.c b/radsecproxy.c
index fdb2838..6ac3825 100644
--- a/radsecproxy.c
+++ b/radsecproxy.c
@@ -577,7 +577,6 @@ struct client *addclient(struct clsrvconf *conf, uint8_t lock) {
     }
     
     memset(new, 0, sizeof(struct client));
-    pthread_mutex_init(&new->lock, NULL);
     new->conf = conf;
     if (conf->pdef->addclient)
 	conf->pdef->addclient(new);
@@ -597,11 +596,8 @@ void removeclient(struct client *client) {
     conf = client->conf;
     pthread_mutex_lock(conf->lock);
     if (conf->clients) {
-	pthread_mutex_lock(&client->lock);
 	removequeue(client->replyq);
 	list_removedata(conf->clients, client);
-	pthread_mutex_unlock(&client->lock);
-	pthread_mutex_destroy(&client->lock);
 	free(client->addr);
 	free(client);
     }
@@ -1352,14 +1348,14 @@ int pwdrecrypt(uint8_t *pwd, uint8_t len, char *oldsecret, char *newsecret, uint
     return 1;
 }
 
-int msmpprecrypt(uint8_t *msmpp, uint8_t len, char *oldsecret, char *newsecret, unsigned char *oldauth, char *newauth) {
+int msmpprecrypt(uint8_t *msmpp, uint8_t len, char *oldsecret, char *newsecret, uint8_t *oldauth, uint8_t *newauth) {
     if (len < 18)
 	return 0;
-    if (!msmppdecrypt(msmpp + 2, len - 2, (unsigned char *)oldsecret, strlen(oldsecret), oldauth, msmpp)) {
+    if (!msmppdecrypt(msmpp + 2, len - 2, (uint8_t *)oldsecret, strlen(oldsecret), oldauth, msmpp)) {
 	debug(DBG_WARN, "msmpprecrypt: failed to decrypt msppe key");
 	return 0;
     }
-    if (!msmppencrypt(msmpp + 2, len - 2, (unsigned char *)newsecret, strlen(newsecret), (unsigned char *)newauth, msmpp)) {
+    if (!msmppencrypt(msmpp + 2, len - 2, (uint8_t *)newsecret, strlen(newsecret), newauth, msmpp)) {
 	debug(DBG_WARN, "msmpprecrypt: failed to encrypt msppe key");
 	return 0;
     }
@@ -1372,7 +1368,7 @@ int msmppe(unsigned char *attrs, int length, uint8_t type, char *attrtxt, struct
     
     for (attr = attrs; (attr = attrget(attr, length - (attr - attrs), type)); attr += ATTRLEN(attr)) {
 	debug(DBG_DBG, "msmppe: Got %s", attrtxt);
-	if (!msmpprecrypt(ATTRVAL(attr), ATTRVALLEN(attr), oldsecret, newsecret, rq->buf + 4, rq->origauth))
+	if (!msmpprecrypt(ATTRVAL(attr), ATTRVALLEN(attr), oldsecret, newsecret, rq->buf + 4, rq->rqauth))
 	    return 0;
     }
     return 1;
@@ -1722,45 +1718,42 @@ struct request *newrequest() {
     return rq;
 }
 
-int addclientrq(struct request *rq, uint8_t id) {
+int addclientrq(struct request *rq) {
     struct request *r;
     struct timeval now;
-
-    pthread_mutex_lock(&rq->from->lock);
-    gettimeofday(&now, NULL);
-    r = rq->from->rqs[id];
+    
+    r = rq->from->rqs[rq->rqid];
     if (r) {
-	if (now.tv_sec - r->created.tv_sec < r->from->conf->dupinterval) {
+	if (rq->udpport == r->udpport && !memcmp(rq->rqauth, r->rqauth, 16)) {
+	    gettimeofday(&now, NULL);
+	    if (now.tv_sec - r->created.tv_sec < r->from->conf->dupinterval) {
 #if 0
-	    later	    
-	    if (r->replybuf) {
-		debug(DBG_INFO, "radsrv: already sent reply to request with id %d from %s, resending", id, r->from->conf->host);
-		r->refcount++;
-		sendreply(r);
-	    } else
+		later	    
+		    if (r->replybuf) {
+			debug(DBG_INFO, "radsrv: already sent reply to request with id %d from %s, resending", rq->rqid, addr2string(r->from->addr));
+			r->refcount++;
+			sendreply(r);
+		    } else
 #endif		
-		debug(DBG_INFO, "radsrv: already got request with id %d from %s, ignoring", id, r->from->conf->host);		
-	    pthread_mutex_unlock(&rq->from->lock);
-	    return 0;
+			debug(DBG_INFO, "radsrv: already got request with id %d from %s, ignoring", rq->rqid, addr2string(r->from->addr));
+		return 0;
+	    }
 	}
 	freerq(r);
     }
     rq->refcount++;
-    rq->from->rqs[id] = rq;
-    pthread_mutex_unlock(&rq->from->lock);
+    rq->from->rqs[rq->rqid] = rq;
     return 1;
 }
 
 void rmclientrq(struct request *rq, uint8_t id) {
     struct request *r;
 
-    pthread_mutex_lock(&rq->from->lock);
     r = rq->from->rqs[id];
     if (r) {
 	freerq(r);
 	rq->from->rqs[id] = NULL;
     }
-    pthread_mutex_unlock(&rq->from->lock);
 }
 
 /* returns 0 if validation/authentication fails, else 1 */
@@ -1768,7 +1761,6 @@ int radsrv(struct request *rq) {
     struct radmsg *msg = NULL;
     struct tlv *attr;
     uint8_t *userascii = NULL;
-    unsigned char newauth[16];
     struct realm *realm = NULL;
     struct server *to = NULL;
     struct client *from = rq->from;
@@ -1784,13 +1776,16 @@ int radsrv(struct request *rq) {
     }
     
     rq->msg = msg;
+    rq->rqid = msg->id;
+    memcpy(rq->rqauth, msg->auth, 16);
+
     debug(DBG_DBG, "radsrv: code %d, id %d", msg->code, msg->id);
     if (msg->code != RAD_Access_Request && msg->code != RAD_Status_Server && msg->code != RAD_Accounting_Request) {
 	debug(DBG_INFO, "radsrv: server currently accepts only access-requests, accounting-requests and status-server, ignoring");	
 	goto exit;
     }
     
-    if (!addclientrq(rq, msg->id))
+    if (!addclientrq(rq))
 	goto exit;
 
     if (msg->code == RAD_Status_Server) {
@@ -1846,11 +1841,11 @@ int radsrv(struct request *rq) {
 	goto exit;
     }
 
-    if (msg->code != RAD_Accounting_Request) {
-	if (!RAND_bytes(newauth, 16)) {
-	    debug(DBG_WARN, "radsrv: failed to generate random auth");
-	    goto rmclrqexit;
-	}
+    if (msg->code == RAD_Accounting_Request)
+	memset(msg->auth, 0, 16);
+    else if (!RAND_bytes(msg->auth, 16)) {
+	debug(DBG_WARN, "radsrv: failed to generate random auth");
+	goto rmclrqexit;
     }
     
 #ifdef DEBUG
@@ -1860,21 +1855,17 @@ int radsrv(struct request *rq) {
     attr = radmsg_gettype(msg, RAD_Attr_User_Password);
     if (attr) {
 	debug(DBG_DBG, "radsrv: found userpwdattr with value length %d", attr->l);
-	if (!pwdrecrypt(attr->v, attr->l, from->conf->secret, to->conf->secret, msg->auth, newauth))
+	if (!pwdrecrypt(attr->v, attr->l, from->conf->secret, to->conf->secret, rq->rqauth, msg->auth))
 	    goto rmclrqexit;
     }
 
     attr = radmsg_gettype(msg, RAD_Attr_Tunnel_Password);
     if (attr) {
 	debug(DBG_DBG, "radsrv: found tunnelpwdattr with value length %d", attr->l);
-	if (!pwdrecrypt(attr->v, attr->l, from->conf->secret, to->conf->secret, msg->auth, newauth))
+	if (!pwdrecrypt(attr->v, attr->l, from->conf->secret, to->conf->secret, rq->rqauth, msg->auth))
 	    goto rmclrqexit;
     }
 
-    rq->origid = msg->id;
-    memcpy(rq->origauth, msg->auth, 16);
-    memcpy(msg->auth, newauth, 16);    
-
     if (to->conf->rewriteout && !dorewrite(msg, to->conf->rewriteout))
 	goto rmclrqexit;
     
@@ -1993,8 +1984,8 @@ void replyh(struct server *server, unsigned char *buf) {
 	}
     }
 
-    msg->id = (char)rqout->rq->origid;
-    memcpy(msg->auth, rqout->rq->origauth, 16);
+    msg->id = (char)rqout->rq->rqid;
+    memcpy(msg->auth, rqout->rq->rqauth, 16);
 
 #ifdef DEBUG	
     printfchars(NULL, "origauth/buf+4", "%02x ", buf + 4, 16);
diff --git a/radsecproxy.h b/radsecproxy.h
index 6caf2d9..8c17c96 100644
--- a/radsecproxy.h
+++ b/radsecproxy.h
@@ -49,8 +49,8 @@ struct request {
     struct radmsg *msg;
     struct client *from;
     char *origusername;
-    char origauth[16];
-    uint8_t origid;
+    uint8_t rqid;
+    uint8_t rqauth[16];
     int udpsock; /* only for UDP */
     uint16_t udpport; /* only for UDP */
 };
@@ -102,9 +102,8 @@ struct clsrvconf {
 
 struct client {
     struct clsrvconf *conf;
-    int sock; /* for tcp/dtls */
+    int sock;
     SSL *ssl;
-    pthread_mutex_t lock; /* used for updating rqs */
     struct request *rqs[MAX_REQUESTS];
     struct queue *replyq;
     struct queue *rbios; /* for dtls */
diff --git a/udp.c b/udp.c
index 6b49e49..05e7a6b 100644
--- a/udp.c
+++ b/udp.c
@@ -47,6 +47,7 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
     struct clsrvconf *p;
     struct list_node *node;
     fd_set readfds;
+    struct client *c = NULL;
     
     for (;;) {
 	if (rad) {
@@ -103,26 +104,27 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
 
 	if (client) {
 	    pthread_mutex_lock(p->lock);
-	    for (node = list_first(p->clients); node; node = list_next(node))
-		if (addr_equal((struct sockaddr *)&from, ((struct client *)node->data)->addr))
+	    for (node = list_first(p->clients); node; node = list_next(node)) {
+		c = (struct client *)node->data;
+		if (s == c->sock && addr_equal((struct sockaddr *)&from, c->addr))
 		    break;
-	    if (node) {
-		*client = (struct client *)node->data;
-		pthread_mutex_unlock(p->lock);
-		break;
 	    }
-	    fromcopy = addr_copy((struct sockaddr *)&from);
-	    if (!fromcopy) {
-		pthread_mutex_unlock(p->lock);
-		continue;
+	    if (!node) {
+		fromcopy = addr_copy((struct sockaddr *)&from);
+		if (!fromcopy) {
+		    pthread_mutex_unlock(p->lock);
+		    continue;
+		}
+		c = addclient(p, 0);
+		if (!c) {
+		    free(fromcopy);
+		    pthread_mutex_unlock(p->lock);
+		    continue;
+		}
+		c->sock = s;
+		c->addr = fromcopy;
 	    }
-	    *client = addclient(p, 0);
-	    if (!*client) {
-		free(fromcopy);
-		pthread_mutex_unlock(p->lock);
-		continue;
-	    }
-	    (*client)->addr = fromcopy;
+	    *client = c;
 	    pthread_mutex_unlock(p->lock);
 	} else if (server)
 	    *server = p->servers;
-- 
cgit v1.1