source: trunk/athena/lib/ares/ares_process.c @ 17958

Revision 17958, 15.6 KB checked in by ghudson, 22 years ago (diff)
Fix freed memory access.
Line 
1/* Copyright 1998 by the Massachusetts Institute of Technology.
2 *
3 * Permission to use, copy, modify, and distribute this
4 * software and its documentation for any purpose and without
5 * fee is hereby granted, provided that the above copyright
6 * notice appear in all copies and that both that copyright
7 * notice and this permission notice appear in supporting
8 * documentation, and that the name of M.I.T. not be used in
9 * advertising or publicity pertaining to distribution of the
10 * software without specific, written prior permission.
11 * M.I.T. makes no representations about the suitability of
12 * this software for any purpose.  It is provided "as is"
13 * without express or implied warranty.
14 */
15
16static const char rcsid[] = "$Id: ares_process.c,v 1.10 2002-10-08 23:28:37 ghudson Exp $";
17
18#include <sys/types.h>
19#include <sys/socket.h>
20#include <sys/uio.h>
21#include <netinet/in.h>
22#include <arpa/nameser.h>
23#include <string.h>
24#include <stdlib.h>
25#include <unistd.h>
26#include <fcntl.h>
27#include <time.h>
28#include <errno.h>
29#include "ares.h"
30#include "ares_dns.h"
31#include "ares_private.h"
32
33static void write_tcp_data(ares_channel channel, time_t now);
34static void read_tcp_data(ares_channel channel, time_t now);
35static void read_udp_packets(ares_channel channel, time_t now);
36static void process_timeouts(ares_channel channel, time_t now);
37static void process_answer(ares_channel channel, unsigned char *abuf,
38                           int alen, int whichserver, int tcp, int now);
39static void handle_error(ares_channel channel, int whichserver, time_t now);
40static void next_server(ares_channel channel, struct query *query, time_t now);
41static int open_tcp_socket(ares_channel channel, struct server_state *server);
42static int open_udp_socket(ares_channel channel, struct server_state *server);
43static int same_questions(const unsigned char *qbuf, int qlen,
44                          const unsigned char *abuf, int alen);
45static void end_query(ares_channel channel, struct query *query, int status,
46                      unsigned char *abuf, int alen);
47
48/* Something interesting happened on the wire, or there was a timeout.
49 * See what's up and respond accordingly.
50 */
51void ares_process(ares_channel channel, fd_set *read_fds, fd_set *write_fds)
52{
53  int i;
54  struct server_state *server;
55  time_t now;
56
57  /* Set writable/readable flags on server states.  We can't just pass
58   * fd sets around because fds can be closed and reopened during
59   * processing.
60   */
61  for (i = 0; i < channel->nservers; i++)
62    {
63      server = &channel->servers[i];
64      server->udp_readable = (server->udp_socket != -1
65                              && FD_ISSET(server->udp_socket, read_fds));
66      server->tcp_readable = (server->tcp_socket != -1
67                              && FD_ISSET(server->tcp_socket, read_fds));
68      server->tcp_writable = (server->qhead && server->tcp_socket != -1
69                              && FD_ISSET(server->tcp_socket, write_fds));
70    }
71
72  time(&now);
73  write_tcp_data(channel, now);
74  read_tcp_data(channel, now);
75  read_udp_packets(channel, now);
76  process_timeouts(channel, now);
77}
78
79/* If any TCP sockets select true for writing, write out queued data
80 * we have for them.
81 */
82static void write_tcp_data(ares_channel channel, time_t now)
83{
84  struct server_state *server;
85  struct send_request *sendreq;
86  struct iovec *vec;
87  int i, n, count;
88
89  for (i = 0; i < channel->nservers; i++)
90    {
91      server = &channel->servers[i];
92      if (!server->tcp_writable)
93        continue;
94
95      /* Count the number of send queue items. */
96      n = 0;
97      for (sendreq = server->qhead; sendreq; sendreq = sendreq->next)
98        n++;
99
100      /* Allocate iovecs so we can send all our data at once. */
101      vec = malloc(n * sizeof(struct iovec));
102      if (vec)
103        {
104          /* Fill in the iovecs and send. */
105          n = 0;
106          for (sendreq = server->qhead; sendreq; sendreq = sendreq->next)
107            {
108              vec[n].iov_base = (char *) sendreq->data;
109              vec[n].iov_len = sendreq->len;
110              n++;
111            }
112          count = writev(server->tcp_socket, vec, n);
113          free(vec);
114          if (count < 0)
115            {
116              handle_error(channel, i, now);
117              continue;
118            }
119
120          /* Advance the send queue by as many bytes as we sent. */
121          while (count)
122            {
123              sendreq = server->qhead;
124              if (count >= sendreq->len)
125                {
126                  count -= sendreq->len;
127                  server->qhead = sendreq->next;
128                  if (server->qhead == NULL)
129                    server->qtail = NULL;
130                  free(sendreq);
131                }
132              else
133                {
134                  sendreq->data += count;
135                  sendreq->len -= count;
136                  break;
137                }
138            }
139        }
140      else
141        {
142          /* Can't allocate iovecs; just send the first request. */
143          sendreq = server->qhead;
144          count = write(server->tcp_socket, sendreq->data, sendreq->len);
145          if (count < 0)
146            {
147              handle_error(channel, i, now);
148              continue;
149            }
150
151          /* Advance the send queue by as many bytes as we sent. */
152          if (count == sendreq->len)
153            {
154              server->qhead = sendreq->next;
155              if (server->qhead == NULL)
156                server->qtail = NULL;
157              free(sendreq);
158            }
159          else
160            {
161              sendreq->data += count;
162              sendreq->len -= count;
163            }
164        }
165    }
166}
167
168/* If any TCP socket selects true for reading, read some data,
169 * allocate a buffer if we finish reading the length word, and process
170 * a packet if we finish reading one.
171 */
172static void read_tcp_data(ares_channel channel, time_t now)
173{
174  struct server_state *server;
175  int i, count;
176
177  for (i = 0; i < channel->nservers; i++)
178    {
179      server = &channel->servers[i];
180      if (!server->tcp_readable)
181        continue;
182
183      if (server->tcp_lenbuf_pos != 2)
184        {
185          /* We haven't yet read a length word, so read that (or
186           * what's left to read of it).
187           */
188          count = read(server->tcp_socket,
189                       server->tcp_lenbuf + server->tcp_lenbuf_pos,
190                       2 - server->tcp_lenbuf_pos);
191          if (count <= 0)
192            {
193              handle_error(channel, i, now);
194              continue;
195            }
196
197          server->tcp_lenbuf_pos += count;
198          if (server->tcp_lenbuf_pos == 2)
199            {
200              /* We finished reading the length word.  Decode the
201               * length and allocate a buffer for the data.
202               */
203              server->tcp_length = server->tcp_lenbuf[0] << 8
204                | server->tcp_lenbuf[1];
205              server->tcp_buffer = malloc(server->tcp_length);
206              if (!server->tcp_buffer)
207                handle_error(channel, i, now);
208              server->tcp_buffer_pos = 0;
209            }
210        }
211      else
212        {
213          /* Read data into the allocated buffer. */
214          count = read(server->tcp_socket,
215                       server->tcp_buffer + server->tcp_buffer_pos,
216                       server->tcp_length - server->tcp_buffer_pos);
217          if (count <= 0)
218            {
219              handle_error(channel, i, now);
220              continue;
221            }
222
223          server->tcp_buffer_pos += count;
224          if (server->tcp_buffer_pos == server->tcp_length)
225            {
226              /* We finished reading this answer; process it and
227               * prepare to read another length word.
228               */
229              process_answer(channel, server->tcp_buffer, server->tcp_length,
230                             i, 1, now);
231              free(server->tcp_buffer);
232              server->tcp_buffer = NULL;
233              server->tcp_lenbuf_pos = 0;
234            }
235        }
236    }
237}
238
239/* If any UDP sockets select true for reading, process them. */
240static void read_udp_packets(ares_channel channel, time_t now)
241{
242  struct server_state *server;
243  int i, count;
244  unsigned char buf[PACKETSZ + 1];
245
246  for (i = 0; i < channel->nservers; i++)
247    {
248      server = &channel->servers[i];
249      if (!server->udp_readable)
250        continue;
251
252      count = recv(server->udp_socket, buf, sizeof(buf), 0);
253      if (count <= 0)
254        handle_error(channel, i, now);
255
256      process_answer(channel, buf, count, i, 0, now);
257    }
258}
259
260/* If any queries have timed out, note the timeout and move them on. */
261static void process_timeouts(ares_channel channel, time_t now)
262{
263  struct query *query, *next;
264
265  for (query = channel->queries; query; query = next)
266    {
267      next = query->next;
268      if (query->timeout != 0 && now >= query->timeout)
269        {
270          query->error_status = ARES_ETIMEOUT;
271          next_server(channel, query, now);
272        }
273    }
274}
275
276/* Handle an answer from a server. */
277static void process_answer(ares_channel channel, unsigned char *abuf,
278                           int alen, int whichserver, int tcp, int now)
279{
280  int id, tc, rcode;
281  struct query *query;
282
283  /* If there's no room in the answer for a header, we can't do much
284   * with it. */
285  if (alen < HFIXEDSZ)
286    return;
287
288  /* Grab the query ID, truncate bit, and response code from the packet. */
289  id = DNS_HEADER_QID(abuf);
290  tc = DNS_HEADER_TC(abuf);
291  rcode = DNS_HEADER_RCODE(abuf);
292
293  /* Find the query corresponding to this packet. */
294  for (query = channel->queries; query; query = query->next)
295    {
296      if (query->qid == id)
297        break;
298    }
299  if (!query)
300    return;
301
302  /* If we got a truncated UDP packet and are not ignoring truncation,
303   * don't accept the packet, and switch the query to TCP if we hadn't
304   * done so already.
305   */
306  if ((tc || alen > PACKETSZ) && !tcp && !(channel->flags & ARES_FLAG_IGNTC))
307    {
308      if (!query->using_tcp)
309        {
310          query->using_tcp = 1;
311          ares__send_query(channel, query, now);
312        }
313      return;
314    }
315
316  /* Limit alen to PACKETSZ if we aren't using TCP (only relevant if we
317   * are ignoring truncation.
318   */
319  if (alen > PACKETSZ && !tcp)
320    alen = PACKETSZ;
321
322  /* If we aren't passing through all error packets, discard packets
323   * with SERVFAIL, NOTIMP, or REFUSED response codes.
324   */
325  if (!(channel->flags & ARES_FLAG_NOCHECKRESP))
326    {
327      if (rcode == SERVFAIL || rcode == NOTIMP || rcode == REFUSED)
328        {
329          query->skip_server[whichserver] = 1;
330          if (query->server == whichserver)
331            next_server(channel, query, now);
332          return;
333        }
334      if (!same_questions(query->qbuf, query->qlen, abuf, alen))
335        {
336          if (query->server == whichserver)
337            next_server(channel, query, now);
338          return;
339        }
340    }
341
342  end_query(channel, query, ARES_SUCCESS, abuf, alen);
343}
344
345static void handle_error(ares_channel channel, int whichserver, time_t now)
346{
347  struct query *query, *next;
348
349  /* Reset communications with this server. */
350  ares__close_sockets(&channel->servers[whichserver]);
351
352  /* Tell all queries talking to this server to move on and not try
353   * this server again.
354   */
355  for (query = channel->queries; query; query = next)
356    {
357      next = query->next;
358      if (query->server == whichserver)
359        {
360          query->skip_server[whichserver] = 1;
361          next_server(channel, query, now);
362        }
363    }
364}
365
366static void next_server(ares_channel channel, struct query *query, time_t now)
367{
368  /* Advance to the next server or try. */
369  query->server++;
370  for (; query->try < channel->tries; query->try++)
371    {
372      for (; query->server < channel->nservers; query->server++)
373        {
374          if (!query->skip_server[query->server])
375            {
376              ares__send_query(channel, query, now);
377              return;
378            }
379        }
380      query->server = 0;
381
382      /* Only one try if we're using TCP. */
383      if (query->using_tcp)
384        break;
385    }
386  end_query(channel, query, query->error_status, NULL, 0);
387}
388
389void ares__send_query(ares_channel channel, struct query *query, time_t now)
390{
391  struct send_request *sendreq;
392  struct server_state *server;
393
394  server = &channel->servers[query->server];
395  if (query->using_tcp)
396    {
397      /* Make sure the TCP socket for this server is set up and queue
398       * a send request.
399       */
400      if (server->tcp_socket == -1)
401        {
402          if (open_tcp_socket(channel, server) == -1)
403            {
404              query->skip_server[query->server] = 1;
405              next_server(channel, query, now);
406              return;
407            }
408        }
409      sendreq = malloc(sizeof(struct send_request));
410      if (!sendreq)
411        end_query(channel, query, ARES_ENOMEM, NULL, 0);
412      sendreq->data = query->tcpbuf;
413      sendreq->len = query->tcplen;
414      sendreq->next = NULL;
415      if (server->qtail)
416        server->qtail->next = sendreq;
417      else
418        server->qhead = sendreq;
419      server->qtail = sendreq;
420      query->timeout = 0;
421    }
422  else
423    {
424      if (server->udp_socket == -1)
425        {
426          if (open_udp_socket(channel, server) == -1)
427            {
428              query->skip_server[query->server] = 1;
429              next_server(channel, query, now);
430              return;
431            }
432        }
433      if (send(server->udp_socket, query->qbuf, query->qlen, 0) == -1)
434        {
435          query->skip_server[query->server] = 1;
436          next_server(channel, query, now);
437          return;
438        }
439      query->timeout = now
440          + ((query->try == 0) ? channel->timeout
441             : channel->timeout << query->try / channel->nservers);
442    }
443}
444
445static int open_tcp_socket(ares_channel channel, struct server_state *server)
446{
447  int s, flags;
448  struct sockaddr_in sin;
449
450  /* Acquire a socket. */
451  s = socket(AF_INET, SOCK_STREAM, 0);
452  if (s == -1)
453    return -1;
454
455  /* Set the socket non-blocking. */
456  if (fcntl(s, F_GETFL, &flags) == -1)
457    {
458      close(s);
459      return -1;
460    }
461  flags |= O_NONBLOCK;
462  if (fcntl(s, F_SETFL, flags) == -1)
463    {
464      close(s);
465      return -1;
466    }
467
468  /* Connect to the server. */
469  memset(&sin, 0, sizeof(sin));
470  sin.sin_family = AF_INET;
471  sin.sin_addr = server->addr;
472  sin.sin_port = channel->tcp_port;
473  if (connect(s, (struct sockaddr *) &sin, sizeof(sin)) == -1
474      && errno != EINPROGRESS)
475    {
476      close(s);
477      return -1;
478    }
479
480  server->tcp_socket = s;
481  return 0;
482}
483
484static int open_udp_socket(ares_channel channel, struct server_state *server)
485{
486  int s;
487  struct sockaddr_in sin;
488
489  /* Acquire a socket. */
490  s = socket(AF_INET, SOCK_DGRAM, 0);
491  if (s == -1)
492    return -1;
493
494  /* Connect to the server. */
495  memset(&sin, 0, sizeof(sin));
496  sin.sin_family = AF_INET;
497  sin.sin_addr = server->addr;
498  sin.sin_port = channel->udp_port;
499  if (connect(s, (struct sockaddr *) &sin, sizeof(sin)) == -1)
500    {
501      close(s);
502      return -1;
503    }
504
505  server->udp_socket = s;
506  return 0;
507}
508
509static int same_questions(const unsigned char *qbuf, int qlen,
510                          const unsigned char *abuf, int alen)
511{
512  struct {
513    const unsigned char *p;
514    int qdcount;
515    char *name;
516    int namelen;
517    int type;
518    int dnsclass;
519  } q, a;
520  int i, j;
521
522  if (qlen < HFIXEDSZ || alen < HFIXEDSZ)
523    return 0;
524
525  /* Extract qdcount from the request and reply buffers and compare them. */
526  q.qdcount = DNS_HEADER_QDCOUNT(qbuf);
527  a.qdcount = DNS_HEADER_QDCOUNT(abuf);
528  if (q.qdcount != a.qdcount)
529    return 0;
530
531  /* For each question in qbuf, find it in abuf. */
532  q.p = qbuf + HFIXEDSZ;
533  for (i = 0; i < q.qdcount; i++)
534    {
535      /* Decode the question in the query. */
536      if (ares_expand_name(q.p, qbuf, qlen, &q.name, &q.namelen)
537          != ARES_SUCCESS)
538        return 0;
539      q.p += q.namelen;
540      if (q.p + QFIXEDSZ > qbuf + qlen)
541        {
542          free(q.name);
543          return 0;
544        }
545      q.type = DNS_QUESTION_TYPE(q.p);
546      q.dnsclass = DNS_QUESTION_CLASS(q.p);
547      q.p += QFIXEDSZ;
548
549      /* Search for this question in the answer. */
550      a.p = abuf + HFIXEDSZ;
551      for (j = 0; j < a.qdcount; j++)
552        {
553          /* Decode the question in the answer. */
554          if (ares_expand_name(a.p, abuf, alen, &a.name, &a.namelen)
555              != ARES_SUCCESS)
556            {
557              free(q.name);
558              return 0;
559            }
560          a.p += a.namelen;
561          if (a.p + QFIXEDSZ > abuf + alen)
562            {
563              free(q.name);
564              free(a.name);
565              return 0;
566            }
567          a.type = DNS_QUESTION_TYPE(a.p);
568          a.dnsclass = DNS_QUESTION_CLASS(a.p);
569          a.p += QFIXEDSZ;
570
571          /* Compare the decoded questions. */
572          if (strcasecmp(q.name, a.name) == 0 && q.type == a.type
573              && q.dnsclass == a.dnsclass)
574            {
575              free(a.name);
576              break;
577            }
578          free(a.name);
579        }
580
581      free(q.name);
582      if (j == a.qdcount)
583        return 0;
584    }
585  return 1;
586}
587
588static void end_query(ares_channel channel, struct query *query, int status,
589                      unsigned char *abuf, int alen)
590{
591  struct query **q;
592  int i;
593
594  query->callback(query->arg, status, abuf, alen);
595  for (q = &channel->queries; *q; q = &(*q)->next)
596    {
597      if (*q == query)
598        break;
599    }
600  *q = query->next;
601  free(query->tcpbuf);
602  free(query->skip_server);
603  free(query);
604
605  /* Simple cleanup policy: if no queries are remaining, close all
606   * network sockets unless STAYOPEN is set.
607   */
608  if (!channel->queries && !(channel->flags & ARES_FLAG_STAYOPEN))
609    {
610      for (i = 0; i < channel->nservers; i++)
611        ares__close_sockets(&channel->servers[i]);
612    }
613}
Note: See TracBrowser for help on using the repository browser.