eigenmath/det.cpp

298 lines
4.0 KiB
C++
Raw Permalink Normal View History

2004-03-03 21:24:06 +01:00
//-----------------------------------------------------------------------------
//
// Input: Matrix on stack
//
// Output: Determinant on stack
//
// Example:
//
// > det(((1,2),(3,4)))
// -2
//
// Note:
//
// Uses Gaussian elimination for numerical matrices.
//
//-----------------------------------------------------------------------------
2006-01-06 03:38:07 +01:00
#include "stdafx.h"
2004-03-03 21:24:06 +01:00
#include "defs.h"
static int
check_arg(void)
{
if (!istensor(p1))
return 0;
else if (p1->u.tensor->ndim != 2)
return 0;
else if (p1->u.tensor->dim[0] != p1->u.tensor->dim[1])
return 0;
else
return 1;
}
void
det(void)
{
int i, n;
U **a;
save();
p1 = pop();
if (check_arg() == 0) {
push_symbol(DET);
push(p1);
list(2);
restore();
return;
}
n = p1->u.tensor->nelem;
a = p1->u.tensor->elem;
for (i = 0; i < n; i++)
if (!isnum(a[i]))
break;
if (i == n)
2006-01-06 03:38:07 +01:00
yydetg();
2004-03-03 21:24:06 +01:00
else {
for (i = 0; i < p1->u.tensor->nelem; i++)
push(p1->u.tensor->elem[i]);
determinant(p1->u.tensor->dim[0]);
}
restore();
}
// determinant of n * n matrix elements on the stack
void
determinant(int n)
{
int h, i, j, k, q, s, sign, t;
int *a, *c, *d;
h = tos - n * n;
a = (int *) malloc(3 * n * sizeof (int));
if (a == NULL)
out_of_memory();
c = a + n;
d = c + n;
for (i = 0; i < n; i++) {
a[i] = i;
c[i] = 0;
d[i] = 1;
}
sign = 1;
2004-06-25 22:45:15 +02:00
push(zero);
2004-03-03 21:24:06 +01:00
for (;;) {
if (sign == 1)
2004-06-25 02:11:40 +02:00
push_integer(1);
2004-03-03 21:24:06 +01:00
else
2004-06-25 02:11:40 +02:00
push_integer(-1);
2004-03-03 21:24:06 +01:00
for (i = 0; i < n; i++) {
k = n * a[i] + i;
push(stack[h + k]);
multiply(); // FIXME -- problem here
}
add();
/* next permutation (Knuth's algorithm P) */
j = n - 1;
s = 0;
P4: q = c[j] + d[j];
if (q < 0) {
d[j] = -d[j];
j--;
goto P4;
}
if (q == j + 1) {
if (j == 0)
break;
s++;
d[j] = -d[j];
j--;
goto P4;
}
t = a[j - c[j] + s];
a[j - c[j] + s] = a[j - q + s];
a[j - q + s] = t;
c[j] = q;
sign = -sign;
}
free(a);
stack[h] = stack[tos - 1];
tos = h + 1;
}
//-----------------------------------------------------------------------------
//
// Input: Matrix on stack
//
// Output: Determinant on stack
//
// Note:
//
// Uses Gaussian elimination which is faster for numerical matrices.
//
// Gaussian Elimination works by walking down the diagonal and clearing
// out the columns below it.
//
//-----------------------------------------------------------------------------
void
detg(void)
{
save();
p1 = pop();
if (check_arg() == 0) {
push_symbol(DET);
push(p1);
list(2);
restore();
return;
}
2006-01-06 03:38:07 +01:00
yydetg();
2004-03-03 21:24:06 +01:00
restore();
}
2006-01-06 03:38:07 +01:00
void
yydetg(void)
2004-03-03 21:24:06 +01:00
{
int i, n;
n = p1->u.tensor->dim[0];
for (i = 0; i < n * n; i++)
push(p1->u.tensor->elem[i]);
2006-01-06 03:38:07 +01:00
lu_decomp(n);
2004-03-03 21:24:06 +01:00
tos -= n * n;
push(p1);
}
//-----------------------------------------------------------------------------
//
// Input: n * n matrix elements on stack
//
// Output: p1 determinant
//
// p2 mangled
//
// upper diagonal matrix on stack
//
//-----------------------------------------------------------------------------
#define M(i, j) stack[h + n * (i) + (j)]
2006-01-06 03:38:07 +01:00
void
lu_decomp(int n)
2004-03-03 21:24:06 +01:00
{
int d, h, i, j;
h = tos - n * n;
2004-06-25 22:45:15 +02:00
p1 = one;
2004-03-03 21:24:06 +01:00
for (d = 0; d < n - 1; d++) {
// diagonal element zero?
2004-06-25 22:45:15 +02:00
if (equal(M(d, d), zero)) {
2004-03-03 21:24:06 +01:00
// find a new row
for (i = d + 1; i < n; i++)
2004-06-25 22:45:15 +02:00
if (!equal(M(i, d), zero))
2004-03-03 21:24:06 +01:00
break;
if (i == n) {
2004-06-25 22:45:15 +02:00
p1 = zero;
2004-03-03 21:24:06 +01:00
break;
}
// exchange rows
for (j = d; j < n; j++) {
p2 = M(d, j);
M(d, j) = M(i, j);
M(i, j) = p2;
}
// negate det
push(p1);
negate();
p1 = pop();
}
// update det
push(p1);
push(M(d, d));
multiply();
p1 = pop();
// update lower diagonal matrix
for (i = d + 1; i < n; i++) {
// multiplier
push(M(i, d));
push(M(d, d));
divide();
negate();
p2 = pop();
// update one row
2004-06-25 22:45:15 +02:00
M(i, d) = zero; // clear column below pivot d
2004-03-03 21:24:06 +01:00
for (j = d + 1; j < n; j++) {
push(M(d, j));
push(p2);
multiply();
push(M(i, j));
add();
M(i, j) = pop();
}
}
}
// last diagonal element
push(p1);
push(M(n - 1, n - 1));
multiply();
p1 = pop();
}