From ebfa3fa3297da38db9d1d57221fc184e83ae8c4e Mon Sep 17 00:00:00 2001 From: Ben Pfaff Date: Sun, 21 Nov 2021 13:50:12 -0800 Subject: [PATCH] Use constraints more widely in matrix functions. --- src/language/stats/matrix.c | 60 ++++++++++++++++++++-------------- tests/language/stats/matrix.at | 3 +- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/language/stats/matrix.c b/src/language/stats/matrix.c index c166473e87..fe02d7269f 100644 --- a/src/language/stats/matrix.c +++ b/src/language/stats/matrix.c @@ -180,6 +180,35 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value) var->value = value; } +/* The third argument to F() is a "prototype". For most prototypes, the first + letter (before the _) represents the return type and each other letter + (after the _) is an argument type. The types are: + + - "m": A matrix of unrestricted dimensions. + + - "d": A scalar. + + - "v": A row or column vector. + + - "e": Primarily for the first argument, this is a matrix with + unrestricted dimensions treated elementwise. Each element in the matrix + is passed to the implementation function separately. + + The fourth argument is an optional constraints string. For this purpose the + first argument is named "a", the second "b", and so on. The following kinds + of constraints are supported. For matrix arguments, the constraints are + applied to each value in the matrix separately: + + - "a(0,1)" or "a[0,1]": 0 < a < 1 or 0 <= a <= 1, respectively. Any + integer may substitute for 0 and 1. Half-open constraints (] and [) are + also supported. + + - "ai": Restrict a to integer values. + + - "a>0", "a<0", "a>=0", "a<=0". + + - "ab", "a<=b", "a>=b". +*/ #define MATRIX_FUNCTIONS \ F(ABS, "ABS", m_e, NULL) \ F(ALL, "ALL", d_m, NULL) \ @@ -206,7 +235,7 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value) F(KRONEKER, "KRONEKER", m_mm, NULL) \ F(LG10, "LG10", m_e, "a>0") \ F(LN, "LN", m_e, "a>0") \ - F(MAGIC, "MAGIC", m_d, NULL) \ + F(MAGIC, "MAGIC", m_d, "ai>=3") \ F(MAKE, "MAKE", m_ddd, NULL) \ F(MDIAG, "MDIAG", m_v, NULL) \ F(MMAX, "MMAX", d_m, NULL) \ @@ -226,7 +255,7 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value) F(RSUM, "RSUM", m_m, NULL) \ F(SIN, "SIN", m_e, NULL) \ F(SOLVE, "SOLVE", m_mm, NULL) \ - F(SQRT, "SQRT", m_m, NULL) \ + F(SQRT, "SQRT", m_e, "a>=0") \ F(SSCP, "SSCP", m_m, NULL) \ F(SVAL, "SVAL", m_m, NULL) \ F(SWEEP, "SWEEP", m_md, NULL) \ @@ -234,7 +263,7 @@ matrix_var_set (struct matrix_var *var, gsl_matrix *value) F(TRACE, "TRACE", d_m, NULL) \ F(TRANSPOS, "TRANSPOS", m_m, NULL) \ F(TRUNC, "TRUNC", m_e, NULL) \ - F(UNIFORM, "UNIFORM", m_dd, NULL) \ + F(UNIFORM, "UNIFORM", m_dd, "ai>=0 bi>=0") \ F(PDF_BETA, "PDF.BETA", m_edd, "a[0,1] b>0 c>0") \ F(CDF_BETA, "CDF.BETA", m_edd, "a[0,1] b>0 c>0") \ F(IDF_BETA, "IDF.BETA", m_edd, "a[0,1] b>0 c>0") \ @@ -1329,11 +1358,6 @@ matrix_eval_MAGIC_singly_even (gsl_matrix *m, size_t n) static gsl_matrix * matrix_eval_MAGIC (double n_) { - if (n_ < 3 || n_ >= sqrt (SIZE_MAX)) - { - msg (SE, _("MAGIC argument must be an integer 3 or greater.")); - return NULL; - } size_t n = n_; gsl_matrix *m = gsl_matrix_calloc (n, n); @@ -1620,19 +1644,10 @@ matrix_eval_SOLVE (gsl_matrix *m1, gsl_matrix *m2) return x; } -static gsl_matrix * -matrix_eval_SQRT (gsl_matrix *m) +static double +matrix_eval_SQRT (double d) { - MATRIX_FOR_ALL_ELEMENTS (d, y, x, m) - { - if (*d < 0) - { - msg (SE, _("Argument to SQRT must be nonnegative.")); - return NULL; - } - *d = sqrt (*d); - } - return m; + return sqrt (d); } static gsl_matrix * @@ -1757,11 +1772,6 @@ matrix_eval_TRUNC (double d) static gsl_matrix * matrix_eval_UNIFORM (double r_, double c_) { - if (r_ < 0 || r_ >= SIZE_MAX || c_ < 0 || c_ >= SIZE_MAX) - { - msg (SE, _("Arguments to UNIFORM must be integers.")); - return NULL; - } size_t r = r_; size_t c = c_; if (size_overflow_p (xtimes (r, xmax (c, 1)))) diff --git a/tests/language/stats/matrix.at b/tests/language/stats/matrix.at index bd81cd7ec9..07fd8fa3c6 100644 --- a/tests/language/stats/matrix.at +++ b/tests/language/stats/matrix.at @@ -1841,7 +1841,8 @@ number of rows, but the first argument has dimensions 2×2 and the second 1×2. SQRT({0, 1, 2, 3, 4, 9, 81}) .00 1.00 1.41 1.73 2.00 3.00 9.00 -matrix.sps:13: error: MATRIX: Argument to SQRT must be nonnegative. +error: Argument 1 to matrix function SQRT has invalid value -1. This argument +must be greater than or equal to 0. ]) AT_CLEANUP -- 2.30.2