source: trunk/third/gmp/demos/pexpr.c @ 18191

Revision 18191, 30.0 KB checked in by ghudson, 22 years ago (diff)
This commit was generated by cvs2svn to compensate for changes in r18190, which included commits to RCS files with non-trunk default branches.
Line 
1/* Program for computing integer expressions using the GNU Multiple Precision
2   Arithmetic Library.
3
4Copyright 1997, 1999, 2000, 2001, 2002 Free Software Foundation, Inc.
5
6This program is free software; you can redistribute it and/or modify it under
7the terms of the GNU General Public License as published by the Free Software
8Foundation; either version 2 of the License, or (at your option) any later
9version.
10
11This program is distributed in the hope that it will be useful, but WITHOUT ANY
12WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
13PARTICULAR PURPOSE.  See the GNU General Public License for more details.
14
15You should have received a copy of the GNU General Public License along with
16this program; if not, write to the Free Software Foundation, Inc., 59 Temple
17Place - Suite 330, Boston, MA 02111-1307, USA.  */
18
19
20/* This expressions evaluator works by building an expression tree (using a
21   recursive descent parser) which is then evaluated.  The expression tree is
22   useful since we want to optimize certain expressions (like a^b % c).
23
24   Usage: pexpr [options] expr ...
25   (Assuming you called the executable `pexpr' of course.)
26
27   Command line options:
28
29   -b        print output in binary
30   -o        print output in octal
31   -d        print output in decimal (the default)
32   -x        print output in hexadecimal
33   -b<NUM>   print output in base NUM
34   -t        print timing information
35   -html     output html
36   -wml      output wml
37   -nosplit  do not split long lines each 60th digit
38*/
39
40/* Define LIMIT_RESOURCE_USAGE if you want to make sure the program doesn't
41   use up extensive resources (cpu, memory).  Useful for the GMP demo on the
42   GMP web site, since we cannot load the server too much.  */
43
44#include "pexpr-config.h"
45
46#include <string.h>
47#include <stdio.h>
48#include <stdlib.h>
49#include <setjmp.h>
50#include <signal.h>
51#include <ctype.h>
52
53#include <time.h>
54#include <sys/types.h>
55#include <sys/time.h>
56#if HAVE_SYS_RESOURCE_H
57#include <sys/resource.h>
58#endif
59
60#include "gmp.h"
61
62/* SunOS 4 and HPUX 9 don't define a canonical SIGSTKSZ, use a default. */
63#ifndef SIGSTKSZ
64#define SIGSTKSZ  4096
65#endif
66
67
68#define TIME(t,func)                                                    \
69  do { int __t0, __times, __t, __tmp;                                   \
70    __times = 1;                                                        \
71    __t0 = cputime ();                                                  \
72    {func;}                                                             \
73    __tmp = cputime () - __t0;                                          \
74    while (__tmp < 100)                                                 \
75      {                                                                 \
76        __times <<= 1;                                                  \
77        __t0 = cputime ();                                              \
78        for (__t = 0; __t < __times; __t++)                             \
79          {func;}                                                       \
80        __tmp = cputime () - __t0;                                      \
81      }                                                                 \
82    (t) = (double) __tmp / __times;                                     \
83  } while (0)
84
85/* GMP version 1.x compatibility.  */
86#if ! (__GNU_MP_VERSION >= 2)
87typedef MP_INT __mpz_struct;
88typedef __mpz_struct mpz_t[1];
89typedef __mpz_struct *mpz_ptr;
90#define mpz_fdiv_q      mpz_div
91#define mpz_fdiv_r      mpz_mod
92#define mpz_tdiv_q_2exp mpz_div_2exp
93#define mpz_sgn(Z) ((Z)->size < 0 ? -1 : (Z)->size > 0)
94#endif
95
96/* GMP version 2.0 compatibility.  */
97#if ! (__GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1)
98#define mpz_swap(a,b) \
99  do { __mpz_struct __t; __t = *a; *a = *b; *b = __t;} while (0)
100#endif
101
102jmp_buf errjmpbuf;
103
104enum op_t {NOP, LIT, NEG, NOT, PLUS, MINUS, MULT, DIV, MOD, REM, INVMOD, POW,
105           AND, IOR, XOR, SLL, SRA, POPCNT, HAMDIST, GCD, LCM, SQRT, ROOT, FAC,
106           LOG, LOG2, FERMAT, MERSENNE, FIBONACCI, RANDOM, NEXTPRIME, BINOM};
107
108/* Type for the expression tree.  */
109struct expr
110{
111  enum op_t op;
112  union
113  {
114    struct {struct expr *lhs, *rhs;} ops;
115    mpz_t val;
116  } operands;
117};
118
119typedef struct expr *expr_t;
120
121void cleanup_and_exit __GMP_PROTO ((int));
122
123char *skipspace __GMP_PROTO ((char *));
124void makeexp __GMP_PROTO ((expr_t *, enum op_t, expr_t, expr_t));
125void free_expr __GMP_PROTO ((expr_t));
126char *expr __GMP_PROTO ((char *, expr_t *));
127char *term __GMP_PROTO ((char *, expr_t *));
128char *power __GMP_PROTO ((char *, expr_t *));
129char *factor __GMP_PROTO ((char *, expr_t *));
130int match __GMP_PROTO ((char *, char *));
131int matchp __GMP_PROTO ((char *, char *));
132int cputime __GMP_PROTO ((void));
133
134void mpz_eval_expr __GMP_PROTO ((mpz_ptr, expr_t));
135void mpz_eval_mod_expr __GMP_PROTO ((mpz_ptr, expr_t, mpz_ptr));
136
137char *error;
138int flag_print = 1;
139int print_timing = 0;
140int flag_html = 0;
141int flag_wml = 0;
142int flag_splitup_output = 0;
143char *newline = "";
144gmp_randstate_t rstate;
145
146
147
148/* cputime() returns user CPU time measured in milliseconds.  */
149#if ! HAVE_CPUTIME
150#if HAVE_GETRUSAGE
151int
152cputime (void)
153{
154  struct rusage rus;
155
156  getrusage (0, &rus);
157  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
158}
159#else
160#if HAVE_CLOCK
161int
162cputime (void)
163{
164  if (CLOCKS_PER_SEC < 100000)
165    return clock () * 1000 / CLOCKS_PER_SEC;
166  return clock () / (CLOCKS_PER_SEC / 1000);
167}
168#else
169int
170cputime (void)
171{
172  return 0;
173}
174#endif
175#endif
176#endif
177
178
179int
180stack_downwards_helper (char *xp)
181{
182  char  y;
183  return &y < xp;
184}
185int
186stack_downwards_p (void)
187{
188  char  x;
189  return stack_downwards_helper (&x);
190}
191
192
193void
194setup_error_handler (void)
195{
196#if HAVE_SIGACTION
197  struct sigaction act;
198  act.sa_handler = cleanup_and_exit;
199  sigemptyset (&(act.sa_mask));
200#define SIGNAL(sig)  sigaction (sig, &act, NULL)
201#else
202  struct { int sa_flags } act;
203#define SIGNAL(sig)  signal (sig, cleanup_and_exit)
204#endif
205  act.sa_flags = 0;
206
207  /* Set up a stack for signal handling.  A typical cause of error is stack
208     overflow, and in such situation a signal can not be delivered on the
209     overflown stack.  */
210#if HAVE_SIGALTSTACK
211  {
212    /* AIX uses stack_t, MacOS uses struct sigaltstack, various other
213       systems have both. */
214#if HAVE_STACK_T
215    stack_t s;
216#else
217    struct sigaltstack s;
218#endif
219    s.ss_sp = malloc (SIGSTKSZ);
220    s.ss_size = SIGSTKSZ;
221    s.ss_flags = 0;
222    if (sigaltstack (&s, NULL) != 0)
223      perror("sigaltstack");
224    act.sa_flags = SA_ONSTACK;
225  }
226#else
227#if HAVE_SIGSTACK
228  {
229    struct sigstack s;
230    s.ss_sp = malloc (SIGSTKSZ);
231    if (stack_downwards_p ())
232      s.ss_sp += SIGSTKSZ;
233    s.ss_onstack = 0;
234    if (sigstack (&s, NULL) != 0)
235      perror("sigstack");
236    act.sa_flags = SA_ONSTACK;
237  }
238#else
239#endif
240#endif
241
242#ifdef LIMIT_RESOURCE_USAGE
243  {
244    struct rlimit limit;
245
246    limit.rlim_cur = limit.rlim_max = 0;
247    setrlimit (RLIMIT_CORE, &limit);
248
249    limit.rlim_cur = 3;
250    limit.rlim_max = 4;
251    setrlimit (RLIMIT_CPU, &limit);
252
253    limit.rlim_cur = limit.rlim_max = 16 * 1024 * 1024;
254    setrlimit (RLIMIT_DATA, &limit);
255
256    getrlimit (RLIMIT_STACK, &limit);
257    limit.rlim_cur = 4 * 1024 * 1024;
258    setrlimit (RLIMIT_STACK, &limit);
259
260    SIGNAL (SIGXCPU);
261  }
262#endif /* LIMIT_RESOURCE_USAGE */
263
264  SIGNAL (SIGILL);
265  SIGNAL (SIGSEGV);
266#ifdef SIGBUS /* not in mingw */
267  SIGNAL (SIGBUS);
268#endif
269  SIGNAL (SIGFPE);
270  SIGNAL (SIGABRT);
271}
272
273int
274main (int argc, char **argv)
275{
276  struct expr *e;
277  int i;
278  mpz_t r;
279  int errcode = 0;
280  char *str;
281  int base = 10;
282
283  setup_error_handler ();
284
285  gmp_randinit (rstate, GMP_RAND_ALG_LC, 128);
286
287  {
288#if HAVE_GETTIMEOFDAY
289    struct timeval tv;
290    gettimeofday (&tv, NULL);
291    gmp_randseed_ui (rstate, tv.tv_sec + tv.tv_usec);
292#else
293    time_t t;
294    time (&t);
295    gmp_randseed_ui (rstate, t);
296#endif
297  }
298
299  mpz_init (r);
300
301  while (argc > 1 && argv[1][0] == '-')
302    {
303      char *arg = argv[1];
304
305      if (arg[1] >= '0' && arg[1] <= '9')
306        break;
307
308      if (arg[1] == 't')
309        print_timing = 1;
310      else if (arg[1] == 'b' && arg[2] >= '0' && arg[2] <= '9')
311        {
312          base = atoi (arg + 2);
313          if (base < 2 || base > 36)
314            {
315              fprintf (stderr, "error: invalid output base\n");
316              exit (-1);
317            }
318        }
319      else if (arg[1] == 'b' && arg[2] == 0)
320        base = 2;
321      else if (arg[1] == 'x' && arg[2] == 0)
322        base = 16;
323      else if (arg[1] == 'X' && arg[2] == 0)
324        base = -16;
325      else if (arg[1] == 'o' && arg[2] == 0)
326        base = 8;
327      else if (arg[1] == 'd' && arg[2] == 0)
328        base = 10;
329      else if (strcmp (arg, "-html") == 0)
330        {
331          flag_html = 1;
332          newline = "<br>";
333        }
334      else if (strcmp (arg, "-wml") == 0)
335        {
336          flag_wml = 1;
337          newline = "<br/>";
338        }
339      else if (strcmp (arg, "-split") == 0)
340        {
341          flag_splitup_output = 1;
342        }
343      else if (strcmp (arg, "-noprint") == 0)
344        {
345          flag_print = 0;
346        }
347      else
348        {
349          fprintf (stderr, "error: unknown option `%s'\n", arg);
350          exit (-1);
351        }
352      argv++;
353      argc--;
354    }
355
356  for (i = 1; i < argc; i++)
357    {
358      int s;
359      int jmpval;
360
361      /* Set up error handler for parsing expression.  */
362      jmpval = setjmp (errjmpbuf);
363      if (jmpval != 0)
364        {
365          fprintf (stderr, "error: %s%s\n", error, newline);
366          fprintf (stderr, "       %s%s\n", argv[i], newline);
367          if (! flag_html)
368            {
369              /* ??? Dunno how to align expression position with arrow in
370                 HTML ??? */
371              fprintf (stderr, "       ");
372              for (s = jmpval - (long) argv[i]; --s >= 0; )
373                putc (' ', stderr);
374              fprintf (stderr, "^\n");
375            }
376
377          errcode |= 1;
378          continue;
379        }
380
381      str = expr (argv[i], &e);
382
383      if (str[0] != 0)
384        {
385          fprintf (stderr,
386                   "error: garbage where end of expression expected%s\n",
387                   newline);
388          fprintf (stderr, "       %s%s\n", argv[i], newline);
389          if (! flag_html)
390            {
391              /* ??? Dunno how to align expression position with arrow in
392                 HTML ??? */
393              fprintf (stderr, "        ");
394              for (s = str - argv[i]; --s; )
395                putc (' ', stderr);
396              fprintf (stderr, "^\n");
397            }
398
399          errcode |= 1;
400          free_expr (e);
401          continue;
402        }
403
404      /* Set up error handler for evaluating expression.  */
405      if (setjmp (errjmpbuf))
406        {
407          fprintf (stderr, "error: %s%s\n", error, newline);
408          fprintf (stderr, "       %s%s\n", argv[i], newline);
409          if (! flag_html)
410            {
411              /* ??? Dunno how to align expression position with arrow in
412                 HTML ??? */
413              fprintf (stderr, "       ");
414              for (s = str - argv[i]; --s >= 0; )
415                putc (' ', stderr);
416              fprintf (stderr, "^\n");
417            }
418
419          errcode |= 2;
420          continue;
421        }
422
423      if (print_timing)
424        {
425          double t;
426          TIME (t, mpz_eval_expr (r, e));
427          printf ("computation took %.2f ms%s\n", t, newline);
428        }
429      else
430        mpz_eval_expr (r, e);
431
432      if (flag_print)
433        {
434          size_t out_len;
435          char *tmp, *s;
436
437          out_len = mpz_sizeinbase (r, base >= 0 ? base : -base) + 2;
438          tmp = malloc (out_len);
439
440          if (print_timing)
441            {
442              double t;
443              printf ("output conversion ");
444              TIME (t, mpz_get_str (tmp, base, r));
445              printf ("took %.2f ms%s\n", t, newline);
446            }
447          else
448            mpz_get_str (tmp, base, r);
449
450          out_len = strlen (tmp);
451          if (flag_splitup_output)
452            {
453              for (s = tmp; out_len > 60; s += 60)
454                {
455                  fwrite (s, 1, 60, stdout);
456                  printf ("%s\n", newline);
457                  out_len -= 60;
458                }
459
460              fwrite (s, 1, out_len, stdout);
461            }
462          else
463            {
464              fwrite (tmp, 1, out_len, stdout);
465            }
466
467          free (tmp);
468          printf ("%s\n", newline);
469        }
470      else
471        {
472          printf ("result is approximately %ld digits%s\n",
473                  (long) mpz_sizeinbase (r, 10), newline);
474        }
475
476      free_expr (e);
477    }
478
479  exit (errcode);
480}
481
482char *
483expr (char *str, expr_t *e)
484{
485  expr_t e2;
486
487  str = skipspace (str);
488  if (str[0] == '+')
489    {
490      str = term (str + 1, e);
491    }
492  else if (str[0] == '-')
493    {
494      str = term (str + 1, e);
495      makeexp (e, NEG, *e, NULL);
496    }
497  else if (str[0] == '~')
498    {
499      str = term (str + 1, e);
500      makeexp (e, NOT, *e, NULL);
501    }
502  else
503    {
504      str = term (str, e);
505    }
506
507  for (;;)
508    {
509      str = skipspace (str);
510      switch (str[0])
511        {
512        case 'p':
513          if (match ("plus", str))
514            {
515              str = term (str + 4, &e2);
516              makeexp (e, PLUS, *e, e2);
517            }
518          else
519            return str;
520          break;
521        case 'm':
522          if (match ("minus", str))
523            {
524              str = term (str + 5, &e2);
525              makeexp (e, MINUS, *e, e2);
526            }
527          else
528            return str;
529          break;
530        case '+':
531          str = term (str + 1, &e2);
532          makeexp (e, PLUS, *e, e2);
533          break;
534        case '-':
535          str = term (str + 1, &e2);
536          makeexp (e, MINUS, *e, e2);
537          break;
538        default:
539          return str;
540        }
541    }
542}
543
544char *
545term (char *str, expr_t *e)
546{
547  expr_t e2;
548
549  str = power (str, e);
550  for (;;)
551    {
552      str = skipspace (str);
553      switch (str[0])
554        {
555        case 'm':
556          if (match ("mul", str))
557            {
558              str = power (str + 3, &e2);
559              makeexp (e, MULT, *e, e2);
560              break;
561            }
562          if (match ("mod", str))
563            {
564              str = power (str + 3, &e2);
565              makeexp (e, MOD, *e, e2);
566              break;
567            }
568          return str;
569        case 'd':
570          if (match ("div", str))
571            {
572              str = power (str + 3, &e2);
573              makeexp (e, DIV, *e, e2);
574              break;
575            }
576          return str;
577        case 'r':
578          if (match ("rem", str))
579            {
580              str = power (str + 3, &e2);
581              makeexp (e, REM, *e, e2);
582              break;
583            }
584          return str;
585        case 'i':
586          if (match ("invmod", str))
587            {
588              str = power (str + 6, &e2);
589              makeexp (e, REM, *e, e2);
590              break;
591            }
592          return str;
593        case 't':
594          if (match ("times", str))
595            {
596              str = power (str + 5, &e2);
597              makeexp (e, MULT, *e, e2);
598              break;
599            }
600          if (match ("thru", str))
601            {
602              str = power (str + 4, &e2);
603              makeexp (e, DIV, *e, e2);
604              break;
605            }
606          if (match ("through", str))
607            {
608              str = power (str + 7, &e2);
609              makeexp (e, DIV, *e, e2);
610              break;
611            }
612          return str;
613        case '*':
614          str = power (str + 1, &e2);
615          makeexp (e, MULT, *e, e2);
616          break;
617        case '/':
618          str = power (str + 1, &e2);
619          makeexp (e, DIV, *e, e2);
620          break;
621        case '%':
622          str = power (str + 1, &e2);
623          makeexp (e, MOD, *e, e2);
624          break;
625        default:
626          return str;
627        }
628    }
629}
630
631char *
632power (char *str, expr_t *e)
633{
634  expr_t e2;
635
636  str = factor (str, e);
637  while (str[0] == '!')
638    {
639      str++;
640      makeexp (e, FAC, *e, NULL);
641    }
642  str = skipspace (str);
643  if (str[0] == '^')
644    {
645      str = power (str + 1, &e2);
646      makeexp (e, POW, *e, e2);
647    }
648  return str;
649}
650
651int
652match (char *s, char *str)
653{
654  char *ostr = str;
655  int i;
656
657  for (i = 0; s[i] != 0; i++)
658    {
659      if (str[i] != s[i])
660        return 0;
661    }
662  str = skipspace (str + i);
663  return str - ostr;
664}
665
666int
667matchp (char *s, char *str)
668{
669  char *ostr = str;
670  int i;
671
672  for (i = 0; s[i] != 0; i++)
673    {
674      if (str[i] != s[i])
675        return 0;
676    }
677  str = skipspace (str + i);
678  if (str[0] == '(')
679    return str - ostr + 1;
680  return 0;
681}
682
683struct functions
684{
685  char *spelling;
686  enum op_t op;
687  int arity; /* 1 or 2 means real arity; 0 means arbitrary.  */
688};
689
690struct functions fns[] =
691{
692  {"sqrt", SQRT, 1},
693#if __GNU_MP_VERSION >= 2
694  {"root", ROOT, 2},
695  {"popc", POPCNT, 1},
696  {"hamdist", HAMDIST, 2},
697#endif
698  {"gcd", GCD, 0},
699#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
700  {"lcm", LCM, 0},
701#endif
702  {"and", AND, 0},
703  {"ior", IOR, 0},
704#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
705  {"xor", XOR, 0},
706#endif
707  {"plus", PLUS, 0},
708  {"pow", POW, 2},
709  {"minus", MINUS, 2},
710  {"mul", MULT, 0},
711  {"div", DIV, 2},
712  {"mod", MOD, 2},
713  {"rem", REM, 2},
714#if __GNU_MP_VERSION >= 2
715  {"invmod", INVMOD, 2},
716#endif
717  {"log", LOG, 2},
718  {"log2", LOG2, 1},
719  {"F", FERMAT, 1},
720  {"M", MERSENNE, 1},
721  {"fib", FIBONACCI, 1},
722  {"Fib", FIBONACCI, 1},
723  {"random", RANDOM, 1},
724  {"nextprime", NEXTPRIME, 1},
725  {"binom", BINOM, 2},
726  {"binomial", BINOM, 2},
727  {"", NOP, 0}
728};
729
730char *
731factor (char *str, expr_t *e)
732{
733  expr_t e1, e2;
734
735  str = skipspace (str);
736
737  if (isalpha (str[0]))
738    {
739      int i;
740      int cnt;
741
742      for (i = 0; fns[i].op != NOP; i++)
743        {
744          if (fns[i].arity == 1)
745            {
746              cnt = matchp (fns[i].spelling, str);
747              if (cnt != 0)
748                {
749                  str = expr (str + cnt, &e1);
750                  str = skipspace (str);
751                  if (str[0] != ')')
752                    {
753                      error = "expected `)'";
754                      longjmp (errjmpbuf, (int) (long) str);
755                    }
756                  makeexp (e, fns[i].op, e1, NULL);
757                  return str + 1;
758                }
759            }
760        }
761
762      for (i = 0; fns[i].op != NOP; i++)
763        {
764          if (fns[i].arity != 1)
765            {
766              cnt = matchp (fns[i].spelling, str);
767              if (cnt != 0)
768                {
769                  str = expr (str + cnt, &e1);
770                  str = skipspace (str);
771
772                  if (str[0] != ',')
773                    {
774                      error = "expected `,' and another operand";
775                      longjmp (errjmpbuf, (int) (long) str);
776                    }
777
778                  str = skipspace (str + 1);
779                  str = expr (str, &e2);
780                  str = skipspace (str);
781
782                  if (fns[i].arity == 0)
783                    {
784                      while (str[0] == ',')
785                        {
786                          makeexp (&e1, fns[i].op, e1, e2);
787                          str = skipspace (str + 1);
788                          str = expr (str, &e2);
789                          str = skipspace (str);
790                        }
791                    }
792
793                  if (str[0] != ')')
794                    {
795                      error = "expected `)'";
796                      longjmp (errjmpbuf, (int) (long) str);
797                    }
798
799                  makeexp (e, fns[i].op, e1, e2);
800                  return str + 1;
801                }
802            }
803        }
804    }
805
806  if (str[0] == '(')
807    {
808      str = expr (str + 1, e);
809      str = skipspace (str);
810      if (str[0] != ')')
811        {
812          error = "expected `)'";
813          longjmp (errjmpbuf, (int) (long) str);
814        }
815      str++;
816    }
817  else if (str[0] >= '0' && str[0] <= '9')
818    {
819      expr_t res;
820      char *s, *sc;
821
822      res = malloc (sizeof (struct expr));
823      res -> op = LIT;
824      mpz_init (res->operands.val);
825
826      s = str;
827      while (isalnum (str[0]))
828        str++;
829      sc = malloc (str - s + 1);
830      memcpy (sc, s, str - s);
831      sc[str - s] = 0;
832
833      mpz_set_str (res->operands.val, sc, 0);
834      *e = res;
835      free (sc);
836    }
837  else
838    {
839      error = "operand expected";
840      longjmp (errjmpbuf, (int) (long) str);
841    }
842  return str;
843}
844
845char *
846skipspace (char *str)
847{
848  while (str[0] == ' ')
849    str++;
850  return str;
851}
852
853/* Make a new expression with operation OP and right hand side
854   RHS and left hand side lhs.  Put the result in R.  */
855void
856makeexp (expr_t *r, enum op_t op, expr_t lhs, expr_t rhs)
857{
858  expr_t res;
859  res = malloc (sizeof (struct expr));
860  res -> op = op;
861  res -> operands.ops.lhs = lhs;
862  res -> operands.ops.rhs = rhs;
863  *r = res;
864  return;
865}
866
867/* Free the memory used by expression E.  */
868void
869free_expr (expr_t e)
870{
871  if (e->op != LIT)
872    {
873      free_expr (e->operands.ops.lhs);
874      if (e->operands.ops.rhs != NULL)
875        free_expr (e->operands.ops.rhs);
876    }
877  else
878    {
879      mpz_clear (e->operands.val);
880    }
881}
882
883/* Evaluate the expression E and put the result in R.  */
884void
885mpz_eval_expr (mpz_ptr r, expr_t e)
886{
887  mpz_t lhs, rhs;
888
889  switch (e->op)
890    {
891    case LIT:
892      mpz_set (r, e->operands.val);
893      return;
894    case PLUS:
895      mpz_init (lhs); mpz_init (rhs);
896      mpz_eval_expr (lhs, e->operands.ops.lhs);
897      mpz_eval_expr (rhs, e->operands.ops.rhs);
898      mpz_add (r, lhs, rhs);
899      mpz_clear (lhs); mpz_clear (rhs);
900      return;
901    case MINUS:
902      mpz_init (lhs); mpz_init (rhs);
903      mpz_eval_expr (lhs, e->operands.ops.lhs);
904      mpz_eval_expr (rhs, e->operands.ops.rhs);
905      mpz_sub (r, lhs, rhs);
906      mpz_clear (lhs); mpz_clear (rhs);
907      return;
908    case MULT:
909      mpz_init (lhs); mpz_init (rhs);
910      mpz_eval_expr (lhs, e->operands.ops.lhs);
911      mpz_eval_expr (rhs, e->operands.ops.rhs);
912      mpz_mul (r, lhs, rhs);
913      mpz_clear (lhs); mpz_clear (rhs);
914      return;
915    case DIV:
916      mpz_init (lhs); mpz_init (rhs);
917      mpz_eval_expr (lhs, e->operands.ops.lhs);
918      mpz_eval_expr (rhs, e->operands.ops.rhs);
919      mpz_fdiv_q (r, lhs, rhs);
920      mpz_clear (lhs); mpz_clear (rhs);
921      return;
922    case MOD:
923      mpz_init (rhs);
924      mpz_eval_expr (rhs, e->operands.ops.rhs);
925      mpz_abs (rhs, rhs);
926      mpz_eval_mod_expr (r, e->operands.ops.lhs, rhs);
927      mpz_clear (rhs);
928      return;
929    case REM:
930      /* Check if lhs operand is POW expression and optimize for that case.  */
931      if (e->operands.ops.lhs->op == POW)
932        {
933          mpz_t powlhs, powrhs;
934          mpz_init (powlhs);
935          mpz_init (powrhs);
936          mpz_init (rhs);
937          mpz_eval_expr (powlhs, e->operands.ops.lhs->operands.ops.lhs);
938          mpz_eval_expr (powrhs, e->operands.ops.lhs->operands.ops.rhs);
939          mpz_eval_expr (rhs, e->operands.ops.rhs);
940          mpz_powm (r, powlhs, powrhs, rhs);
941          if (mpz_cmp_si (rhs, 0L) < 0)
942            mpz_neg (r, r);
943          mpz_clear (powlhs);
944          mpz_clear (powrhs);
945          mpz_clear (rhs);
946          return;
947        }
948
949      mpz_init (lhs); mpz_init (rhs);
950      mpz_eval_expr (lhs, e->operands.ops.lhs);
951      mpz_eval_expr (rhs, e->operands.ops.rhs);
952      mpz_fdiv_r (r, lhs, rhs);
953      mpz_clear (lhs); mpz_clear (rhs);
954      return;
955#if __GNU_MP_VERSION >= 2
956    case INVMOD:
957      mpz_init (lhs); mpz_init (rhs);
958      mpz_eval_expr (lhs, e->operands.ops.lhs);
959      mpz_eval_expr (rhs, e->operands.ops.rhs);
960      mpz_invert (r, lhs, rhs);
961      mpz_clear (lhs); mpz_clear (rhs);
962      return;
963#endif
964    case POW:
965      mpz_init (lhs); mpz_init (rhs);
966      mpz_eval_expr (lhs, e->operands.ops.lhs);
967      if (mpz_cmpabs_ui (lhs, 1) <= 0)
968        {
969          /* For 0^rhs and 1^rhs, we just need to verify that
970             rhs is well-defined.  For (-1)^rhs we need to
971             determine (rhs mod 2).  For simplicity, compute
972             (rhs mod 2) for all three cases.  */
973          expr_t two, et;
974          two = malloc (sizeof (struct expr));
975          two -> op = LIT;
976          mpz_init_set_ui (two->operands.val, 2L);
977          makeexp (&et, MOD, e->operands.ops.rhs, two);
978          e->operands.ops.rhs = et;
979        }
980
981      mpz_eval_expr (rhs, e->operands.ops.rhs);
982      if (mpz_cmp_si (rhs, 0L) == 0)
983        /* x^0 is 1 */
984        mpz_set_ui (r, 1L);
985      else if (mpz_cmp_si (lhs, 0L) == 0)
986        /* 0^y (where y != 0) is 0 */
987        mpz_set_ui (r, 0L);
988      else if (mpz_cmp_ui (lhs, 1L) == 0)
989        /* 1^y is 1 */
990        mpz_set_ui (r, 1L);
991      else if (mpz_cmp_si (lhs, -1L) == 0)
992        /* (-1)^y just depends on whether y is even or odd */
993        mpz_set_si (r, (mpz_get_ui (rhs) & 1) ? -1L : 1L);
994      else if (mpz_cmp_si (rhs, 0L) < 0)
995        /* x^(-n) is 0 */
996        mpz_set_ui (r, 0L);
997      else
998        {
999          unsigned long int cnt;
1000          unsigned long int y;
1001          /* error if exponent does not fit into an unsigned long int.  */
1002          if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1003            goto pow_err;
1004
1005          y = mpz_get_ui (rhs);
1006          /* x^y == (x/(2^c))^y * 2^(c*y) */
1007#if __GNU_MP_VERSION >= 2
1008          cnt = mpz_scan1 (lhs, 0);
1009#else
1010          cnt = 0;
1011#endif
1012          if (cnt != 0)
1013            {
1014              if (y * cnt / cnt != y)
1015                goto pow_err;
1016              mpz_tdiv_q_2exp (lhs, lhs, cnt);
1017              mpz_pow_ui (r, lhs, y);
1018              mpz_mul_2exp (r, r, y * cnt);
1019            }
1020          else
1021            mpz_pow_ui (r, lhs, y);
1022        }
1023      mpz_clear (lhs); mpz_clear (rhs);
1024      return;
1025    pow_err:
1026      error = "result of `pow' operator too large";
1027      mpz_clear (lhs); mpz_clear (rhs);
1028      longjmp (errjmpbuf, 1);
1029    case GCD:
1030      mpz_init (lhs); mpz_init (rhs);
1031      mpz_eval_expr (lhs, e->operands.ops.lhs);
1032      mpz_eval_expr (rhs, e->operands.ops.rhs);
1033      mpz_gcd (r, lhs, rhs);
1034      mpz_clear (lhs); mpz_clear (rhs);
1035      return;
1036#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1037    case LCM:
1038      mpz_init (lhs); mpz_init (rhs);
1039      mpz_eval_expr (lhs, e->operands.ops.lhs);
1040      mpz_eval_expr (rhs, e->operands.ops.rhs);
1041      mpz_lcm (r, lhs, rhs);
1042      mpz_clear (lhs); mpz_clear (rhs);
1043      return;
1044#endif
1045    case AND:
1046      mpz_init (lhs); mpz_init (rhs);
1047      mpz_eval_expr (lhs, e->operands.ops.lhs);
1048      mpz_eval_expr (rhs, e->operands.ops.rhs);
1049      mpz_and (r, lhs, rhs);
1050      mpz_clear (lhs); mpz_clear (rhs);
1051      return;
1052    case IOR:
1053      mpz_init (lhs); mpz_init (rhs);
1054      mpz_eval_expr (lhs, e->operands.ops.lhs);
1055      mpz_eval_expr (rhs, e->operands.ops.rhs);
1056      mpz_ior (r, lhs, rhs);
1057      mpz_clear (lhs); mpz_clear (rhs);
1058      return;
1059#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1060    case XOR:
1061      mpz_init (lhs); mpz_init (rhs);
1062      mpz_eval_expr (lhs, e->operands.ops.lhs);
1063      mpz_eval_expr (rhs, e->operands.ops.rhs);
1064      mpz_xor (r, lhs, rhs);
1065      mpz_clear (lhs); mpz_clear (rhs);
1066      return;
1067#endif
1068    case NEG:
1069      mpz_eval_expr (r, e->operands.ops.lhs);
1070      mpz_neg (r, r);
1071      return;
1072    case NOT:
1073      mpz_eval_expr (r, e->operands.ops.lhs);
1074      mpz_com (r, r);
1075      return;
1076    case SQRT:
1077      mpz_init (lhs);
1078      mpz_eval_expr (lhs, e->operands.ops.lhs);
1079      if (mpz_sgn (lhs) < 0)
1080        {
1081          error = "cannot take square root of negative numbers";
1082          mpz_clear (lhs);
1083          longjmp (errjmpbuf, 1);
1084        }
1085      mpz_sqrt (r, lhs);
1086      return;
1087#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1088    case ROOT:
1089      mpz_init (lhs); mpz_init (rhs);
1090      mpz_eval_expr (lhs, e->operands.ops.lhs);
1091      mpz_eval_expr (rhs, e->operands.ops.rhs);
1092      if (mpz_sgn (rhs) <= 0)
1093        {
1094          error = "cannot take non-positive root orders";
1095          mpz_clear (lhs); mpz_clear (rhs);
1096          longjmp (errjmpbuf, 1);
1097        }
1098      if (mpz_sgn (lhs) < 0 && (mpz_get_ui (rhs) & 1) == 0)
1099        {
1100          error = "cannot take even root orders of negative numbers";
1101          mpz_clear (lhs); mpz_clear (rhs);
1102          longjmp (errjmpbuf, 1);
1103        }
1104
1105      {
1106        unsigned long int nth = mpz_get_ui (rhs);
1107        if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1108          {
1109            /* If we are asked to take an awfully large root order, cheat and
1110               ask for the largest order we can pass to mpz_root.  This saves
1111               some error prone special cases.  */
1112            nth = ~(unsigned long int) 0;
1113          }
1114        mpz_root (r, lhs, nth);
1115      }
1116      mpz_clear (lhs); mpz_clear (rhs);
1117      return;
1118#endif
1119    case FAC:
1120      mpz_eval_expr (r, e->operands.ops.lhs);
1121      if (mpz_size (r) > 1)
1122        {
1123          error = "result of `!' operator too large";
1124          longjmp (errjmpbuf, 1);
1125        }
1126      mpz_fac_ui (r, mpz_get_ui (r));
1127      return;
1128#if __GNU_MP_VERSION >= 2
1129    case POPCNT:
1130      mpz_eval_expr (r, e->operands.ops.lhs);
1131      { long int cnt;
1132        cnt = mpz_popcount (r);
1133        mpz_set_si (r, cnt);
1134      }
1135      return;
1136    case HAMDIST:
1137      { long int cnt;
1138        mpz_init (lhs); mpz_init (rhs);
1139        mpz_eval_expr (lhs, e->operands.ops.lhs);
1140        mpz_eval_expr (rhs, e->operands.ops.rhs);
1141        cnt = mpz_hamdist (lhs, rhs);
1142        mpz_clear (lhs); mpz_clear (rhs);
1143        mpz_set_si (r, cnt);
1144      }
1145      return;
1146#endif
1147    case LOG2:
1148      mpz_eval_expr (r, e->operands.ops.lhs);
1149      { unsigned long int cnt;
1150        if (mpz_sgn (r) <= 0)
1151          {
1152            error = "logarithm of non-positive number";
1153            longjmp (errjmpbuf, 1);
1154          }
1155        cnt = mpz_sizeinbase (r, 2);
1156        mpz_set_ui (r, cnt - 1);
1157      }
1158      return;
1159    case LOG:
1160      { unsigned long int cnt;
1161        mpz_init (lhs); mpz_init (rhs);
1162        mpz_eval_expr (lhs, e->operands.ops.lhs);
1163        mpz_eval_expr (rhs, e->operands.ops.rhs);
1164        if (mpz_sgn (lhs) <= 0)
1165          {
1166            error = "logarithm of non-positive number";
1167            mpz_clear (lhs); mpz_clear (rhs);
1168            longjmp (errjmpbuf, 1);
1169          }
1170        if (mpz_cmp_ui (rhs, 256) >= 0)
1171          {
1172            error = "logarithm base too large";
1173            mpz_clear (lhs); mpz_clear (rhs);
1174            longjmp (errjmpbuf, 1);
1175          }
1176        cnt = mpz_sizeinbase (lhs, mpz_get_ui (rhs));
1177        mpz_set_ui (r, cnt - 1);
1178        mpz_clear (lhs); mpz_clear (rhs);
1179      }
1180      return;
1181    case FERMAT:
1182      {
1183        unsigned long int t;
1184        mpz_init (lhs);
1185        mpz_eval_expr (lhs, e->operands.ops.lhs);
1186        t = (unsigned long int) 1 << mpz_get_ui (lhs);
1187        if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0 || t == 0)
1188          {
1189            error = "too large Mersenne number index";
1190            mpz_clear (lhs);
1191            longjmp (errjmpbuf, 1);
1192          }
1193        mpz_set_ui (r, 1);
1194        mpz_mul_2exp (r, r, t);
1195        mpz_add_ui (r, r, 1);
1196        mpz_clear (lhs);
1197      }
1198      return;
1199    case MERSENNE:
1200      mpz_init (lhs);
1201      mpz_eval_expr (lhs, e->operands.ops.lhs);
1202      if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0)
1203        {
1204          error = "too large Mersenne number index";
1205          mpz_clear (lhs);
1206          longjmp (errjmpbuf, 1);
1207        }
1208      mpz_set_ui (r, 1);
1209      mpz_mul_2exp (r, r, mpz_get_ui (lhs));
1210      mpz_sub_ui (r, r, 1);
1211      mpz_clear (lhs);
1212      return;
1213    case FIBONACCI:
1214      { mpz_t t;
1215        unsigned long int n, i;
1216        mpz_init (lhs);
1217        mpz_eval_expr (lhs, e->operands.ops.lhs);
1218        if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1219          {
1220            error = "Fibonacci index out of range";
1221            mpz_clear (lhs);
1222            longjmp (errjmpbuf, 1);
1223          }
1224        n = mpz_get_ui (lhs);
1225        mpz_clear (lhs);
1226
1227#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1228        mpz_fib_ui (r, n);
1229#else
1230        mpz_init_set_ui (t, 1);
1231        mpz_set_ui (r, 1);
1232
1233        if (n <= 2)
1234          mpz_set_ui (r, 1);
1235        else
1236          {
1237            for (i = 3; i <= n; i++)
1238              {
1239                mpz_add (t, t, r);
1240                mpz_swap (t, r);
1241              }
1242          }
1243        mpz_clear (t);
1244#endif
1245      }
1246      return;
1247    case RANDOM:
1248      {
1249        unsigned long int n;
1250        mpz_init (lhs);
1251        mpz_eval_expr (lhs, e->operands.ops.lhs);
1252        if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1253          {
1254            error = "random number size out of range";
1255            mpz_clear (lhs);
1256            longjmp (errjmpbuf, 1);
1257          }
1258        n = mpz_get_ui (lhs);
1259        mpz_clear (lhs);
1260        mpz_urandomb (r, rstate, n);
1261      }
1262      return;
1263    case NEXTPRIME:
1264      {
1265        mpz_eval_expr (r, e->operands.ops.lhs);
1266        mpz_nextprime (r, r);
1267      }
1268      return;
1269    case BINOM:
1270      mpz_init (lhs); mpz_init (rhs);
1271      mpz_eval_expr (lhs, e->operands.ops.lhs);
1272      mpz_eval_expr (rhs, e->operands.ops.rhs);
1273      {
1274        unsigned long int k;
1275        if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1276          {
1277            error = "k too large in (n over k) expression";
1278            mpz_clear (lhs); mpz_clear (rhs);
1279            longjmp (errjmpbuf, 1);
1280          }
1281        k = mpz_get_ui (rhs);
1282        mpz_bin_ui (r, lhs, k);
1283      }
1284      mpz_clear (lhs); mpz_clear (rhs);
1285      return;
1286    default:
1287      abort ();
1288    }
1289}
1290
1291/* Evaluate the expression E modulo MOD and put the result in R.  */
1292void
1293mpz_eval_mod_expr (mpz_ptr r, expr_t e, mpz_ptr mod)
1294{
1295  mpz_t lhs, rhs;
1296
1297  switch (e->op)
1298    {
1299      case POW:
1300        mpz_init (lhs); mpz_init (rhs);
1301        mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1302        mpz_eval_expr (rhs, e->operands.ops.rhs);
1303        mpz_powm (r, lhs, rhs, mod);
1304        mpz_clear (lhs); mpz_clear (rhs);
1305        return;
1306      case PLUS:
1307        mpz_init (lhs); mpz_init (rhs);
1308        mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1309        mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1310        mpz_add (r, lhs, rhs);
1311        if (mpz_cmp_si (r, 0L) < 0)
1312          mpz_add (r, r, mod);
1313        else if (mpz_cmp (r, mod) >= 0)
1314          mpz_sub (r, r, mod);
1315        mpz_clear (lhs); mpz_clear (rhs);
1316        return;
1317      case MINUS:
1318        mpz_init (lhs); mpz_init (rhs);
1319        mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1320        mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1321        mpz_sub (r, lhs, rhs);
1322        if (mpz_cmp_si (r, 0L) < 0)
1323          mpz_add (r, r, mod);
1324        else if (mpz_cmp (r, mod) >= 0)
1325          mpz_sub (r, r, mod);
1326        mpz_clear (lhs); mpz_clear (rhs);
1327        return;
1328      case MULT:
1329        mpz_init (lhs); mpz_init (rhs);
1330        mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1331        mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1332        mpz_mul (r, lhs, rhs);
1333        mpz_mod (r, r, mod);
1334        mpz_clear (lhs); mpz_clear (rhs);
1335        return;
1336      default:
1337        mpz_init (lhs);
1338        mpz_eval_expr (lhs, e);
1339        mpz_mod (r, lhs, mod);
1340        mpz_clear (lhs);
1341        return;
1342    }
1343}
1344
1345void
1346cleanup_and_exit (int sig)
1347{
1348  switch (sig) {
1349#ifdef LIMIT_RESOURCE_USAGE
1350  case SIGXCPU:
1351    printf ("expression took too long to evaluate%s\n", newline);
1352    break;
1353#endif
1354  case SIGFPE:
1355    printf ("divide by zero%s\n", newline);
1356    break;
1357  default:
1358    printf ("expression required too much memory to evaluate%s\n", newline);
1359    break;
1360  }
1361  exit (-2);
1362}
Note: See TracBrowser for help on using the repository browser.