source: trunk/athena/lib/ares/ares_process.c @ 11893

Revision 11893, 15.4 KB checked in by ghudson, 26 years ago (diff)
Catch oversized UDP packets.
Line 
1/* Copyright 1998 by the Massachusetts Institute of Technology.
2 *
3 * Permission to use, copy, modify, and distribute this
4 * software and its documentation for any purpose and without
5 * fee is hereby granted, provided that the above copyright
6 * notice appear in all copies and that both that copyright
7 * notice and this permission notice appear in supporting
8 * documentation, and that the name of M.I.T. not be used in
9 * advertising or publicity pertaining to distribution of the
10 * software without specific, written prior permission.
11 * M.I.T. makes no representations about the suitability of
12 * this software for any purpose.  It is provided "as is"
13 * without express or implied warranty.
14 */
15
16static const char rcsid[] = "$Id: ares_process.c,v 1.3 1998-08-22 15:28:33 ghudson Exp $";
17
18#include <sys/types.h>
19#include <sys/socket.h>
20#include <sys/uio.h>
21#include <arpa/nameser.h>
22#include <string.h>
23#include <stdlib.h>
24#include <unistd.h>
25#include <fcntl.h>
26#include <time.h>
27#include <errno.h>
28#include "ares.h"
29#include "ares_dns.h"
30#include "ares_private.h"
31
32static void write_tcp_data(ares_channel channel, fd_set *write_fds,
33                           time_t now);
34static void read_tcp_data(ares_channel channel, fd_set *read_fds, time_t now);
35static void read_udp_packets(ares_channel channel, fd_set *read_fds,
36                             time_t now);
37static void process_timeouts(ares_channel channel, time_t now);
38static void process_answer(ares_channel channel, unsigned char *abuf,
39                           int alen, int whichserver, int tcp, int now);
40static void handle_error(ares_channel channel, int whichserver, time_t now);
41static void next_server(ares_channel channel, struct query *query, time_t now);
42static int open_tcp_socket(ares_channel channel, struct server_state *server);
43static int open_udp_socket(ares_channel channel, struct server_state *server);
44static int same_questions(const unsigned char *qbuf, int qlen,
45                          const unsigned char *abuf, int alen);
46static void end_query(ares_channel channel, struct query *query, int status,
47                      unsigned char *abuf, int alen);
48
49/* Something interesting happened on the wire, or there was a timeout.
50 * See what's up and respond accordingly.
51 */
52void ares_process(ares_channel channel, fd_set *read_fds, fd_set *write_fds)
53{
54  time_t now;
55
56  time(&now);
57  write_tcp_data(channel, write_fds, now);
58  read_tcp_data(channel, read_fds, now);
59  read_udp_packets(channel, read_fds, now);
60  process_timeouts(channel, now);
61}
62
63/* If any TCP sockets select true for writing, write out queued data
64 * we have for them.
65 */
66static void write_tcp_data(ares_channel channel, fd_set *write_fds, time_t now)
67{
68  struct server_state *server;
69  struct send_request *sendreq;
70  struct iovec *vec;
71  int i, n, count;
72
73  for (i = 0; i < channel->nservers; i++)
74    {
75      /* Make sure server has data to send and is selected in write_fds. */
76      server = &channel->servers[i];
77      if (!server->qhead || server->tcp_socket == -1
78          || !FD_ISSET(server->tcp_socket, write_fds))
79        continue;
80
81      /* Count the number of send queue items. */
82      n = 0;
83      for (sendreq = server->qhead; sendreq; sendreq = sendreq->next)
84        n++;
85
86      /* Allocate iovecs so we can send all our data at once. */
87      vec = malloc(n * sizeof(struct iovec));
88      if (vec)
89        {
90          /* Fill in the iovecs and send. */
91          n = 0;
92          for (sendreq = server->qhead; sendreq; sendreq = sendreq->next)
93            {
94              vec[n].iov_base = (char *) sendreq->data;
95              vec[n].iov_len = sendreq->len;
96              n++;
97            }
98          count = writev(server->tcp_socket, vec, n);
99          free(vec);
100          if (count < 0)
101            {
102              handle_error(channel, i, now);
103              continue;
104            }
105
106          /* Advance the send queue by as many bytes as we sent. */
107          while (count)
108            {
109              sendreq = server->qhead;
110              if (count >= sendreq->len)
111                {
112                  count -= sendreq->len;
113                  server->qhead = sendreq->next;
114                  if (server->qhead == NULL)
115                    server->qtail = NULL;
116                  free(sendreq);
117                }
118              else
119                {
120                  sendreq->data += count;
121                  sendreq->len -= count;
122                  break;
123                }
124            }
125        }
126      else
127        {
128          /* Can't allocate iovecs; just send the first request. */
129          sendreq = server->qhead;
130          count = write(server->tcp_socket, sendreq->data, sendreq->len);
131          if (count < 0)
132            {
133              handle_error(channel, i, now);
134              continue;
135            }
136
137          /* Advance the send queue by as many bytes as we sent. */
138          if (count == sendreq->len)
139            {
140              server->qhead = sendreq->next;
141              if (server->qhead == NULL)
142                server->qtail = NULL;
143              free(sendreq);
144            }
145          else
146            {
147              sendreq->data += count;
148              sendreq->len -= count;
149            }
150        }
151    }
152}
153
154/* If any TCP socket selects true for reading, read some data,
155 * allocate a buffer if we finish reading the length word, and process
156 * a packet if we finish reading one.
157 */
158static void read_tcp_data(ares_channel channel, fd_set *read_fds, time_t now)
159{
160  struct server_state *server;
161  int i, count;
162
163  for (i = 0; i < channel->nservers; i++)
164    {
165      /* Make sure the server has a socket and is selected in read_fds. */
166      server = &channel->servers[i];
167      if (server->tcp_socket == -1 || !FD_ISSET(server->tcp_socket, read_fds))
168        continue;
169
170      if (server->tcp_lenbuf_pos != 2)
171        {
172          /* We haven't yet read a length word, so read that (or
173           * what's left to read of it).
174           */
175          count = read(server->tcp_socket,
176                       server->tcp_lenbuf + server->tcp_lenbuf_pos,
177                       2 - server->tcp_lenbuf_pos);
178          if (count <= 0)
179            {
180              handle_error(channel, i, now);
181              continue;
182            }
183
184          server->tcp_lenbuf_pos += count;
185          if (server->tcp_lenbuf_pos == 2)
186            {
187              /* We finished reading the length word.  Decode the
188               * length and allocate a buffer for the data.
189               */
190              server->tcp_length = server->tcp_lenbuf[0] << 8
191                | server->tcp_lenbuf[1];
192              server->tcp_buffer = malloc(server->tcp_length);
193              if (!server->tcp_buffer)
194                handle_error(channel, i, now);
195              server->tcp_buffer_pos = 0;
196            }
197        }
198      else
199        {
200          /* Read data into the allocated buffer. */
201          count = read(server->tcp_socket,
202                       server->tcp_buffer + server->tcp_buffer_pos,
203                       server->tcp_length - server->tcp_buffer_pos);
204          if (count <= 0)
205            {
206              handle_error(channel, i, now);
207              continue;
208            }
209
210          server->tcp_buffer_pos += count;
211          if (server->tcp_buffer_pos == server->tcp_length)
212            {
213              /* We finished reading this answer; process it and
214               * prepare to read another length word.
215               */
216              process_answer(channel, server->tcp_buffer, server->tcp_length,
217                             i, 1, now);
218              free(server->tcp_buffer);
219              server->tcp_buffer = NULL;
220              server->tcp_lenbuf_pos = 0;
221            }
222        }
223    }
224}
225
226/* If any UDP sockets select true for reading, process them. */
227static void read_udp_packets(ares_channel channel, fd_set *read_fds,
228                             time_t now)
229{
230  struct server_state *server;
231  int i, count;
232  unsigned char buf[PACKETSZ + 1];
233
234  for (i = 0; i < channel->nservers; i++)
235    {
236      /* Make sure the server has a socket and is selected in read_fds. */
237      server = &channel->servers[i];
238      if (server->udp_socket == -1 || !FD_ISSET(server->udp_socket, read_fds))
239        continue;
240
241      count = recv(server->udp_socket, buf, sizeof(buf), 0);
242      if (count <= 0)
243        handle_error(channel, i, now);
244
245      process_answer(channel, buf, count, i, 0, now);
246    }
247}
248
249/* If any queries have timed out, note the timeout and move them on. */
250static void process_timeouts(ares_channel channel, time_t now)
251{
252  struct query *query;
253
254  for (query = channel->queries; query; query = query->next)
255    {
256      if (query->timeout != 0 && now >= query->timeout)
257        {
258          query->error_status = ARES_ETIMEOUT;
259          next_server(channel, query, now);
260        }
261    }
262}
263
264/* Handle an answer from a server. */
265static void process_answer(ares_channel channel, unsigned char *abuf,
266                           int alen, int whichserver, int tcp, int now)
267{
268  int id, tc, rcode;
269  struct query *query;
270
271  /* If there's no room in the answer for a header, we can't do much
272   * with it. */
273  if (alen < HFIXEDSZ)
274    return;
275
276  /* Grab the query ID, truncate bit, and response code from the packet. */
277  id = DNS_HEADER_QID(abuf);
278  tc = DNS_HEADER_TC(abuf);
279  rcode = DNS_HEADER_RCODE(abuf);
280
281  /* Find the query corresponding to this packet. */
282  for (query = channel->queries; query; query = query->next)
283    {
284      if (query->qid == id)
285        break;
286    }
287  if (!query)
288    return;
289
290  /* If we got a truncated UDP packet and are not ignoring truncation,
291   * don't accept the packet, and switch the query to TCP if we hadn't
292   * done so already.
293   */
294  if ((tc || alen > PACKETSZ) && !tcp && !(channel->flags & ARES_FLAG_IGNTC))
295    {
296      if (!query->using_tcp)
297        {
298          query->using_tcp = 1;
299          ares__send_query(channel, query, now);
300        }
301      return;
302    }
303
304  /* Limit alen to PACKETSZ if we aren't using TCP (only relevant if we
305   * are ignoring truncation.
306   */
307  if (alen > PACKETSZ && !tcp)
308    alen = PACKETSZ;
309
310  /* If we aren't passing through all error packets, discard packets
311   * with SERVFAIL, NOTIMP, or REFUSED response codes.
312   */
313  if (!(channel->flags & ARES_FLAG_NOCHECKRESP))
314    {
315      if (rcode == SERVFAIL || rcode == NOTIMP || rcode == REFUSED)
316        {
317          query->skip_server[whichserver] = 1;
318          if (query->server == whichserver)
319            next_server(channel, query, now);
320          return;
321        }
322      if (!same_questions(query->qbuf, query->qlen, abuf, alen))
323        {
324          if (query->server == whichserver)
325            next_server(channel, query, now);
326          return;
327        }
328    }
329
330  end_query(channel, query, ARES_SUCCESS, abuf, alen);
331}
332
333static void handle_error(ares_channel channel, int whichserver, time_t now)
334{
335  struct query *query;
336
337  /* Reset communications with this server. */
338  ares__close_sockets(&channel->servers[whichserver]);
339
340  /* Tell all queries talking to this server to move on and not try
341   * this server again.
342   */
343  for (query = channel->queries; query; query = query->next)
344    {
345      if (query->server == whichserver)
346        {
347          query->skip_server[whichserver] = 1;
348          next_server(channel, query, now);
349        }
350    }
351}
352
353static void next_server(ares_channel channel, struct query *query, time_t now)
354{
355  /* Advance to the next server or try. */
356  query->server++;
357  for (; query->try < channel->tries; query->try++)
358    {
359      for (; query->server < channel->nservers; query->server++)
360        {
361          if (!query->skip_server[query->server])
362            {
363              ares__send_query(channel, query, now);
364              return;
365            }
366        }
367      query->server = 0;
368
369      /* Only one try if we're using TCP. */
370      if (query->using_tcp)
371        break;
372    }
373  end_query(channel, query, query->error_status, NULL, 0);
374}
375
376void ares__send_query(ares_channel channel, struct query *query, time_t now)
377{
378  struct send_request *sendreq;
379  struct server_state *server;
380
381  server = &channel->servers[query->server];
382  if (query->using_tcp)
383    {
384      /* Make sure the TCP socket for this server is set up and queue
385       * a send request.
386       */
387      if (server->tcp_socket == -1)
388        {
389          if (open_tcp_socket(channel, server) == -1)
390            {
391              query->skip_server[query->server] = 1;
392              next_server(channel, query, now);
393            }
394        }
395      sendreq = malloc(sizeof(struct send_request));
396      if (!sendreq)
397        end_query(channel, query, ARES_ENOMEM, NULL, 0);
398      sendreq->data = query->tcpbuf;
399      sendreq->len = query->tcplen;
400      sendreq->next = NULL;
401      if (server->qtail)
402        server->qtail->next = sendreq;
403      else
404        server->qhead = sendreq;
405      server->qtail = sendreq;
406      query->timeout = 0;
407    }
408  else
409    {
410      if (server->udp_socket == -1)
411        {
412          if (open_udp_socket(channel, server) == -1)
413            {
414              query->skip_server[query->server] = 1;
415              next_server(channel, query, now);
416            }
417        }
418      if (send(server->udp_socket, query->qbuf, query->qlen, 0) == -1)
419        {
420          query->skip_server[query->server] = 1;
421          next_server(channel, query, now);
422        }
423      query->timeout = now
424          + ((query->try == 0) ? channel->timeout
425             : channel->timeout << query->try / channel->nservers);
426    }
427}
428
429static int open_tcp_socket(ares_channel channel, struct server_state *server)
430{
431  int s, flags;
432  struct sockaddr_in sin;
433
434  /* Acquire a socket. */
435  s = socket(AF_INET, SOCK_STREAM, 0);
436  if (s == -1)
437    return -1;
438
439  /* Set the socket non-blocking. */
440  if (fcntl(s, F_GETFL, &flags) == -1)
441    {
442      close(s);
443      return -1;
444    }
445  flags &= O_NONBLOCK;
446  if (fcntl(s, F_SETFL, flags) == -1)
447    {
448      close(s);
449      return -1;
450    }
451
452  /* Connect to the server. */
453  memset(&sin, 0, sizeof(sin));
454  sin.sin_family = AF_INET;
455  sin.sin_addr = server->addr;
456  sin.sin_port = channel->tcp_port;
457  if (connect(s, (struct sockaddr *) &sin, sizeof(sin)) == -1
458      && errno != EINPROGRESS)
459    {
460      close(s);
461      return -1;
462    }
463
464  server->tcp_socket = s;
465  return 0;
466}
467
468static int open_udp_socket(ares_channel channel, struct server_state *server)
469{
470  int s;
471  struct sockaddr_in sin;
472
473  /* Acquire a socket. */
474  s = socket(AF_INET, SOCK_DGRAM, 0);
475  if (s == -1)
476    return -1;
477
478  /* Connect to the server. */
479  memset(&sin, 0, sizeof(sin));
480  sin.sin_family = AF_INET;
481  sin.sin_addr = server->addr;
482  sin.sin_port = channel->udp_port;
483  if (connect(s, (struct sockaddr *) &sin, sizeof(sin)) == -1)
484    {
485      close(s);
486      return -1;
487    }
488
489  server->udp_socket = s;
490  return 0;
491}
492
493static int same_questions(const unsigned char *qbuf, int qlen,
494                          const unsigned char *abuf, int alen)
495{
496  struct {
497    const unsigned char *p;
498    int qdcount;
499    char *name;
500    int namelen;
501    int type;
502    int class;
503  } q, a;
504  int i, j;
505
506  if (qlen < HFIXEDSZ || alen < HFIXEDSZ)
507    return 0;
508
509  /* Extract qdcount from the request and reply buffers and compare them. */
510  q.qdcount = DNS_HEADER_QDCOUNT(qbuf);
511  a.qdcount = DNS_HEADER_QDCOUNT(abuf);
512  if (q.qdcount != a.qdcount)
513    return 0;
514
515  /* For each question in qbuf, find it in abuf. */
516  q.p = qbuf + HFIXEDSZ;
517  for (i = 0; i < q.qdcount; i++)
518    {
519      /* Decode the question in the query. */
520      if (ares_expand_name(q.p, qbuf, qlen, &q.name, &q.namelen)
521          != ARES_SUCCESS)
522        return 0;
523      q.p += q.namelen;
524      if (q.p + QFIXEDSZ > qbuf + qlen)
525        {
526          free(q.name);
527          return 0;
528        }
529      q.type = DNS_QUESTION_TYPE(q.p);
530      q.class = DNS_QUESTION_CLASS(q.p);
531      q.p += QFIXEDSZ;
532
533      /* Search for this question in the answer. */
534      a.p = abuf + HFIXEDSZ;
535      for (j = 0; j < a.qdcount; j++)
536        {
537          /* Decode the question in the answer. */
538          if (ares_expand_name(a.p, abuf, alen, &a.name, &a.namelen)
539              != ARES_SUCCESS)
540            {
541              free(q.name);
542              return 0;
543            }
544          a.p += a.namelen;
545          if (a.p + QFIXEDSZ > abuf + alen)
546            {
547              free(q.name);
548              free(a.name);
549              return 0;
550            }
551          a.type = DNS_QUESTION_TYPE(a.p);
552          a.class = DNS_QUESTION_CLASS(a.p);
553          a.p += QFIXEDSZ;
554
555          /* Compare the decoded questions. */
556          if (strcasecmp(q.name, a.name) == 0 && q.type == a.type
557              && q.class == a.class)
558            {
559              free(a.name);
560              break;
561            }
562          free(a.name);
563        }
564
565      free(q.name);
566      if (j == a.qdcount)
567        return 0;
568    }
569  return 1;
570}
571
572static void end_query(ares_channel channel, struct query *query, int status,
573                      unsigned char *abuf, int alen)
574{
575  struct query **q;
576  int i;
577
578  query->callback(query->arg, status, abuf, alen);
579  for (q = &channel->queries; *q; q = &(*q)->next)
580    {
581      if (*q == query)
582        break;
583    }
584  *q = query->next;
585  free(query->tcpbuf);
586  free(query->skip_server);
587  free(query);
588
589  /* Simple cleanup policy: if no queries are remaining, close all
590   * network sockets unless STAYOPEN is set.
591   */
592  if (!channel->queries && !(channel->flags & ARES_FLAG_STAYOPEN))
593    {
594      for (i = 0; i < channel->nservers; i++)
595        ares__close_sockets(&channel->servers[i]);
596    }
597}
Note: See TracBrowser for help on using the repository browser.