Edinburgh Speech Tools 2.4-release
wagon_aux.cc
1/*************************************************************************/
2/* */
3/* Centre for Speech Technology Research */
4/* University of Edinburgh, UK */
5/* Copyright (c) 1996,1997 */
6/* All Rights Reserved. */
7/* */
8/* Permission is hereby granted, free of charge, to use and distribute */
9/* this software and its documentation without restriction, including */
10/* without limitation the rights to use, copy, modify, merge, publish, */
11/* distribute, sublicense, and/or sell copies of this work, and to */
12/* permit persons to whom this work is furnished to do so, subject to */
13/* the following conditions: */
14/* 1. The code must retain the above copyright notice, this list of */
15/* conditions and the following disclaimer. */
16/* 2. Any modifications must be clearly marked as such. */
17/* 3. Original authors' names are not deleted. */
18/* 4. The authors' names are not used to endorse or promote products */
19/* derived from this software without specific prior written */
20/* permission. */
21/* */
22/* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23/* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24/* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25/* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26/* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27/* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28/* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29/* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30/* THIS SOFTWARE. */
31/* */
32/*************************************************************************/
33/* Author : Alan W Black */
34/* Date : May 1996 */
35/*-----------------------------------------------------------------------*/
36/* */
37/* Various method functions */
38/*=======================================================================*/
39
40#include <cstdlib>
41#include <iostream>
42#include <cstring>
43#include "EST_unix.h"
44#include "EST_cutils.h"
45#include "EST_Token.h"
46#include "EST_Wagon.h"
47#include "EST_math.h"
48
49
50EST_Val WNode::predict(const WVector &d)
51{
52 if (leaf())
53 return impurity.value();
54 else if (question.ask(d))
55 return left->predict(d);
56 else
57 return right->predict(d);
58}
59
60WNode *WNode::predict_node(const WVector &d)
61{
62 if (leaf())
63 return this;
64 else if (question.ask(d))
65 return left->predict_node(d);
66 else
67 return right->predict_node(d);
68}
69
70int WNode::pure(void)
71{
72 // A node is pure if it has no sub-nodes or its not of type class
73
74 if ((left == 0) && (right == 0))
75 return TRUE;
76 else if (get_impurity().type() != wnim_class)
77 return TRUE;
78 else
79 return FALSE;
80}
81
82void WNode::prune(void)
83{
84 // Check all sub-nodes and if they are all of the same class
85 // delete their sub nodes. Returns pureness of this node
86
87 if (pure() == FALSE)
88 {
89 // Ok lets try and make it pure
90 if (left != 0) left->prune();
91 if (right != 0) right->prune();
92
93 // Have to check purity as well as values to ensure left and right
94 // don't further split
95 if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96 (left->get_impurity().value() == right->get_impurity().value()))
97 {
98 delete left; left = 0;
99 delete right; right = 0;
100 }
101 }
102
103}
104
105void WNode::held_out_prune()
106{
107 // prune tree with held out data
108 // Check if node's questions differentiates for the held out data
109 // if not, prune all sub_nodes
110
111 // Rescore with prune data
112 set_impurity(WImpurity(get_data())); // for this new data
113
114 if (left != 0)
115 {
116 wgn_score_question(question,get_data());
117 if (question.get_score() < get_impurity().measure())
118 { // its worth goint ot the next level
119 wgn_find_split(question,get_data(),
120 left->get_data(),
121 right->get_data());
122 left->held_out_prune();
123 right->held_out_prune();
124 }
125 else
126 { // not worth the split so prune both sub_nodes
127 delete left; left = 0;
128 delete right; right = 0;
129 }
130 }
131}
132
133void WNode::print_out(ostream &s, int margin)
134{
135 int i;
136
137 s << endl;
138 for (i=0;i<margin;i++) s << " ";
139 s << "(";
140 if (left==0) // base case
141 s << impurity;
142 else
143 {
144 s << question;
145 left->print_out(s,margin+1);
146 right->print_out(s,margin+1);
147 }
148 s << ")";
149}
150
151ostream & operator <<(ostream &s, WNode &n)
152{
153 // Output this node and its sub-node
154
155 n.print_out(s,0);
156 s << endl;
157 return s;
158}
159
160void WDataSet::ignore_non_numbers()
161{
162 /* For ols we want to ignore anything that is categorial */
163 int i;
164
165 for (i=0; i<dlength; i++)
166 {
167 if ((p_type[i] == wndt_binary) ||
168 (p_type[i] == wndt_float))
169 continue;
170 else
171 {
172 p_ignore[i] = TRUE;
173 }
174 }
175
176 return;
177}
178
179void WDataSet::load_description(const EST_String &fname, LISP ignores)
180{
181 // Initialise a dataset with sizes and types
182 EST_String tname;
183 int i;
184 LISP description,d;
185
186 description = car(vload(fname,1));
187 dlength = siod_llength(description);
188
189 p_type.resize(dlength);
190 p_ignore.resize(dlength);
191 p_name.resize(dlength);
192
193 if (wgn_predictee_name == "")
194 wgn_predictee = 0; // default predictee is first field
195 else
196 wgn_predictee = -1;
197
198 for (i=0,d=description; d != NIL; d=cdr(d),i++)
199 {
200 p_name[i] = get_c_string(car(car(d)));
201 tname = get_c_string(car(cdr(car(d))));
202 p_ignore[i] = FALSE;
203 if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
204 wgn_predictee = i;
205 if ((wgn_count_field_name != "") &&
206 (wgn_count_field_name == p_name[i]))
207 wgn_count_field = i;
208 if ((tname == "count") || (i == wgn_count_field))
209 {
210 // The count must be ignored, repeat it if you want it too
211 p_type[i] = wndt_ignore; // the count must be ignored
212 p_ignore[i] = TRUE;
213 wgn_count_field = i;
214 }
215 else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
216 {
217 p_type[i] = wndt_ignore; // user specified ignore
218 p_ignore[i] = TRUE;
219 if (i == wgn_predictee)
220 wagon_error(EST_String("predictee \"")+p_name[i]+
221 "\" can't be ignored \n");
222 }
223 else if (siod_llength(car(d)) > 2)
224 {
225 LISP rest = cdr(car(d));
226 EST_StrList sl;
227 siod_list_to_strlist(rest,sl);
228 p_type[i] = wgn_discretes.def(sl);
229 if (streq(get_c_string(car(rest)),"_other_"))
230 wgn_discretes[p_type[i]].def_val("_other_");
231 }
232 else if (tname == "binary")
233 p_type[i] = wndt_binary;
234 else if (tname == "cluster")
235 p_type[i] = wndt_cluster;
236 else if (tname == "vector")
237 p_type[i] = wndt_vector;
238 else if (tname == "trajectory")
239 p_type[i] = wndt_trajectory;
240 else if (tname == "ols")
241 p_type[i] = wndt_ols;
242 else if (tname == "matrix")
243 p_type[i] = wndt_matrix;
244 else if (tname == "float")
245 p_type[i] = wndt_float;
246 else
247 {
248 wagon_error(EST_String("Unknown type \"")+tname+
249 "\" for field number "+itoString(i)+
250 "/"+p_name[i]+" in description file \""+fname+"\"");
251 }
252 }
253
254 if (wgn_predictee == -1)
255 {
256 wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
257 "\" not found in description ");
258 }
259}
260
261const int WQuestion::ask(const WVector &w) const
262{
263 // Ask this question of the given vector
264 switch (op)
265 {
266 case wnop_equal: // for numbers
267 if (w.get_flt_val(feature_pos) == operand1.Float())
268 return TRUE;
269 else
270 return FALSE;
271 case wnop_binary: // for numbers
272 if (w.get_int_val(feature_pos) == 1)
273 return TRUE;
274 else
275 return FALSE;
276 case wnop_greaterthan:
277 if (w.get_flt_val(feature_pos) > operand1.Float())
278 return TRUE;
279 else
280 return FALSE;
281 case wnop_lessthan:
282 if (w.get_flt_val(feature_pos) < operand1.Float())
283 return TRUE;
284 else
285 return FALSE;
286 case wnop_is: // for classes
287 if (w.get_int_val(feature_pos) == operand1.Int())
288 return TRUE;
289 else
290 return FALSE;
291 case wnop_in: // for subsets -- note operand is list of ints
292 if (ilist_member(operandl,w.get_int_val(feature_pos)))
293 return TRUE;
294 else
295 return FALSE;
296 default:
297 wagon_error("Unknown test operator");
298 }
299
300 return FALSE;
301}
302
303ostream& operator<<(ostream& s, const WQuestion &q)
304{
305 EST_String name;
306 static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
307
308 s << "(" << wgn_dataset.feat_name(q.get_fp());
309 switch (q.get_op())
310 {
311 case wnop_equal:
312 s << " = " << q.get_operand1().string();
313 break;
314 case wnop_binary:
315 break;
316 case wnop_greaterthan:
317 s << " > " << q.get_operand1().Float();
318 break;
319 case wnop_lessthan:
320 s << " < " << q.get_operand1().Float();
321 break;
322 case wnop_is:
323 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
324 name(q.get_operand1().Int());
325 s << " is ";
326 if (name.matches(needquotes))
327 s << quote_string(name,"\"","\\",1);
328 else
329 s << name;
330 break;
331 case wnop_matches:
332 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
333 name(q.get_operand1().Int());
334 s << " matches " << quote_string(name,"\"","\\",1);
335 break;
336 case wnop_in:
337 s << " in (";
338 for (int l=0; l < q.get_operandl().length(); l++)
339 {
340 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
341 name(q.get_operandl().nth(l));
342 if (name.matches(needquotes))
343 s << quote_string(name,"\"","\\",1);
344 else
345 s << name;
346 s << " ";
347 }
348 s << ")";
349 break;
350 // SunCC wont let me add this
351// default:
352// s << " unknown operation ";
353 }
354 s << ")";
355
356 return s;
357}
358
359EST_Val WImpurity::value(void)
360{
361 // Returns the recommended value for this
362 EST_String s;
363 double prob;
364
365 if (t==wnim_unset)
366 {
367 cerr << "WImpurity: no value currently set\n";
368 return EST_Val(0.0);
369 }
370 else if (t==wnim_class)
371 return EST_Val(p.most_probable(&prob));
372 else if (t==wnim_cluster)
373 return EST_Val(a.mean());
374 else if (t==wnim_ols) /* OLS TBA */
375 return EST_Val(a.mean());
376 else if (t==wnim_vector)
377 return EST_Val(a.mean()); /* wnim_vector */
378 else if (t==wnim_trajectory)
379 return EST_Val(a.mean()); /* NOT YET WRITTEN */
380 else
381 return EST_Val(a.mean());
382}
383
384double WImpurity::samples(void)
385{
386 if (t==wnim_float)
387 return a.samples();
388 else if (t==wnim_class)
389 return (int)p.samples();
390 else if (t==wnim_cluster)
391 return members.length();
392 else if (t==wnim_ols)
393 return members.length();
394 else if (t==wnim_vector)
395 return members.length();
396 else if (t==wnim_trajectory)
397 return members.length();
398 else
399 return 0;
400}
401
402WImpurity::WImpurity(const WVectorVector &ds)
403{
404 int i;
405
406 t=wnim_unset;
407 a.reset(); trajectory=0; l=0; width=0;
408 data = &ds; // for ols, model calculation
409 for (i=0; i < ds.n(); i++)
410 {
411 if (t == wnim_ols)
412 cumulate(i,1);
413 else if (wgn_count_field == -1)
414 cumulate((*(ds(i)))[wgn_predictee],1);
415 else
416 cumulate((*(ds(i)))[wgn_predictee],
417 (*(ds(i)))[wgn_count_field]);
418 }
419}
420
421float WImpurity::measure(void)
422{
423 if (t == wnim_float)
424 return a.variance()*a.samples();
425 else if (t == wnim_vector)
426 return vector_impurity();
427 else if (t == wnim_trajectory)
428 return trajectory_impurity();
429 else if (t == wnim_matrix)
430 return a.variance()*a.samples();
431 else if (t == wnim_class)
432 return p.entropy()*p.samples();
433 else if (t == wnim_cluster)
434 return cluster_impurity();
435 else if (t == wnim_ols)
436 return ols_impurity(); /* RMSE for OLS model */
437 else
438 {
439 cerr << "WImpurity: can't measure unset object" << endl;
440 return 0.0;
441 }
442}
443
444float WImpurity::vector_impurity()
445{
446 // Find the mean/stddev for all values in all vectors
447 // sum the variances and multiply them by the number of members
448 EST_Litem *pp;
449 EST_Litem *countpp;
450 int i,j;
452 double count = 1;
453
454 a.reset();
455#if 1
456 /* simple distance */
457 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
458 {
459 if (wgn_VertexFeats.a(0,j) > 0.0)
460 {
461 b.reset();
462 for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
463 {
464 i = members.item(pp);
465
466 // Accumulate the value with count
467 b.cumulate(wgn_VertexTrack.a(i,j), member_counts.item(countpp)) ;
468 }
469 a += b.stddev();
470 count = b.samples();
471 }
472 }
473#endif
474
475#if 0
476 EST_SuffStats *c;
477 float x, lshift, rshift, ushift;
478 /* Find base mean, then measure do fshift to find best match */
479 c = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
480 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
481 {
482 if (wgn_VertexFeats.a(0,j) > 0.0)
483 {
484 c[j].reset();
485 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
486 pp=pp->next(), countpp=countpp->next())
487 {
488 i = members.item(pp);
489 // Accumulate the value with count
490 c[j].cumulate(wgn_VertexTrack.a(i,j),member_counts.item(countpp));
491 }
492 count = c[j].samples();
493 }
494 }
495
496 /* Pass through again but vary the num_channels offset (hardcoded) */
497 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
498 pp=pp->next(), countpp=countpp->next())
499 {
500 int q;
501 float bshift, qshift;
502 /* For each sample */
503 i = members.item(pp);
504 /* Find the value left shifted, unshifted, and right shifted */
505 lshift = 0; ushift = 0; rshift = 0;
506 bshift = 0;
507 for (q=-20; q<=20; q++)
508 {
509 qshift = 0;
510 for (j=67+q; j<147+q/*hardcoded*/; j++)
511 {
512 x = c[j].mean() - wgn_VertexTrack(i,j);
513 qshift += sqrt(x*x);
514 if ((bshift > 0) && (qshift > bshift))
515 break;
516 }
517 if ((bshift == 0) || (qshift < bshift))
518 bshift = qshift;
519 }
520 a += bshift;
521 }
522
523#endif
524
525#if 0
526 /* full covariance */
527 /* worse in listening experiments */
528 EST_SuffStats **cs;
529 int mmm;
530 cs = new EST_SuffStats *[wgn_VertexTrack.num_channels()+1];
531 for (j=0; j<=wgn_VertexTrack.num_channels(); j++)
532 cs[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
533 /* Find means for diagonal */
534 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
535 {
536 if (wgn_VertexFeats.a(0,j) > 0.0)
537 {
538 for (pp=members.head(); pp != 0; pp=pp->next())
539 cs[j][j] += wgn_VertexTrack.a(members.item(pp),j);
540 }
541 }
542 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
543 {
544 for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
545 if (wgn_VertexFeats.a(0,j) > 0.0)
546 {
547 for (pp=members.head(); pp != 0; pp=pp->next())
548 {
549 mmm = members.item(pp);
550 cs[i][j] += (wgn_VertexTrack.a(mmm,i)-cs[j][j].mean())*
551 (wgn_VertexTrack.a(mmm,j)-cs[j][j].mean());
552 }
553 }
554 }
555 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
556 {
557 for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
558 if (wgn_VertexFeats.a(0,j) > 0.0)
559 a += cs[i][j].stddev();
560 }
561 count = cs[0][0].samples();
562#endif
563
564#if 0
565 // look at mean euclidean distance between vectors
566 EST_Litem *qq;
567 int x,y;
568 double d,q;
569 count = 0;
570 for (pp=members.head(); pp != 0; pp=pp->next())
571 {
572 x = members.item(pp);
573 count++;
574 for (qq=pp->next(); qq != 0; qq=qq->next())
575 {
576 y = members.item(qq);
577 for (q=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
578 if (wgn_VertexFeats.a(0,j) > 0.0)
579 {
580 d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
581 q += d*d;
582 }
583 a += sqrt(q);
584 }
585
586 }
587#endif
588
589 // This is sum of stddev*samples
590 return a.mean() * count;
591}
592
593WImpurity::~WImpurity()
594{
595 int j;
596
597 if (trajectory != 0)
598 {
599 for (j=0; j<l; j++)
600 delete [] trajectory[j];
601 delete [] trajectory;
602 trajectory = 0;
603 l = 0;
604 }
605}
606
607
608float WImpurity::trajectory_impurity()
609{
610 // Find the mean length of all the units in the cluster
611 // Create that number of points
612 // Interpolate each unit to that number of points
613 // collect means and standard deviations for each point
614 // impurity is sum of the variance for each point and each coef
615 // multiplied by the number of units.
616 EST_Litem *pp;
617 int i, j;
618 int s, ti, ni, q;
619 int s1l, s2l;
620 double n, m, m1, m2, w;
621 EST_SuffStats lss, stdss;
622 EST_SuffStats l1ss, l2ss;
623 int l1, l2;
624 int ola=0;
625
626 if (trajectory != 0)
627 { /* already done this */
628 return score;
629 }
630
631 lss.reset();
632 l = 0;
633 for (pp=members.head(); pp != 0; pp=pp->next())
634 {
635 i = members.item(pp);
636 for (q=0; q<wgn_UnitTrack.a(i,1); q++)
637 {
638 ni = (int)wgn_UnitTrack.a(i,0)+q;
639 if (wgn_VertexTrack.a(ni,0) == -1.0)
640 {
641 l1ss += q;
642 ola = 1;
643 break;
644 }
645 }
646 if (q==wgn_UnitTrack.a(i,1))
647 { /* can't find -1 center point, so put all in l2 */
648 l1ss += 0;
649 l2ss += q;
650 }
651 else
652 l2ss += wgn_UnitTrack.a(i,1) - (q+1) - 1;
653 lss += wgn_UnitTrack.a(i,1); /* length of each unit in the cluster */
654 if (wgn_UnitTrack.a(i,1) > l)
655 l = (int)wgn_UnitTrack.a(i,1);
656 }
657
658 if (ola==0) /* no -1's so its not an ola type cluster */
659 {
660 l = ((int)lss.mean() < 7) ? 7 : (int)lss.mean();
661
662 /* a list of SuffStats on for each point in the trajectory */
663 trajectory = new EST_SuffStats *[l];
664 width = wgn_VertexTrack.num_channels()+1;
665 for (j=0; j<l; j++)
666 trajectory[j] = new EST_SuffStats[width];
667
668 for (pp=members.head(); pp != 0; pp=pp->next())
669 { /* for each unit */
670 i = members.item(pp);
671 m = (float)wgn_UnitTrack.a(i,1)/(float)l; /* find interpolation */
672 s = (int)wgn_UnitTrack.a(i,0); /* start point */
673 for (ti=0,n=0.0; ti<l; ti++,n+=m)
674 {
675 ni = (int)n; // hmm floor or nint ??
676 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
677 {
678 if (wgn_VertexFeats.a(0,j) > 0.0)
679 trajectory[ti][j] += wgn_VertexTrack.a(s+ni,j);
680 }
681 }
682 }
683
684 /* find sum of sum of stddev for all coefs of all traj points */
685 stdss.reset();
686 for (ti=0; ti<l; ti++)
687 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
688 {
689 if (wgn_VertexFeats.a(0,j) > 0.0)
690 stdss += trajectory[ti][j].stddev();
691 }
692
693 // This is sum of all stddev * samples
694 score = stdss.mean() * members.length();
695 }
696 else
697 { /* OLA model */
698 l1 = (l1ss.mean() < 10.0) ? 10 : (int)l1ss.mean();
699 l2 = (l2ss.mean() < 10.0) ? 10 : (int)l2ss.mean();
700 l = l1 + l2 + 1 + 1;
701
702 /* a list of SuffStats on for each point in the trajectory */
703 trajectory = new EST_SuffStats *[l];
704 for (j=0; j<l; j++)
705 trajectory[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
706
707 for (pp=members.head(); pp != 0; pp=pp->next())
708 { /* for each unit */
709 i = members.item(pp);
710 s1l = 0;
711 s = (int)wgn_UnitTrack.a(i,0); /* start point */
712 for (q=0; q<wgn_UnitTrack.a(i,1); q++)
713 if (wgn_VertexTrack.a(s+q,0) == -1.0)
714 {
715 s1l = q; /* printf("awb q is -1 at %d\n",q); */
716 break;
717 }
718 s2l = (int)wgn_UnitTrack.a(i,1) - (s1l + 2);
719 m1 = (float)(s1l)/(float)l1; /* find interpolation step */
720 m2 = (float)(s2l)/(float)l2; /* find interpolation step */
721 /* First half */
722 for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
723 {
724 ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
725 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
726 if (wgn_VertexFeats.a(0,j) > 0.0)
727 trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
728 }
729 ti = l1; /* do it explicitly in case s1l < 1 */
730 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
731 if (wgn_VertexFeats.a(0,j) > 0.0)
732 trajectory[ti][j] += -1;
733 /* Second half */
734 s += s1l+1;
735 for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
736 {
737 ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
738 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
739 if (wgn_VertexFeats.a(0,j) > 0.0)
740 trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
741 }
742 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
743 if (wgn_VertexFeats.a(0,j) > 0.0)
744 trajectory[ti][j] += -2;
745 }
746
747 /* find sum of sum of stddev for all coefs of all traj points */
748 /* windowing the sums with a triangular weight window */
749 stdss.reset();
750 m = 1.0/(float)l1;
751 for (w=0.0,ti=0; ti<l1; ti++,w+=m)
752 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
753 if (wgn_VertexFeats.a(0,j) > 0.0)
754 stdss += trajectory[ti][j].stddev() * w;
755 m = 1.0/(float)l2;
756 for (w=1.0,ti++; ti<l-1; ti++,w-=m)
757 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
758 if (wgn_VertexFeats.a(0,j) > 0.0)
759 stdss += trajectory[ti][j].stddev() * w;
760
761 // This is sum of all stddev * samples
762 score = stdss.mean() * members.length();
763 }
764 return score;
765}
766
767static void part_to_ols_data(EST_FMatrix &X, EST_FMatrix &Y,
768 EST_IVector &included,
769 EST_StrList &feat_names,
770 const EST_IList &members,
771 const WVectorVector &d)
772{
773 int m,n,p;
774 int w, xm=0;
775 EST_Litem *pp;
776 WVector *wv;
777
778 w = wgn_dataset.width();
779 included.resize(w);
780 X.resize(members.length(),w);
781 Y.resize(members.length(),1);
782 feat_names.append("Intercept");
783 included[0] = TRUE;
784
785 for (p=0,pp=members.head(); pp; p++,pp=pp->next())
786 {
787 n = members.item(pp);
788 if (n < 0)
789 {
790 p--;
791 continue;
792 }
793 wv = d(n);
794 Y.a_no_check(p,0) = (*wv)[0];
795 X.a_no_check(p,0) = 1;
796 for (m=1,xm=1; m < w; m++)
797 {
798 if (wgn_dataset.ftype(m) == wndt_float)
799 {
800 if (p == 0) // only do this once
801 {
802 feat_names.append(wgn_dataset.feat_name(m));
803 }
804 X.a_no_check(p,xm) = (*wv)[m];
805 included.a_no_check(xm) = FALSE;
806 included.a_no_check(xm) = TRUE;
807 xm++;
808 }
809 }
810 }
811
812 included.resize(xm);
813 X.resize(p,xm);
814 Y.resize(p,1);
815}
816
817float WImpurity::ols_impurity()
818{
819 // Build an OLS model for the current data and measure it against
820 // the data itself and give a RMSE
821 EST_FMatrix X,Y;
822 EST_IVector included;
823 EST_FMatrix coeffs;
824 EST_StrList feat_names;
825 float best_score;
826 EST_FMatrix coeffsl;
827 EST_FMatrix pred;
828 float cor,rmse;
829
830 // Load the sample members into matrices for ols
831 part_to_ols_data(X,Y,included,feat_names,members,*data);
832
833 // Find the best ols model.
834 // Far too computationally expensive
835 // if (!stepwise_ols(X,Y,feat_names,0.0,coeffs,
836 // X,Y,included,best_score))
837 // return WGN_HUGE_VAL; // couldn't find a model
838
839 // Non stepwise model
840 if (!robust_ols(X,Y,included,coeffsl))
841 {
842 // printf("no robust ols\n");
843 return WGN_HUGE_VAL;
844 }
845 ols_apply(X,coeffsl,pred);
846 ols_test(Y,pred,cor,rmse);
847 best_score = cor;
848
849 printf("Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
850 X.num_rows(),X.num_columns(),Y.num_rows(),Y.num_columns(),
851 rmse,cor,
852 1-best_score);
853 if (fabs(coeffsl[0]) > 10000)
854 {
855 // printf("weird sized Intercept %f\n",coeffsl[0]);
856 return WGN_HUGE_VAL;
857 }
858
859 return (1-best_score) *members.length();
860}
861
862float WImpurity::cluster_impurity()
863{
864 // Find the mean distance between all members of the dataset
865 // Uses the global DistMatrix for distances between members of
866 // the cluster set. Distances are assumed to be symmetric thus only
867 // the bottom half of the distance matrix is filled
868 EST_Litem *pp, *q;
869 int i,j;
870 double dist;
871
872 a.reset();
873 for (pp=members.head(); pp != 0; pp=pp->next())
874 {
875 i = members.item(pp);
876 for (q=pp->next(); q != 0; q=q->next())
877 {
878 j = members.item(q);
879 dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
880 wgn_DistMatrix.a_no_check(j,i));
881 a+=dist; // cumulate for whole cluster
882 }
883 }
884
885 // This is sum distance between cross product of members
886// return a.sum();
887 if (a.samples() > 1)
888 return a.stddev() * a.samples();
889 else
890 return 0.0;
891}
892
893float WImpurity::cluster_distance(int i)
894{
895 // Distance this unit is from all others in this cluster
896 // in absolute standard deviations from the the mean.
897 float dist = cluster_member_mean(i);
898 float mdist = dist-a.mean();
899
900 if (mdist == 0.0)
901 return 0.0;
902 else
903 return fabs((dist-a.mean())/a.stddev());
904
905}
906
907int WImpurity::in_cluster(int i)
908{
909 // Would this be a member of this cluster?. Returns 1 if
910 // its distance is less than at least one other
911 float dist = cluster_member_mean(i);
912 EST_Litem *pp;
913
914 for (pp=members.head(); pp != 0; pp=pp->next())
915 {
916 if (dist < cluster_member_mean(members.item(pp)))
917 return 1;
918 }
919 return 0;
920}
921
922float WImpurity::cluster_ranking(int i)
923{
924 // Position in ranking closest to centre
925 float dist = cluster_distance(i);
926 EST_Litem *pp;
927 int ranking = 1;
928
929 for (pp=members.head(); pp != 0; pp=pp->next())
930 {
931 if (dist >= cluster_distance(members.item(pp)))
932 ranking++;
933 }
934
935 return ranking;
936}
937
938float WImpurity::cluster_member_mean(int i)
939{
940 // Returns the mean difference between this member and all others
941 // in cluster
942 EST_Litem *q;
943 int j,n;
944 double dist,sum;
945
946 for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
947 {
948 j = members.item(q);
949 if (i != j)
950 {
951 dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
952 sum += dist;
953 n++;
954 }
955 }
956
957 return ( n == 0 ? 0.0 : sum/n );
958}
959
960void WImpurity::cumulate(const float pv,double count)
961{
962 // Cumulate data for impurity calculation
963
964 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
965 {
966 t = wnim_cluster;
967 members.append((int)pv);
968 }
969 else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
970 {
971 t = wnim_ols;
972 members.append((int)pv);
973 }
974 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
975 {
976 t = wnim_vector;
977
978 // AUP: Implement counts in vectors
979 members.append((int)pv);
980 member_counts.append((float)count);
981 }
982 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
983 {
984 t = wnim_trajectory;
985 members.append((int)pv);
986 }
987 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
988 {
989 if (t == wnim_unset)
990 p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
991 t = wnim_class;
992 p.cumulate((int)pv,count);
993 }
994 else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
995 {
996 t = wnim_float;
997 a.cumulate((int)pv,count);
998 }
999 else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1000 {
1001 t = wnim_float;
1002 a.cumulate(pv,count);
1003 }
1004 else
1005 {
1006 wagon_error("WImpurity: cannot cumulate EST_Val type");
1007 }
1008}
1009
1010ostream & operator <<(ostream &s, WImpurity &imp)
1011{
1012 int j,i;
1013 EST_SuffStats b;
1014
1015 if (imp.t == wnim_float)
1016 s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
1017 else if (imp.t == wnim_vector)
1018 {
1019 EST_Litem *p, *countp;
1020 s << "((";
1021 imp.vector_impurity();
1022 if (wgn_vertex_output == "mean") //output means
1023 {
1024 for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1025 {
1026 b.reset();
1027 for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
1028 {
1029 // Accumulate the members with their counts
1030 b.cumulate(wgn_VertexTrack.a(imp.members.item(p),j), imp.member_counts.item(countp));
1031 //b += wgn_VertexTrack.a(imp.members.item(p),j);
1032 }
1033 s << "(" << b.mean() << " ";
1034 if (isfinite(b.stddev()))
1035 s << b.stddev() << ")";
1036 else
1037 s << "0.001" << ")";
1038 if (j+1<wgn_VertexTrack.num_channels())
1039 s << " ";
1040 }
1041 }
1042 else /* output best in the cluster */
1043 {
1044 /* print out vector closest to center, rather than average */
1045 /* printf("awb_debug outputing best\n"); */
1046 double best = WGN_HUGE_VAL;
1047 double x,d;
1048 int bestp = 0;
1049 EST_SuffStats *cs;
1050
1051 cs = new EST_SuffStats [wgn_VertexTrack.num_channels()+1];
1052
1053 for (j=0; j<wgn_VertexFeats.num_channels(); j++)
1054 {
1055 cs[j].reset();
1056 for (p=imp.members.head(); p != 0; p=p->next())
1057 {
1058 cs[j] += wgn_VertexTrack.a(imp.members.item(p),j);
1059 }
1060 }
1061
1062 for (p=imp.members.head(); p != 0; p=p->next())
1063 {
1064 for (x=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
1065 if (wgn_VertexFeats.a(0,j) > 0.0)
1066 {
1067 d = (wgn_VertexTrack.a(imp.members.item(p),j)-cs[j].mean())
1068 /* / cs[j].stddev() */ ; /* seems worse 061218 */
1069 x += d*d;
1070 }
1071 if (x < best)
1072 {
1073 /* printf("awb_debug updating best %d %f %d %f\n",
1074 bestp, best, imp.members.item(p), x); */
1075 bestp = imp.members.item(p);
1076 best = x;
1077 }
1078 }
1079 for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1080 {
1081 s << "( ";
1082 s << wgn_VertexTrack.a(bestp,j);
1083 // s << " 0 "; // fake stddev
1084 s << " ";
1085 if (isfinite(cs[j].stddev()))
1086 s << cs[j].stddev();
1087 else
1088 s << "0";
1089 s << " ) ";
1090 if (j+1<wgn_VertexTrack.num_channels())
1091 s << " ";
1092 }
1093
1094 delete [] cs;
1095 }
1096 s << ") ";
1097 s << imp.a.mean() << ")";
1098 }
1099 else if (imp.t == wnim_trajectory)
1100 {
1101 s << "((";
1102 imp.trajectory_impurity();
1103 for (i=0; i<imp.l; i++)
1104 {
1105 s << "(";
1106 for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1107 {
1108 s << "(" << imp.trajectory[i][j].mean() << " "
1109 << imp.trajectory[i][j].stddev() << " " << ")";
1110 }
1111 s << ")\n";
1112 }
1113 s << ") ";
1114 // Mean of cross product of distances (cluster score)
1115 s << imp.a.mean() << ")";
1116 }
1117 else if (imp.t == wnim_cluster)
1118 {
1119 EST_Litem *p;
1120 s << "((";
1121 for (p=imp.members.head(); p != 0; p=p->next())
1122 {
1123 // Ouput cluster member and its mean distance to others
1124 s << "(" << imp.members.item(p) << " " <<
1125 imp.cluster_member_mean(imp.members.item(p)) << ")";
1126 if (p->next() != 0)
1127 s << " ";
1128 }
1129 s << ") ";
1130 // Mean of cross product of distances (cluster score)
1131 s << imp.a.mean() << ")";
1132 }
1133 else if (imp.t == wnim_ols)
1134 {
1135 /* Output intercept, feature names and coefficients for ols model */
1136 EST_FMatrix X,Y;
1137 EST_IVector included;
1138 EST_FMatrix coeffs;
1139 EST_StrList feat_names;
1140 EST_FMatrix coeffsl;
1141 EST_FMatrix pred;
1142 float cor=0.0,rmse;
1143
1144 s << "((";
1145 // Load the sample members into matrices for ols
1146 part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1147 if (!robust_ols(X,Y,included,coeffsl))
1148 {
1149 printf("no robust ols\n");
1150 // shouldn't happen
1151 }
1152 else
1153 {
1154 ols_apply(X,coeffsl,pred);
1155 ols_test(Y,pred,cor,rmse);
1156 for (i=0; i<coeffsl.num_rows(); i++)
1157 {
1158 s << "(";
1159 s << feat_names.nth(i);
1160 s << " ";
1161 s << coeffsl[i];
1162 s << ") ";
1163 }
1164 }
1165
1166 // Mean of cross product of distances (cluster score)
1167 s << ") " << cor << ")";
1168 }
1169 else if (imp.t == wnim_class)
1170 {
1171 EST_Litem *i;
1172 EST_String name;
1173 double prob;
1174
1175 s << "(";
1176 for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
1177 {
1178 imp.p.item_prob(i,name,prob);
1179 s << "(" << name << " " << prob << ") ";
1180 }
1181 s << imp.p.most_probable(&prob) << ")";
1182 }
1183 else
1184 s << "([WImpurity unset])";
1185
1186 return s;
1187}
1188
1189
1190
1191
EST_Litem * item_next(EST_Litem *idx) const
Used for iterating through members of the distribution.
EST_Litem * item_start() const
Used for iterating through members of the distribution.
void item_prob(EST_Litem *idx, EST_String &s, double &prob) const
During iteration returns name and probability given index.
const EST_String & most_probable(double *prob=NULL) const
Return the most probable member of the distribution.
double samples(void) const
Total number of example found.
double entropy(void) const
int item_end(EST_Litem *idx) const
Used for iterating through members of the distribution.
int matches(const char *e, int pos=0) const
Exactly match this string?
Definition: EST_String.cc:652
double stddev(void) const
standard deviation of currently cummulated values
double variance(void) const
variance of currently cummulated values
double mean(void) const
mean of currently cummulated values
void reset(void)
reset internal values
double samples(void)
number of samples in set
T & item(const EST_Litem *p)
Definition: EST_TList.h:133
T & nth(int n)
return the Nth value
Definition: EST_TList.h:139
void append(const T &item)
add item onto end of list
Definition: EST_TList.h:191
int num_columns() const
return number of columns
Definition: EST_TMatrix.h:181
INLINE const T & a_no_check(int row, int col) const
const access with no bounds check, care recommend
Definition: EST_TMatrix.h:184
int num_rows() const
return number of rows
Definition: EST_TMatrix.h:179
void resize(int rows, int cols, int set=1)
resize matrix
void resize(int n, int set=1)
resize vector
void resize(int n, int set=1)
Definition: EST_TVector.cc:196
INLINE int n() const
number of items in vector.
Definition: EST_TVector.h:254
INLINE const T & a_no_check(int n) const
read-only const access operator: without bounds checking
Definition: EST_TVector.h:257
float & a(int i, int c=0)
Definition: EST_Track.cc:1022
int num_channels() const
return number of channels in track
Definition: EST_Track.h:656
const EST_String & string(void) const
Definition: EST_Val.h:150
const int Int(void) const
Definition: EST_Val.h:130
const float Float(void) const
Definition: EST_Val.h:138