Use constraints more widely in matrix functions.
authorBen Pfaff <blp@cs.stanford.edu>
Sun, 21 Nov 2021 21:50:12 +0000 (13:50 -0800)
committerBen Pfaff <blp@cs.stanford.edu>
Sun, 21 Nov 2021 21:50:12 +0000 (13:50 -0800)
src/language/stats/matrix.c
tests/language/stats/matrix.at

index c166473e87d9f7f01aa90a04aa008e1ed35b4924..fe02d7269f1d9646533f6e0d425690373d8d05f3 100644 (file)
@@ -180,6 +180,35 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value)
   var->value = value;
 }
 \f
+/* The third argument to F() is a "prototype".  For most prototypes, the first
+   letter (before the _) represents the return type and each other letter
+   (after the _) is an argument type.  The types are:
+
+     - "m": A matrix of unrestricted dimensions.
+
+     - "d": A scalar.
+
+     - "v": A row or column vector.
+
+     - "e": Primarily for the first argument, this is a matrix with
+       unrestricted dimensions treated elementwise.  Each element in the matrix
+       is passed to the implementation function separately.
+
+   The fourth argument is an optional constraints string.  For this purpose the
+   first argument is named "a", the second "b", and so on.  The following kinds
+   of constraints are supported.  For matrix arguments, the constraints are
+   applied to each value in the matrix separately:
+
+     - "a(0,1)" or "a[0,1]": 0 < a < 1 or 0 <= a <= 1, respectively.  Any
+       integer may substitute for 0 and 1.  Half-open constraints (] and [) are
+       also supported.
+
+     - "ai": Restrict a to integer values.
+
+     - "a>0", "a<0", "a>=0", "a<=0".
+
+     - "a<b", "a>b", "a<=b", "a>=b".
+*/
 #define MATRIX_FUNCTIONS                                \
     F(ABS,      "ABS",      m_e, NULL)                \
     F(ALL,      "ALL",      d_m, NULL)                  \
@@ -206,7 +235,7 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value)
     F(KRONEKER, "KRONEKER", m_mm, NULL)                 \
     F(LG10,     "LG10",     m_e, "a>0")               \
     F(LN,       "LN",       m_e, "a>0")               \
-    F(MAGIC,    "MAGIC",    m_d, NULL)                  \
+    F(MAGIC,    "MAGIC",    m_d, "ai>=3")                  \
     F(MAKE,     "MAKE",     m_ddd, NULL)                \
     F(MDIAG,    "MDIAG",    m_v, NULL)                  \
     F(MMAX,     "MMAX",     d_m, NULL)                  \
@@ -226,7 +255,7 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value)
     F(RSUM,     "RSUM",     m_m, NULL)                  \
     F(SIN,      "SIN",      m_e, NULL)                \
     F(SOLVE,    "SOLVE",    m_mm, NULL)                 \
-    F(SQRT,     "SQRT",     m_m, NULL)                  \
+    F(SQRT,     "SQRT",     m_e, "a>=0")                  \
     F(SSCP,     "SSCP",     m_m, NULL)                  \
     F(SVAL,     "SVAL",     m_m, NULL)                  \
     F(SWEEP,    "SWEEP",    m_md, NULL)                 \
@@ -234,7 +263,7 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value)
     F(TRACE,    "TRACE",    d_m, NULL)                  \
     F(TRANSPOS, "TRANSPOS", m_m, NULL)                  \
     F(TRUNC,    "TRUNC",    m_e, NULL)                \
-    F(UNIFORM,  "UNIFORM",  m_dd, NULL)                 \
+    F(UNIFORM,  "UNIFORM",  m_dd, "ai>=0 bi>=0")                 \
     F(PDF_BETA, "PDF.BETA", m_edd, "a[0,1] b>0 c>0")  \
     F(CDF_BETA, "CDF.BETA", m_edd, "a[0,1] b>0 c>0")  \
     F(IDF_BETA, "IDF.BETA", m_edd, "a[0,1] b>0 c>0")  \
@@ -1329,11 +1358,6 @@ matrix_eval_MAGIC_singly_even (gsl_matrix *m, size_t n)
 static gsl_matrix *
 matrix_eval_MAGIC (double n_)
 {
-  if (n_ < 3 || n_ >= sqrt (SIZE_MAX))
-    {
-      msg (SE, _("MAGIC argument must be an integer 3 or greater."));
-      return NULL;
-    }
   size_t n = n_;
 
   gsl_matrix *m = gsl_matrix_calloc (n, n);
@@ -1620,19 +1644,10 @@ matrix_eval_SOLVE (gsl_matrix *m1, gsl_matrix *m2)
   return x;
 }
 
-static gsl_matrix *
-matrix_eval_SQRT (gsl_matrix *m)
+static double
+matrix_eval_SQRT (double d)
 {
-  MATRIX_FOR_ALL_ELEMENTS (d, y, x, m)
-    {
-      if (*d < 0)
-        {
-          msg (SE, _("Argument to SQRT must be nonnegative."));
-          return NULL;
-        }
-      *d = sqrt (*d);
-    }
-  return m;
+  return sqrt (d);
 }
 
 static gsl_matrix *
@@ -1757,11 +1772,6 @@ matrix_eval_TRUNC (double d)
 static gsl_matrix *
 matrix_eval_UNIFORM (double r_, double c_)
 {
-  if (r_ < 0 || r_ >= SIZE_MAX || c_ < 0 || c_ >= SIZE_MAX)
-    {
-      msg (SE, _("Arguments to UNIFORM must be integers."));
-      return NULL;
-    }
   size_t r = r_;
   size_t c = c_;
   if (size_overflow_p (xtimes (r, xmax (c, 1))))
index bd81cd7ec99cc5dae31ce8614f7fe1eb4e716e22..07fd8fa3c61b8946d34fa76b9e81b1f907994e97 100644 (file)
@@ -1841,7 +1841,8 @@ number of rows, but the first argument has dimensions 2×2 and the second 1×2.
 SQRT({0, 1, 2, 3, 4, 9, 81})
    .00  1.00  1.41  1.73  2.00  3.00  9.00
 
-matrix.sps:13: error: MATRIX: Argument to SQRT must be nonnegative.
+error: Argument 1 to matrix function SQRT has invalid value -1.  This argument
+must be greater than or equal to 0.
 ])
 AT_CLEANUP