Skip to content

Graph

This module provides the Node class.

Classes

Node: class for facilitating the handling and the creation of nodes for a DAG.

Node

Node class.

Source code in causalflow/graph/Node.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class Node():
    """Node class."""

    def __init__(self, name, neglect_autodep):
        """
        Class contructor.

        Args:
            name (str): node name.
            neglect_autodep (bool): flag to decide whether to to skip the node if it is only auto-dependent.
        """
        self.name = name
        self.sources = dict()
        self.children = list()
        self.neglect_autodep = neglect_autodep
        self.intervention_node = False        
        self.associated_context = None        


    @property
    def is_autodependent(self) -> bool:
        """
        Return True if the node is autodependent.

        Returns:
            bool: Returns True if the node is autodependent. Otherwise False.
        """
        return self.name in self.sourcelist


    @property
    def is_isolated(self) -> bool:
        """
        Return True if the node is isolated.

        Returns:
            bool: Returns True if the node is isolated. Otherwise False.
        """
        if self.neglect_autodep:
            return (self.is_exogenous and not self.has_child) or self.is_only_autodep or self.is_only_autodep_context
        return (self.is_exogenous or self.has_only_context) and not self.has_child


    @property
    def is_only_autodep(self) -> bool:
        """
        Return True if the node is ONLY auto-dependent.

        Returns:
            bool: Returns True if the node is ONLY auto-dependent. Otherwise False.
        """
        return len(self.sources) == 1 and self.name in self.sourcelist and len(self.children) == 1 and self.name in self.children


    @property
    def has_only_context(self) -> bool:
        """
        Return True if the node has ONLY the context variable as parent.

        Returns:
            bool: Returns True if the node has ONLY the context variable as parent. Otherwise False.
        """
        return len(self.sources) == 1 and self.associated_context in self.sourcelist


    @property
    def is_only_autodep_context(self) -> bool:
        """
        Return True if the node has ONLY the context variable and itself as parent.

        Returns:
            bool: Returns True if the node has ONLY the context variable and itself as parent. Otherwise False.
        """
        return len(self.sources) == 2 and self.name in self.sourcelist and self.associated_context in self.sourcelist and len(self.children) == 1 and self.name in self.children


    @property
    def is_exogenous(self) -> bool:
        """
        Return True if the node has no parents.

        Returns:
            bool: Returns True if the node has no parents. Otherwise False.
        """
        return len(self.sources) == 0


    @property
    def has_child(self) -> bool:
        """
        Return True if the node has at least one child.

        Returns:
            bool: Returns True if the node has at least one child. Otherwise False.
        """
        return len(self.children) > 0


    @property
    def sourcelist(self) -> list:
        """
        Return list of source names.

        Returns:
            list(str): Returns list of source names.
        """
        return [s[0] for s in self.sources]


    @property
    def autodependency_links(self) -> list:
        """
        Return list of autodependency links.

        Returns:
            list: Returns list of autodependency links.

        """
        autodep_links = list()
        if self.is_autodependent:
            for s in self.sources: 
                if s[0] == self.name: 
                    autodep_links.append(s)
        return autodep_links


    @property
    def get_max_autodependent(self) -> float:
        """
        Return max score of autodependent link.

        Returns:
            float: Returns max score of autodependent link.
        """
        max_score = 0
        max_s = None
        if self.is_autodependent:
            for s in self.sources: 
                if s[0] == self.name:
                    if self.sources[s][SCORE] > max_score: max_s = s
        return max_s

Return list of autodependency links.

Returns:

Name Type Description
list list

Returns list of autodependency links.

get_max_autodependent: float property

Return max score of autodependent link.

Returns:

Name Type Description
float float

Returns max score of autodependent link.

has_child: bool property

Return True if the node has at least one child.

Returns:

Name Type Description
bool bool

Returns True if the node has at least one child. Otherwise False.

has_only_context: bool property

Return True if the node has ONLY the context variable as parent.

Returns:

Name Type Description
bool bool

Returns True if the node has ONLY the context variable as parent. Otherwise False.

is_autodependent: bool property

Return True if the node is autodependent.

Returns:

Name Type Description
bool bool

Returns True if the node is autodependent. Otherwise False.

is_exogenous: bool property

Return True if the node has no parents.

Returns:

Name Type Description
bool bool

Returns True if the node has no parents. Otherwise False.

is_isolated: bool property

Return True if the node is isolated.

Returns:

Name Type Description
bool bool

Returns True if the node is isolated. Otherwise False.

is_only_autodep: bool property

Return True if the node is ONLY auto-dependent.

Returns:

Name Type Description
bool bool

Returns True if the node is ONLY auto-dependent. Otherwise False.

is_only_autodep_context: bool property

Return True if the node has ONLY the context variable and itself as parent.

Returns:

Name Type Description
bool bool

Returns True if the node has ONLY the context variable and itself as parent. Otherwise False.

sourcelist: list property

Return list of source names.

Returns:

Name Type Description
list str

Returns list of source names.

__init__(name, neglect_autodep)

Class contructor.

Parameters:

Name Type Description Default
name str

node name.

required
neglect_autodep bool

flag to decide whether to to skip the node if it is only auto-dependent.

required
Source code in causalflow/graph/Node.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self, name, neglect_autodep):
    """
    Class contructor.

    Args:
        name (str): node name.
        neglect_autodep (bool): flag to decide whether to to skip the node if it is only auto-dependent.
    """
    self.name = name
    self.sources = dict()
    self.children = list()
    self.neglect_autodep = neglect_autodep
    self.intervention_node = False        
    self.associated_context = None        

This module provides the DAG class.

Classes

DAG: class for facilitating the handling and the creation of DAGs.

DAG

DAG class.

Source code in causalflow/graph/DAG.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
class DAG():
    """DAG class."""

    def __init__(self, var_names, min_lag, max_lag, neglect_autodep = False, scm = None):
        """
        DAG constructor.

        Args:
            var_names (list): variable list.
            min_lag (int): minimum time lag.
            max_lag (int): maximum time lag.
            neglect_autodep (bool, optional): bit to neglect nodes when they are only autodependent. Defaults to False.
            scm (dict, optional): Build the DAG for SCM. Defaults to None.
        """
        self.g = {var: Node(var, neglect_autodep) for var in var_names}
        self.neglect_autodep = neglect_autodep
        self.sys_context = dict()
        self.min_lag = min_lag
        self.max_lag = max_lag

        if scm is not None:
            for t in scm:
                    for s in scm[t]: 
                        if len(s) == 2:
                            self.add_source(t, s[0], 0.3, 0, s[1])
                        elif len(s) == 3:
                            self.add_source(t, s[0], 0.3, 0, s[1], s[2])


    @property
    def features(self) -> list:
        """
        Return features list.

        Returns:
            list: Features list.
        """
        return list(self.g.keys())


    @property
    def pretty_features(self) -> list:
        """
        Return list of features with LaTeX symbols.

        Returns:
            list(str): list of feature names.
        """
        return [r'$' + str(v) + '$' for v in self.g.keys()]


    @property
    def autodep_nodes(self) -> list:
        """
        Return the autodependent nodes list.

        Returns:
            list: Autodependent nodes list.
        """
        autodeps = list()
        for t in self.g:
            # NOTE: I commented this because I want to check all the auto-dep nodes with obs data
            # if self.g[t].is_autodependent and self.g[t].intervention_node: autodeps.append(t)
            if self.g[t].is_autodependent: autodeps.append(t)
        return autodeps


    @property
    def interventions_links(self) -> list:
        """
        Return the intervention links list.

        Returns:
            list: Intervention link list.
        """
        int_links = list()
        for t in self.g:
            for s in self.g[t].sources:
                if self.g[s[0]].intervention_node:
                    int_links.append((s[0], s[1], t))
        return int_links


    @property
    def max_auto_score(self) -> float:
        """
        Return maximum score of an auto-dependency link.

        Returns:
            float: maximum score of an auto-dependency link.
        """
        return max([self.g[t].sources[self.g[t].get_max_autodependent][SCORE] for t in self.g if self.g[t].is_autodependent])


    @property
    def max_cross_score(self) -> float:
        """
        Return maximum score of an cross-dependency link.

        Returns:
            float: maximum score of an cross-dependency link.
        """
        return max([self.g[t].sources[s][SCORE] if self.g[t].sources[s][SCORE] != float('inf') else 1 for t in self.g for s in self.g[t].sources if t != s[0]])


    def add_source(self, t, s, score, pval, lag, mode = LinkType.Directed.value):
        """
        Add source node to a target node.

        Args:
            t (str): target node name.
            s (str): source node name.
            score (float): dependency score.
            pval (float): dependency p-value.
            lag (int): dependency lag.
            mode (LinkType): link type. E.g., Directed -->
        """
        self.g[t].sources[(s, abs(lag))] = {SCORE: score, PVAL: pval, TYPE: mode}
        self.g[s].children.append(t)


    def del_source(self, t, s, lag):
        """
        Remove source node from a target node.

        Args:
            t (str): target node name.
            s (str): source node name.
            lag (int): dependency lag.
        """
        del self.g[t].sources[(s, lag)]
        self.g[s].children.remove(t)


    def remove_unneeded_features(self):
        """Remove isolated nodes."""
        tmp = copy.deepcopy(self.g)
        for t in self.g.keys():
            if self.g[t].is_isolated: 
                if self.g[t].intervention_node: del tmp[self.g[t].associated_context]
                del tmp[t]
        self.g = tmp


    def add_context(self):
        """Add context variables."""
        for sys_var, context_var in self.sys_context.items():
            if sys_var in self.features:

                # Adding context var to the graph
                self.g[context_var] = Node(context_var, self.neglect_autodep)

                # Adding context var to sys var
                self.g[sys_var].intervention_node = True
                self.g[sys_var].associated_context = context_var
                self.add_source(sys_var, context_var, 1, 0, 0)

        # NOTE: bi-directed link contemporanous link between context vars
        for sys_var, context_var in self.sys_context.items():
            if sys_var in self.features:
                other_context = [value for value in self.sys_context.values() if value != context_var and value in self.features]
                for other in other_context: self.add_source(context_var, other, 1, 0, 0)


    def remove_context(self):
        """Remove context variables."""
        for sys_var, context_var in self.sys_context.items():
            if sys_var in self.g:
                # Removing context var from sys var
                # self.g[sys_var].intervention_node = False
                self.g[sys_var].associated_context = None
                self.del_source(sys_var, context_var, 0)

                # Removing context var from dag
                del self.g[context_var]


    def get_link_assumptions(self, autodep_ok = False) -> dict:
        """
        Return link assumption dictionary.

        Args:
            autodep_ok (bool, optional): If true, autodependecy link assumption = -->. Otherwise -?>. Defaults to False.

        Returns:
            dict: link assumption dictionary.
        """
        link_assump = {self.features.index(f): dict() for f in self.features}
        for t in self.g:
            for s in self.g[t].sources:
                if autodep_ok and s[0] == t: # NOTE: new condition added in order to not control twice the autodependency links
                    link_assump[self.features.index(t)][(self.features.index(s[0]), -abs(s[1]))] = '-->'

                elif s[0] not in list(self.sys_context.values()):
                    if s[1] == 0 and (t, 0) in self.g[s[0]].sources:
                        link_assump[self.features.index(t)][(self.features.index(s[0]), 0)] = 'o-o'
                    elif s[1] == 0 and (t, 0) not in self.g[s[0]].sources:
                        link_assump[self.features.index(t)][(self.features.index(s[0]),0)] = '-?>'
                        link_assump[self.features.index(s[0])][(self.features.index(t), 0)] = '<?-'
                    elif s[1] > 0:
                        link_assump[self.features.index(t)][(self.features.index(s[0]), -abs(s[1]))] = '-?>'

                elif t in self.sys_context.keys() and s[0] == self.sys_context[t]:
                    link_assump[self.features.index(t)][(self.features.index(s[0]), -abs(s[1]))] = '-->'

        return link_assump


    def make_pretty(self) -> dict:
        """
        Make variables' names pretty, i.e. $ varname $ with '{' after '_' and '}' at the end of the string.

        Returns:
            dict: pretty DAG.
        """
        def prettify(name):
            return '$' + re.sub(r'_(\w+)', r'_{\1}', name) + '$'

        pretty = dict()
        for t in self.g:
            p_t = prettify(t)
            pretty[p_t] = copy.deepcopy(self.g[t])
            pretty[p_t].name = p_t
            pretty[p_t].children = [prettify(c) for c in self.g[t].children]
            for s in self.g[t].sources:
                del pretty[p_t].sources[s]
                p_s = prettify(s[0])
                pretty[p_t].sources[(p_s, s[1])] = {
                    SCORE: self.g[t].sources[s][SCORE],
                    PVAL: self.g[t].sources[s][PVAL],
                    TYPE: self.g[t].sources[s][TYPE]
                }
        return pretty


    def __add_edge(self, min_width, max_width, min_score, max_score, edges, edge_width, arrows, r, t, s, s_node, t_node):
        """
        Add edge to a graph. Support method for dag and ts_dag.

        Args:
            min_width (int): minimum linewidth. Defaults to 1.
            max_width (int): maximum linewidth. Defaults to 5.
            min_score (int): minimum score range. Defaults to 0.
            max_score (int): maximum score range. Defaults to 1.
            edges (list): list of edges.
            edge_width (dict): dictionary containing the width for each edge of the graph.
            arrows (dict): dictionary containing a bool for each edge of the graph describing if the edge is directed or not.
            r (DAG): DAG.
            t (str or tuple): target node.
            s (str or tuple): source node.
            s_node (str): source node.
            t_node (str): target node.

        Raises:
            ValueError: link type associated to this edge not included in our LinkType list.
        """
        edges.append((s_node, t_node))
        score = r.g[t].sources[s][SCORE] if r.g[t].sources[s][SCORE] != float('inf') else 1
        edge_width[(s_node, t_node)] = self.__scale(score, min_width, max_width, min_score, max_score)

        if r.g[t].sources[s][TYPE] == LinkType.Directed.value:
            arrows[(s_node, t_node)] = {'h':'>', 't':''}

        elif r.g[t].sources[s][TYPE] == LinkType.Bidirected.value:
            edges.append((t_node, s_node))
            edge_width[(t_node, s_node)] = self.__scale(score, min_width, max_width, min_score, max_score)
            arrows[(t_node, s_node)] = {'h':'>', 't':''}
            arrows[(s_node, t_node)] = {'h':'>', 't':''}

        elif r.g[t].sources[s][TYPE] == LinkType.HalfUncertain.value:
            arrows[(s_node, t_node)] = {'h':'>', 't':'o'}

        elif r.g[t].sources[s][TYPE] == LinkType.Uncertain.value:
            arrows[(s_node, t_node)] = {'h':'o', 't':'o'}

        else:
            raise ValueError(f"{r.g[t].sources[s][TYPE]} not included in LinkType")


    def dag(self,
        node_layout='dot',
        min_auto_width=0.25, 
        max_auto_width=0.75,
        min_cross_width=0.5, 
        max_cross_width=1.5,
        node_size=4, 
        node_color='orange',
        edge_color='grey',
        tail_color='black',
        font_size=8,
        label_type=LabelType.Lag,
        save_name=None,
        img_extention=ImageExt.PNG):
        """
        Build a dag, first with contemporaneous links, then lagged links.

        Args:
            node_layout (str, optional): Node layout. Defaults to 'dot'.
            min_auto_width (float, optional): minimum border linewidth. Defaults to 0.25.
            max_auto_width (float, optional): maximum border linewidth. Defaults to 0.75.
            min_cross_width (float, optional): minimum edge linewidth. Defaults to 0.5.
            max_cross_width (float, optional): maximum edge linewidth. Defaults to 1.5.
            node_size (int, optional): node size. Defaults to 4.
            node_color (str, optional): node color. Defaults to 'orange'.
            edge_color (str, optional): edge color for contemporaneous links. Defaults to 'grey'.
            tail_color (str, optional): tail color. Defaults to 'black'.
            font_size (int, optional): font size. Defaults to 8.
            label_type (LabelType, optional): Show the lag time (LabelType.Lag), the strength (LabelType.Score), or no labels (LabelType.NoLabels). Default LabelType.Lag.
            save_name (str, optional): Filename path. If None, plot is shown and not saved. Defaults to None.
            img_extention (ImageExt, optional): Image Extension. Defaults to PNG.
        """
        r = copy.deepcopy(self)
        r.g = r.make_pretty()

        Gcont = nx.DiGraph()
        Glag = nx.DiGraph()

        # 1. Nodes definition
        Gcont.add_nodes_from(r.g.keys())
        Glag.add_nodes_from(r.g.keys())

        # 2. Nodes border definition
        border = dict()
        for t in r.g:
            border[t] = 0
            if r.g[t].is_autodependent:
                border[t] = max(self.__scale(r.g[t].sources[r.g[t].get_max_autodependent][SCORE], 
                                             min_auto_width, max_auto_width, 
                                             0, self.max_auto_score), 
                                border[t])

        # 3. Nodes border label definition
        node_label = None
        if label_type == LabelType.Lag or label_type == LabelType.Score:
            node_label = {t: [] for t in r.g.keys()}
            for t in r.g:
                if r.g[t].is_autodependent:
                    autodep = r.g[t].get_max_autodependent
                    if label_type == LabelType.Lag:
                        node_label[t].append(autodep[1])
                    elif label_type == LabelType.Score:
                        node_label[t].append(round(r.g[t].sources[autodep][SCORE], 3))
                node_label[t] = ",".join(str(s) for s in node_label[t])

        # 3. Edges definition
        cont_edges = []
        cont_edge_width = dict()
        cont_arrows = {}
        lagged_edges = []
        lagged_edge_width = dict()
        lagged_arrows = {}

        for t in r.g:
            for s in r.g[t].sources:
                if t != s[0]:  # skip self-loops
                    if s[1] == 0:  # Contemporaneous link (no lag)
                        self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score,
                                        cont_edges, cont_edge_width, cont_arrows, r, t, s, s[0], t)
                    else:  # Lagged link
                        self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score,
                                        lagged_edges, lagged_edge_width, lagged_arrows, r, t, s, s[0], t)

        Gcont.add_edges_from(cont_edges)
        Glag.add_edges_from(lagged_edges)


        fig, ax = plt.subplots(figsize=(8, 6))

        # 4. Edges label definition
        cont_edge_label = None
        lagged_edge_label = None
        if label_type == LabelType.Lag or label_type == LabelType.Score:
            cont_edge_label = {(s[0], t): [] for t in r.g for s in r.g[t].sources if t != s[0] and s[1] == 0}
            lagged_edge_label = {(s[0], t): [] for t in r.g for s in r.g[t].sources if t != s[0] and s[1] != 0}
            for t in r.g:
                for s in r.g[t].sources:
                    if t != s[0]:
                        if s[1] == 0:  # Contemporaneous
                            if label_type == LabelType.Lag:
                                cont_edge_label[(s[0], t)].append(s[1])
                            elif label_type == LabelType.Score:
                                cont_edge_label[(s[0], t)].append(round(r.g[t].sources[s][SCORE], 3))
                        else:  # Lagged
                            if label_type == LabelType.Lag:
                                lagged_edge_label[(s[0], t)].append(s[1])
                            elif label_type == LabelType.Score:
                                lagged_edge_label[(s[0], t)].append(round(r.g[t].sources[s][SCORE], 3))
            for k in cont_edge_label.keys():
                cont_edge_label[k] = ",".join(str(s) for s in cont_edge_label[k])
            for k in lagged_edge_label.keys():
                lagged_edge_label[k] = ",".join(str(s) for s in lagged_edge_label[k])

        # 5. Draw graph - contemporaneous
        if cont_edges:
            a = Graph(Gcont,
                    node_layout=node_layout,
                    node_size=node_size,
                    node_color=node_color,
                    node_labels=None,
                    node_edge_width=border,
                    node_label_fontdict=dict(size=font_size),
                    node_edge_color=edge_color,
                    node_label_offset=0.05,
                    node_alpha=1,

                    arrows=cont_arrows,
                    edge_layout='straight',
                    edge_label=label_type != LabelType.NoLabels,
                    edge_labels=cont_edge_label,
                    edge_label_fontdict=dict(size=font_size),
                    edge_color=edge_color,
                    tail_color=tail_color,
                    edge_width=cont_edge_width,
                    edge_alpha=1,
                    edge_zorder=1,
                    edge_label_position=0.35)

            nx.draw_networkx_labels(Gcont,
                                    pos=a.node_positions,
                                    labels={n: n for n in Gcont},
                                    font_size=font_size)

        # 6. Draw graph - lagged
        if lagged_edges:
            a = Graph(Glag,
                    node_layout=a.node_positions if cont_edges else node_layout,
                    node_size=node_size,
                    node_color=node_color,
                    node_labels=node_label,
                    node_edge_width=border,
                    node_label_fontdict=dict(size=font_size),
                    node_edge_color=edge_color,
                    node_label_offset=0.05,
                    node_alpha=1,

                    arrows=lagged_arrows,
                    edge_layout='curved',
                    edge_label=label_type != LabelType.NoLabels,
                    edge_labels=lagged_edge_label,
                    edge_label_fontdict=dict(size=font_size),
                    edge_color=edge_color,
                    tail_color=tail_color,
                    edge_width=lagged_edge_width,
                    edge_alpha=1,
                    edge_zorder=1,
                    edge_label_position=0.35)

            if not cont_edges:
                nx.draw_networkx_labels(Glag,
                                        pos=a.node_positions,
                                        labels={n: n for n in Glag},
                                        font_size=font_size)
        # 7. Plot or save
        if save_name is not None:
            plt.savefig(save_name + img_extention.value, dpi=300)
        else:
            plt.show()


    def ts_dag(self,
               min_cross_width = 1, 
               max_cross_width = 5,
               node_size = 8,
               x_disp = 1.5,
               y_disp = 0.2,
               text_disp = 0.1,
               node_color = 'orange',
               edge_color = 'grey',
               tail_color = 'black',
               font_size = 8,
               save_name = None,
               img_extention = ImageExt.PNG):
        """
        Build a timeseries dag.

        Args:
            min_cross_width (int, optional): minimum linewidth. Defaults to 1.
            max_cross_width (int, optional): maximum linewidth. Defaults to 5.
            node_size (int, optional): node size. Defaults to 8.
            x_disp (float, optional): node displacement along x. Defaults to 1.5.
            y_disp (float, optional): node displacement along y. Defaults to 0.2.
            text_disp (float, optional): text displacement along y. Defaults to 0.1.
            node_color (str/list, optional): node color. 
                                             If a string, all the nodes will have the same colour. 
                                             If a list (same dimension of features), each colour will have the specified colour.
                                             Defaults to 'orange'.
            edge_color (str, optional): edge color. Defaults to 'grey'.
            tail_color (str, optional): tail color. Defaults to 'black'.
            font_size (int, optional): font size. Defaults to 8.
            save_name (str, optional): Filename path. If None, plot is shown and not saved. Defaults to None.
            img_extention (ImageExt, optional): Image Extension. Defaults to PNG.
        """
        r = copy.deepcopy(self)
        r.g = r.make_pretty()

        Gcont = nx.DiGraph()
        Glagcross = nx.DiGraph()
        Glagauto = nx.DiGraph()

        # 1. Nodes definition
        if isinstance(node_color, list):
            node_c = dict()
        else:
            node_c = node_color
        for i in range(len(self.features)):
            for j in range(self.max_lag + 1):
                Glagauto.add_node((j, i))
                Glagcross.add_node((j, i))
                Gcont.add_node((j, i))
                if isinstance(node_color, list): node_c[(j, i)] = node_color[abs(i - (len(r.g.keys()) - 1))]

        pos = {n : (n[0]*x_disp, n[1]*y_disp) for n in Glagauto.nodes()}
        scale = max(pos.values())

        # 2. Edges definition
        cont_edges = list()
        cont_edge_width = dict()
        cont_arrows = dict()
        lagged_cross_edges = list()
        lagged_cross_edge_width = dict()
        lagged_cross_arrows = dict()
        lagged_auto_edges = list()
        lagged_auto_edge_width = dict()
        lagged_auto_arrows = dict()

        for t in r.g:
            for s in r.g[t].sources:
                s_index = len(r.g.keys())-1 - list(r.g.keys()).index(s[0])
                t_index = len(r.g.keys())-1 - list(r.g.keys()).index(t)

                # 2.1. Contemporaneous edges definition
                if s[1] == 0:
                    for i in range(self.max_lag + 1):
                        s_node = (i, s_index)
                        t_node = (i, t_index)
                        self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score, 
                                        cont_edges, cont_edge_width, cont_arrows, r, t, s, 
                                        s_node, t_node)

                else:
                    s_lag = self.max_lag - s[1]
                    t_lag = self.max_lag
                    while s_lag >= 0:
                        s_node = (s_lag, s_index)
                        t_node = (t_lag, t_index)
                        # 2.2. Lagged cross edges definition
                        if s[0] != t:
                            self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score, 
                                            lagged_cross_edges, lagged_cross_edge_width, lagged_cross_arrows, r, t, s, 
                                            s_node, t_node)
                        # 2.3. Lagged auto edges definition
                        else:
                            self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score, 
                                            lagged_auto_edges, lagged_auto_edge_width, lagged_auto_arrows, r, t, s, 
                                            s_node, t_node)
                        s_lag -= 1
                        t_lag -= 1

        Gcont.add_edges_from(cont_edges)
        Glagcross.add_edges_from(lagged_cross_edges)
        Glagauto.add_edges_from(lagged_auto_edges)

        fig, ax = plt.subplots(figsize=(8,6))
        edge_layout = self.__get_fixed_edges(ax, x_disp, Gcont, node_size, pos, node_c, font_size, 
                                             cont_arrows, edge_color, tail_color, cont_edge_width, scale)

        # 3. Label definition
        for n in Gcont.nodes():
            if n[0] == 0:
                ax.text(pos[n][0] - text_disp, pos[n][1], list(r.g.keys())[len(r.g.keys()) - 1 - n[1]], horizontalalignment='center', verticalalignment='center', fontsize=font_size)

        # 4. Time line text drawing
        pos_tau = set([pos[p][0] for p in pos])
        max_y = max([pos[p][1] for p in pos])
        for p in pos_tau:
            if abs(int(p/x_disp) - self.max_lag) == 0:
                ax.text(p, max_y + 0.1, r"$t$", horizontalalignment='center', fontsize=font_size)
            else:
                ax.text(p, max_y + 0.1, r"$t-" + str(abs(int(p/x_disp) - self.max_lag)) + "$", horizontalalignment='center', fontsize=font_size)

        # 5. Draw graph - contemporaneous
        if cont_edges:
            a = Graph(Gcont,
                    node_layout={p : np.array(pos[p]) for p in pos},
                    node_size=node_size,
                    node_color=node_c,
                    node_edge_width=0,
                    node_label_fontdict=dict(size=font_size),
                    node_label_offset=0,
                    node_alpha=1,

                    arrows=cont_arrows,
                    edge_layout=edge_layout,
                    edge_label=False,
                    edge_color=edge_color,
                    tail_color=tail_color,
                    edge_width=cont_edge_width,
                    edge_alpha=1,
                    edge_zorder=1,
                    scale = (scale[0] + 2, scale[1] + 2))

        # 6. Draw graph - lagged cross
        if lagged_cross_edges:
            a = Graph(Glagcross,
                    node_layout={p : np.array(pos[p]) for p in pos},
                    node_size=node_size,
                    node_color=node_c,
                    node_edge_width=0,
                    node_label_fontdict=dict(size=font_size),
                    node_label_offset=0,
                    node_alpha=1,

                    arrows=lagged_cross_arrows,
                    edge_layout='curved',
                    edge_label=False,
                    edge_color=edge_color,
                    tail_color=tail_color,
                    edge_width=lagged_cross_edge_width,
                    edge_alpha=1,
                    edge_zorder=1,
                    scale = (scale[0] + 2, scale[1] + 2))

        # 7. Draw graph - lagged auto
        if lagged_auto_edges:
            a = Graph(Glagauto,
                    node_layout={p : np.array(pos[p]) for p in pos},
                    node_size=node_size,
                    node_color=node_c,
                    node_edge_width=0,
                    node_label_fontdict=dict(size=font_size),
                    node_label_offset=0,
                    node_alpha=1,

                    arrows=lagged_auto_arrows,
                    edge_layout='straight',
                    edge_label=False,
                    edge_color=edge_color,
                    tail_color=tail_color,
                    edge_width=lagged_auto_edge_width,
                    edge_alpha=1,
                    edge_zorder=1,
                    scale = (scale[0] + 2, scale[1] + 2))

        # 7. Plot or save
        if save_name is not None:
            plt.savefig(save_name + img_extention.value, dpi = 300)
        else:
            plt.show()


    def __get_fixed_edges(self, ax, x_disp, Gcont, node_size, pos, node_c, font_size, cont_arrows, edge_color, tail_color, cont_edge_width, scale) -> dict:
        """
        Fix edge paths at t-tau_max.

        Args:
            ax (Axes): figure axis.
            x_disp (float): node displacement along x. Defaults to 1.5.
            Gcont (DiGraph): Direct Graph containing only contemporaneous links.
            node_size (int): node size.
            pos (dict): node layout.
            node_c (str/list, optional): node color. 
                                         If a string, all the nodes will have the same colour. 
                                         If a list (same dimension of features), each colour will have the specified colour.
            font_size (int): font size.
            cont_arrows (dict): edge-arrows dictionary .
            edge_color (str): edge color.
            tail_color (str): tail color.
            cont_edge_width (dict): edge-width dictionary.
            scale (tuple): graph scale.

        Returns:
            dict: new edge paths
        """
        a = Graph(Gcont,
                  node_layout={p : np.array(pos[p]) for p in pos},
                  node_size=node_size,
                  node_color=node_c,
                  node_edge_width=0,
                  node_label_fontdict=dict(size=font_size),
                  node_label_offset=0,
                  node_alpha=1,

                  arrows=cont_arrows,
                  edge_layout='curved',
                  edge_label=False,
                  edge_color=edge_color,
                  tail_color=tail_color,
                  edge_width=cont_edge_width,
                  edge_alpha=1,
                  edge_zorder=1,
                  scale = (scale[0] + 2, scale[1] + 2))
        res = copy.deepcopy(a.edge_layout.edge_paths)
        for edge, edge_path in a.edge_layout.edge_paths.items():
            if edge[0][0] == self.max_lag and edge[1][0] == self.max_lag: # t
                for t in range(0, self.max_lag):
                    for fixed_edge in a.edge_layout.edge_paths.keys():
                        if fixed_edge == edge: continue
                        if fixed_edge[0][0] == t and fixed_edge[0][1] == edge[0][1] and fixed_edge[1][0] == t and fixed_edge[1][1] == edge[1][1]:
                            res[fixed_edge] = edge_path - np.array([(self.max_lag - t)*x_disp,0])*np.ones_like(a.edge_layout.edge_paths[edge])
            # if edge[0][0] == 0 and edge[1][0] == 0: # t-tau_max
            #     for shifted_edge, shifted_edge_path in a.edge_layout.edge_paths.items():
            #         if shifted_edge == edge: continue
            #         if shifted_edge[0][0] == self.max_lag and shifted_edge[0][1] == edge[0][1] and shifted_edge[1][0] == self.max_lag and shifted_edge[1][1] == edge[1][1]:
            #             res[edge] = shifted_edge_path - np.array([x_disp,0])*np.ones_like(a.edge_layout.edge_paths[shifted_edge])
        ax.clear()              
        return res

    def __scale(self, score, min_width, max_width, min_score = 0, max_score = 1):
        """
        Scale the score of the cause-effect relationship strength to a linewitdth.

        Args:
            score (float): score to scale.
            min_width (float): minimum linewidth.
            max_width (float): maximum linewidth.
            min_score (int, optional): minimum score range. Defaults to 0.
            max_score (int, optional): maximum score range. Defaults to 1.

        Returns:
            (float): scaled score.
        """
        return ((score - min_score) / (max_score - min_score)) * (max_width - min_width) + min_width


    def get_skeleton(self) -> np.array:
        """
        Return skeleton matrix.

        Skeleton matrix is composed by 0 and 1.
        1 <- if there is a link from source to target 
        0 <- if there is not a link from source to target 

        Returns:
            np.array: skeleton matrix
        """
        r = np.full((len(self.features), len(self.features), self.max_lag + 1), '', dtype=object)
        for t in self.g.keys():
            for s in self.g[t].sources:
                r[self.features.index(t), self.features.index(s[0])][s[1]] = 1
        return np.array(r)


    def get_val_matrix(self) -> np.array:
        """
        Return val matrix.

        Val matrix contains information about the strength of the links componing the causal model.

        Returns:
            np.array: val matrix.
        """
        r = np.zeros((len(self.features), len(self.features), self.max_lag + 1))
        for t in self.g.keys():
            for s, info in self.g[t].sources.items():
                    r[self.features.index(t), self.features.index(s[0])][s[1]] = info[SCORE]
        return np.array(r)


    def get_pval_matrix(self) -> np.array:
        """
        Return pval matrix.

        Pval matrix contains information about the pval of the links componing the causal model.

        Returns:
            np.array: pval matrix
        """
        r = np.zeros((len(self.features), len(self.features), self.max_lag + 1))
        for t in self.g.keys():
            for s, info in self.g[t].sources.items():
                r[self.features.index(t), self.features.index(s[0])][s[1]] = info[PVAL]
        return np.array(r)


    def get_graph_matrix(self) -> np.array:
        """
        Return graph matrix.

        Graph matrix contains information about the link type. E.g., -->, <->, ..

        Returns:
            np.array: graph matrix.
        """
        r = np.full((len(self.features), len(self.features), self.max_lag + 1), '', dtype=object)
        for t in self.g.keys():
            for s, info in self.g[t].sources.items():
                r[self.features.index(t), self.features.index(s[0])][s[1]] = info[TYPE]
        return np.array(r)


    def get_Adj(self, indexed = False) -> dict:   
        """
        Return Adjacency dictionary.

        If indexed = True: example {0: [(0, -1), (1, -2)], 1: [], ...}
        If indexed = False: example {"X_0": [(X_0, -1), (X_1, -2)], "X_1": [], ...}

        Args:
            indexed (bool, optional): If true, returns the SCM with index instead of variables' names. Otherwise it uses variables' names. Defaults to False.

        Returns:
            dict: SCM.
        """
        if not indexed:
            scm = {v: list() for v in self.features}
            for t in self.g:
                for s in self.g[t].sources:
                    scm[t].append((s[0], -abs(s[1]))) 
        else:
            scm = {self.features.index(v): list() for v in self.features}
            for t in self.g:
                for s in self.g[t].sources:
                    scm[self.features.index(t)].append((self.features.index(s[0]), -abs(s[1]))) 
        return scm


    def get_Graph(self) -> dict:
        """
        Return Graph dictionary. E.g. {X1: {(X2, -2): '-->'}, X2: {(X3, -1): '-?>'}, X3: {(X3, -1): '-->'}}.

        Returns:
            dict: graph dictionary.
        """
        scm = {v: dict() for v in self.features}
        for t in self.g:
            for s in self.g[t].sources:
                scm[t][(s[0], -abs(s[1]))] = self.g[t].sources[s][TYPE] 
        return scm

autodep_nodes: list property

Return the autodependent nodes list.

Returns:

Name Type Description
list list

Autodependent nodes list.

features: list property

Return features list.

Returns:

Name Type Description
list list

Features list.

Return the intervention links list.

Returns:

Name Type Description
list list

Intervention link list.

max_auto_score: float property

Return maximum score of an auto-dependency link.

Returns:

Name Type Description
float float

maximum score of an auto-dependency link.

max_cross_score: float property

Return maximum score of an cross-dependency link.

Returns:

Name Type Description
float float

maximum score of an cross-dependency link.

pretty_features: list property

Return list of features with LaTeX symbols.

Returns:

Name Type Description
list str

list of feature names.

__add_edge(min_width, max_width, min_score, max_score, edges, edge_width, arrows, r, t, s, s_node, t_node)

Add edge to a graph. Support method for dag and ts_dag.

Parameters:

Name Type Description Default
min_width int

minimum linewidth. Defaults to 1.

required
max_width int

maximum linewidth. Defaults to 5.

required
min_score int

minimum score range. Defaults to 0.

required
max_score int

maximum score range. Defaults to 1.

required
edges list

list of edges.

required
edge_width dict

dictionary containing the width for each edge of the graph.

required
arrows dict

dictionary containing a bool for each edge of the graph describing if the edge is directed or not.

required
r DAG

DAG.

required
t str or tuple

target node.

required
s str or tuple

source node.

required
s_node str

source node.

required
t_node str

target node.

required

Raises:

Type Description
ValueError

link type associated to this edge not included in our LinkType list.

Source code in causalflow/graph/DAG.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def __add_edge(self, min_width, max_width, min_score, max_score, edges, edge_width, arrows, r, t, s, s_node, t_node):
    """
    Add edge to a graph. Support method for dag and ts_dag.

    Args:
        min_width (int): minimum linewidth. Defaults to 1.
        max_width (int): maximum linewidth. Defaults to 5.
        min_score (int): minimum score range. Defaults to 0.
        max_score (int): maximum score range. Defaults to 1.
        edges (list): list of edges.
        edge_width (dict): dictionary containing the width for each edge of the graph.
        arrows (dict): dictionary containing a bool for each edge of the graph describing if the edge is directed or not.
        r (DAG): DAG.
        t (str or tuple): target node.
        s (str or tuple): source node.
        s_node (str): source node.
        t_node (str): target node.

    Raises:
        ValueError: link type associated to this edge not included in our LinkType list.
    """
    edges.append((s_node, t_node))
    score = r.g[t].sources[s][SCORE] if r.g[t].sources[s][SCORE] != float('inf') else 1
    edge_width[(s_node, t_node)] = self.__scale(score, min_width, max_width, min_score, max_score)

    if r.g[t].sources[s][TYPE] == LinkType.Directed.value:
        arrows[(s_node, t_node)] = {'h':'>', 't':''}

    elif r.g[t].sources[s][TYPE] == LinkType.Bidirected.value:
        edges.append((t_node, s_node))
        edge_width[(t_node, s_node)] = self.__scale(score, min_width, max_width, min_score, max_score)
        arrows[(t_node, s_node)] = {'h':'>', 't':''}
        arrows[(s_node, t_node)] = {'h':'>', 't':''}

    elif r.g[t].sources[s][TYPE] == LinkType.HalfUncertain.value:
        arrows[(s_node, t_node)] = {'h':'>', 't':'o'}

    elif r.g[t].sources[s][TYPE] == LinkType.Uncertain.value:
        arrows[(s_node, t_node)] = {'h':'o', 't':'o'}

    else:
        raise ValueError(f"{r.g[t].sources[s][TYPE]} not included in LinkType")

__get_fixed_edges(ax, x_disp, Gcont, node_size, pos, node_c, font_size, cont_arrows, edge_color, tail_color, cont_edge_width, scale)

Fix edge paths at t-tau_max.

Parameters:

Name Type Description Default
ax Axes

figure axis.

required
x_disp float

node displacement along x. Defaults to 1.5.

required
Gcont DiGraph

Direct Graph containing only contemporaneous links.

required
node_size int

node size.

required
pos dict

node layout.

required
node_c str / list

node color. If a string, all the nodes will have the same colour. If a list (same dimension of features), each colour will have the specified colour.

required
font_size int

font size.

required
cont_arrows dict

edge-arrows dictionary .

required
edge_color str

edge color.

required
tail_color str

tail color.

required
cont_edge_width dict

edge-width dictionary.

required
scale tuple

graph scale.

required

Returns:

Name Type Description
dict dict

new edge paths

Source code in causalflow/graph/DAG.py
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
def __get_fixed_edges(self, ax, x_disp, Gcont, node_size, pos, node_c, font_size, cont_arrows, edge_color, tail_color, cont_edge_width, scale) -> dict:
    """
    Fix edge paths at t-tau_max.

    Args:
        ax (Axes): figure axis.
        x_disp (float): node displacement along x. Defaults to 1.5.
        Gcont (DiGraph): Direct Graph containing only contemporaneous links.
        node_size (int): node size.
        pos (dict): node layout.
        node_c (str/list, optional): node color. 
                                     If a string, all the nodes will have the same colour. 
                                     If a list (same dimension of features), each colour will have the specified colour.
        font_size (int): font size.
        cont_arrows (dict): edge-arrows dictionary .
        edge_color (str): edge color.
        tail_color (str): tail color.
        cont_edge_width (dict): edge-width dictionary.
        scale (tuple): graph scale.

    Returns:
        dict: new edge paths
    """
    a = Graph(Gcont,
              node_layout={p : np.array(pos[p]) for p in pos},
              node_size=node_size,
              node_color=node_c,
              node_edge_width=0,
              node_label_fontdict=dict(size=font_size),
              node_label_offset=0,
              node_alpha=1,

              arrows=cont_arrows,
              edge_layout='curved',
              edge_label=False,
              edge_color=edge_color,
              tail_color=tail_color,
              edge_width=cont_edge_width,
              edge_alpha=1,
              edge_zorder=1,
              scale = (scale[0] + 2, scale[1] + 2))
    res = copy.deepcopy(a.edge_layout.edge_paths)
    for edge, edge_path in a.edge_layout.edge_paths.items():
        if edge[0][0] == self.max_lag and edge[1][0] == self.max_lag: # t
            for t in range(0, self.max_lag):
                for fixed_edge in a.edge_layout.edge_paths.keys():
                    if fixed_edge == edge: continue
                    if fixed_edge[0][0] == t and fixed_edge[0][1] == edge[0][1] and fixed_edge[1][0] == t and fixed_edge[1][1] == edge[1][1]:
                        res[fixed_edge] = edge_path - np.array([(self.max_lag - t)*x_disp,0])*np.ones_like(a.edge_layout.edge_paths[edge])
        # if edge[0][0] == 0 and edge[1][0] == 0: # t-tau_max
        #     for shifted_edge, shifted_edge_path in a.edge_layout.edge_paths.items():
        #         if shifted_edge == edge: continue
        #         if shifted_edge[0][0] == self.max_lag and shifted_edge[0][1] == edge[0][1] and shifted_edge[1][0] == self.max_lag and shifted_edge[1][1] == edge[1][1]:
        #             res[edge] = shifted_edge_path - np.array([x_disp,0])*np.ones_like(a.edge_layout.edge_paths[shifted_edge])
    ax.clear()              
    return res

__init__(var_names, min_lag, max_lag, neglect_autodep=False, scm=None)

DAG constructor.

Parameters:

Name Type Description Default
var_names list

variable list.

required
min_lag int

minimum time lag.

required
max_lag int

maximum time lag.

required
neglect_autodep bool

bit to neglect nodes when they are only autodependent. Defaults to False.

False
scm dict

Build the DAG for SCM. Defaults to None.

None
Source code in causalflow/graph/DAG.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(self, var_names, min_lag, max_lag, neglect_autodep = False, scm = None):
    """
    DAG constructor.

    Args:
        var_names (list): variable list.
        min_lag (int): minimum time lag.
        max_lag (int): maximum time lag.
        neglect_autodep (bool, optional): bit to neglect nodes when they are only autodependent. Defaults to False.
        scm (dict, optional): Build the DAG for SCM. Defaults to None.
    """
    self.g = {var: Node(var, neglect_autodep) for var in var_names}
    self.neglect_autodep = neglect_autodep
    self.sys_context = dict()
    self.min_lag = min_lag
    self.max_lag = max_lag

    if scm is not None:
        for t in scm:
                for s in scm[t]: 
                    if len(s) == 2:
                        self.add_source(t, s[0], 0.3, 0, s[1])
                    elif len(s) == 3:
                        self.add_source(t, s[0], 0.3, 0, s[1], s[2])

__scale(score, min_width, max_width, min_score=0, max_score=1)

Scale the score of the cause-effect relationship strength to a linewitdth.

Parameters:

Name Type Description Default
score float

score to scale.

required
min_width float

minimum linewidth.

required
max_width float

maximum linewidth.

required
min_score int

minimum score range. Defaults to 0.

0
max_score int

maximum score range. Defaults to 1.

1

Returns:

Type Description
float

scaled score.

Source code in causalflow/graph/DAG.py
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
def __scale(self, score, min_width, max_width, min_score = 0, max_score = 1):
    """
    Scale the score of the cause-effect relationship strength to a linewitdth.

    Args:
        score (float): score to scale.
        min_width (float): minimum linewidth.
        max_width (float): maximum linewidth.
        min_score (int, optional): minimum score range. Defaults to 0.
        max_score (int, optional): maximum score range. Defaults to 1.

    Returns:
        (float): scaled score.
    """
    return ((score - min_score) / (max_score - min_score)) * (max_width - min_width) + min_width

add_context()

Add context variables.

Source code in causalflow/graph/DAG.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def add_context(self):
    """Add context variables."""
    for sys_var, context_var in self.sys_context.items():
        if sys_var in self.features:

            # Adding context var to the graph
            self.g[context_var] = Node(context_var, self.neglect_autodep)

            # Adding context var to sys var
            self.g[sys_var].intervention_node = True
            self.g[sys_var].associated_context = context_var
            self.add_source(sys_var, context_var, 1, 0, 0)

    # NOTE: bi-directed link contemporanous link between context vars
    for sys_var, context_var in self.sys_context.items():
        if sys_var in self.features:
            other_context = [value for value in self.sys_context.values() if value != context_var and value in self.features]
            for other in other_context: self.add_source(context_var, other, 1, 0, 0)

add_source(t, s, score, pval, lag, mode=LinkType.Directed.value)

Add source node to a target node.

Parameters:

Name Type Description Default
t str

target node name.

required
s str

source node name.

required
score float

dependency score.

required
pval float

dependency p-value.

required
lag int

dependency lag.

required
mode LinkType

link type. E.g., Directed -->

LinkType.Directed.value
Source code in causalflow/graph/DAG.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def add_source(self, t, s, score, pval, lag, mode = LinkType.Directed.value):
    """
    Add source node to a target node.

    Args:
        t (str): target node name.
        s (str): source node name.
        score (float): dependency score.
        pval (float): dependency p-value.
        lag (int): dependency lag.
        mode (LinkType): link type. E.g., Directed -->
    """
    self.g[t].sources[(s, abs(lag))] = {SCORE: score, PVAL: pval, TYPE: mode}
    self.g[s].children.append(t)

dag(node_layout='dot', min_auto_width=0.25, max_auto_width=0.75, min_cross_width=0.5, max_cross_width=1.5, node_size=4, node_color='orange', edge_color='grey', tail_color='black', font_size=8, label_type=LabelType.Lag, save_name=None, img_extention=ImageExt.PNG)

Build a dag, first with contemporaneous links, then lagged links.

Parameters:

Name Type Description Default
node_layout str

Node layout. Defaults to 'dot'.

'dot'
min_auto_width float

minimum border linewidth. Defaults to 0.25.

0.25
max_auto_width float

maximum border linewidth. Defaults to 0.75.

0.75
min_cross_width float

minimum edge linewidth. Defaults to 0.5.

0.5
max_cross_width float

maximum edge linewidth. Defaults to 1.5.

1.5
node_size int

node size. Defaults to 4.

4
node_color str

node color. Defaults to 'orange'.

'orange'
edge_color str

edge color for contemporaneous links. Defaults to 'grey'.

'grey'
tail_color str

tail color. Defaults to 'black'.

'black'
font_size int

font size. Defaults to 8.

8
label_type LabelType

Show the lag time (LabelType.Lag), the strength (LabelType.Score), or no labels (LabelType.NoLabels). Default LabelType.Lag.

LabelType.Lag
save_name str

Filename path. If None, plot is shown and not saved. Defaults to None.

None
img_extention ImageExt

Image Extension. Defaults to PNG.

ImageExt.PNG
Source code in causalflow/graph/DAG.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
def dag(self,
    node_layout='dot',
    min_auto_width=0.25, 
    max_auto_width=0.75,
    min_cross_width=0.5, 
    max_cross_width=1.5,
    node_size=4, 
    node_color='orange',
    edge_color='grey',
    tail_color='black',
    font_size=8,
    label_type=LabelType.Lag,
    save_name=None,
    img_extention=ImageExt.PNG):
    """
    Build a dag, first with contemporaneous links, then lagged links.

    Args:
        node_layout (str, optional): Node layout. Defaults to 'dot'.
        min_auto_width (float, optional): minimum border linewidth. Defaults to 0.25.
        max_auto_width (float, optional): maximum border linewidth. Defaults to 0.75.
        min_cross_width (float, optional): minimum edge linewidth. Defaults to 0.5.
        max_cross_width (float, optional): maximum edge linewidth. Defaults to 1.5.
        node_size (int, optional): node size. Defaults to 4.
        node_color (str, optional): node color. Defaults to 'orange'.
        edge_color (str, optional): edge color for contemporaneous links. Defaults to 'grey'.
        tail_color (str, optional): tail color. Defaults to 'black'.
        font_size (int, optional): font size. Defaults to 8.
        label_type (LabelType, optional): Show the lag time (LabelType.Lag), the strength (LabelType.Score), or no labels (LabelType.NoLabels). Default LabelType.Lag.
        save_name (str, optional): Filename path. If None, plot is shown and not saved. Defaults to None.
        img_extention (ImageExt, optional): Image Extension. Defaults to PNG.
    """
    r = copy.deepcopy(self)
    r.g = r.make_pretty()

    Gcont = nx.DiGraph()
    Glag = nx.DiGraph()

    # 1. Nodes definition
    Gcont.add_nodes_from(r.g.keys())
    Glag.add_nodes_from(r.g.keys())

    # 2. Nodes border definition
    border = dict()
    for t in r.g:
        border[t] = 0
        if r.g[t].is_autodependent:
            border[t] = max(self.__scale(r.g[t].sources[r.g[t].get_max_autodependent][SCORE], 
                                         min_auto_width, max_auto_width, 
                                         0, self.max_auto_score), 
                            border[t])

    # 3. Nodes border label definition
    node_label = None
    if label_type == LabelType.Lag or label_type == LabelType.Score:
        node_label = {t: [] for t in r.g.keys()}
        for t in r.g:
            if r.g[t].is_autodependent:
                autodep = r.g[t].get_max_autodependent
                if label_type == LabelType.Lag:
                    node_label[t].append(autodep[1])
                elif label_type == LabelType.Score:
                    node_label[t].append(round(r.g[t].sources[autodep][SCORE], 3))
            node_label[t] = ",".join(str(s) for s in node_label[t])

    # 3. Edges definition
    cont_edges = []
    cont_edge_width = dict()
    cont_arrows = {}
    lagged_edges = []
    lagged_edge_width = dict()
    lagged_arrows = {}

    for t in r.g:
        for s in r.g[t].sources:
            if t != s[0]:  # skip self-loops
                if s[1] == 0:  # Contemporaneous link (no lag)
                    self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score,
                                    cont_edges, cont_edge_width, cont_arrows, r, t, s, s[0], t)
                else:  # Lagged link
                    self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score,
                                    lagged_edges, lagged_edge_width, lagged_arrows, r, t, s, s[0], t)

    Gcont.add_edges_from(cont_edges)
    Glag.add_edges_from(lagged_edges)


    fig, ax = plt.subplots(figsize=(8, 6))

    # 4. Edges label definition
    cont_edge_label = None
    lagged_edge_label = None
    if label_type == LabelType.Lag or label_type == LabelType.Score:
        cont_edge_label = {(s[0], t): [] for t in r.g for s in r.g[t].sources if t != s[0] and s[1] == 0}
        lagged_edge_label = {(s[0], t): [] for t in r.g for s in r.g[t].sources if t != s[0] and s[1] != 0}
        for t in r.g:
            for s in r.g[t].sources:
                if t != s[0]:
                    if s[1] == 0:  # Contemporaneous
                        if label_type == LabelType.Lag:
                            cont_edge_label[(s[0], t)].append(s[1])
                        elif label_type == LabelType.Score:
                            cont_edge_label[(s[0], t)].append(round(r.g[t].sources[s][SCORE], 3))
                    else:  # Lagged
                        if label_type == LabelType.Lag:
                            lagged_edge_label[(s[0], t)].append(s[1])
                        elif label_type == LabelType.Score:
                            lagged_edge_label[(s[0], t)].append(round(r.g[t].sources[s][SCORE], 3))
        for k in cont_edge_label.keys():
            cont_edge_label[k] = ",".join(str(s) for s in cont_edge_label[k])
        for k in lagged_edge_label.keys():
            lagged_edge_label[k] = ",".join(str(s) for s in lagged_edge_label[k])

    # 5. Draw graph - contemporaneous
    if cont_edges:
        a = Graph(Gcont,
                node_layout=node_layout,
                node_size=node_size,
                node_color=node_color,
                node_labels=None,
                node_edge_width=border,
                node_label_fontdict=dict(size=font_size),
                node_edge_color=edge_color,
                node_label_offset=0.05,
                node_alpha=1,

                arrows=cont_arrows,
                edge_layout='straight',
                edge_label=label_type != LabelType.NoLabels,
                edge_labels=cont_edge_label,
                edge_label_fontdict=dict(size=font_size),
                edge_color=edge_color,
                tail_color=tail_color,
                edge_width=cont_edge_width,
                edge_alpha=1,
                edge_zorder=1,
                edge_label_position=0.35)

        nx.draw_networkx_labels(Gcont,
                                pos=a.node_positions,
                                labels={n: n for n in Gcont},
                                font_size=font_size)

    # 6. Draw graph - lagged
    if lagged_edges:
        a = Graph(Glag,
                node_layout=a.node_positions if cont_edges else node_layout,
                node_size=node_size,
                node_color=node_color,
                node_labels=node_label,
                node_edge_width=border,
                node_label_fontdict=dict(size=font_size),
                node_edge_color=edge_color,
                node_label_offset=0.05,
                node_alpha=1,

                arrows=lagged_arrows,
                edge_layout='curved',
                edge_label=label_type != LabelType.NoLabels,
                edge_labels=lagged_edge_label,
                edge_label_fontdict=dict(size=font_size),
                edge_color=edge_color,
                tail_color=tail_color,
                edge_width=lagged_edge_width,
                edge_alpha=1,
                edge_zorder=1,
                edge_label_position=0.35)

        if not cont_edges:
            nx.draw_networkx_labels(Glag,
                                    pos=a.node_positions,
                                    labels={n: n for n in Glag},
                                    font_size=font_size)
    # 7. Plot or save
    if save_name is not None:
        plt.savefig(save_name + img_extention.value, dpi=300)
    else:
        plt.show()

del_source(t, s, lag)

Remove source node from a target node.

Parameters:

Name Type Description Default
t str

target node name.

required
s str

source node name.

required
lag int

dependency lag.

required
Source code in causalflow/graph/DAG.py
138
139
140
141
142
143
144
145
146
147
148
def del_source(self, t, s, lag):
    """
    Remove source node from a target node.

    Args:
        t (str): target node name.
        s (str): source node name.
        lag (int): dependency lag.
    """
    del self.g[t].sources[(s, lag)]
    self.g[s].children.remove(t)

get_Adj(indexed=False)

Return Adjacency dictionary.

If indexed = True: example {0: [(0, -1), (1, -2)], 1: [], ...} If indexed = False: example {"X_0": [(X_0, -1), (X_1, -2)], "X_1": [], ...}

Parameters:

Name Type Description Default
indexed bool

If true, returns the SCM with index instead of variables' names. Otherwise it uses variables' names. Defaults to False.

False

Returns:

Name Type Description
dict dict

SCM.

Source code in causalflow/graph/DAG.py
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
def get_Adj(self, indexed = False) -> dict:   
    """
    Return Adjacency dictionary.

    If indexed = True: example {0: [(0, -1), (1, -2)], 1: [], ...}
    If indexed = False: example {"X_0": [(X_0, -1), (X_1, -2)], "X_1": [], ...}

    Args:
        indexed (bool, optional): If true, returns the SCM with index instead of variables' names. Otherwise it uses variables' names. Defaults to False.

    Returns:
        dict: SCM.
    """
    if not indexed:
        scm = {v: list() for v in self.features}
        for t in self.g:
            for s in self.g[t].sources:
                scm[t].append((s[0], -abs(s[1]))) 
    else:
        scm = {self.features.index(v): list() for v in self.features}
        for t in self.g:
            for s in self.g[t].sources:
                scm[self.features.index(t)].append((self.features.index(s[0]), -abs(s[1]))) 
    return scm

get_Graph()

Return Graph dictionary. E.g. {X1: {(X2, -2): '-->'}, X2: {(X3, -1): '-?>'}, X3: {(X3, -1): '-->'}}.

Returns:

Name Type Description
dict dict

graph dictionary.

Source code in causalflow/graph/DAG.py
833
834
835
836
837
838
839
840
841
842
843
844
def get_Graph(self) -> dict:
    """
    Return Graph dictionary. E.g. {X1: {(X2, -2): '-->'}, X2: {(X3, -1): '-?>'}, X3: {(X3, -1): '-->'}}.

    Returns:
        dict: graph dictionary.
    """
    scm = {v: dict() for v in self.features}
    for t in self.g:
        for s in self.g[t].sources:
            scm[t][(s[0], -abs(s[1]))] = self.g[t].sources[s][TYPE] 
    return scm

get_graph_matrix()

Return graph matrix.

Graph matrix contains information about the link type. E.g., -->, <->, ..

Returns:

Type Description
np.array

np.array: graph matrix.

Source code in causalflow/graph/DAG.py
791
792
793
794
795
796
797
798
799
800
801
802
803
804
def get_graph_matrix(self) -> np.array:
    """
    Return graph matrix.

    Graph matrix contains information about the link type. E.g., -->, <->, ..

    Returns:
        np.array: graph matrix.
    """
    r = np.full((len(self.features), len(self.features), self.max_lag + 1), '', dtype=object)
    for t in self.g.keys():
        for s, info in self.g[t].sources.items():
            r[self.features.index(t), self.features.index(s[0])][s[1]] = info[TYPE]
    return np.array(r)

Return link assumption dictionary.

Parameters:

Name Type Description Default
autodep_ok bool

If true, autodependecy link assumption = -->. Otherwise -?>. Defaults to False.

False

Returns:

Name Type Description
dict dict

link assumption dictionary.

Source code in causalflow/graph/DAG.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def get_link_assumptions(self, autodep_ok = False) -> dict:
    """
    Return link assumption dictionary.

    Args:
        autodep_ok (bool, optional): If true, autodependecy link assumption = -->. Otherwise -?>. Defaults to False.

    Returns:
        dict: link assumption dictionary.
    """
    link_assump = {self.features.index(f): dict() for f in self.features}
    for t in self.g:
        for s in self.g[t].sources:
            if autodep_ok and s[0] == t: # NOTE: new condition added in order to not control twice the autodependency links
                link_assump[self.features.index(t)][(self.features.index(s[0]), -abs(s[1]))] = '-->'

            elif s[0] not in list(self.sys_context.values()):
                if s[1] == 0 and (t, 0) in self.g[s[0]].sources:
                    link_assump[self.features.index(t)][(self.features.index(s[0]), 0)] = 'o-o'
                elif s[1] == 0 and (t, 0) not in self.g[s[0]].sources:
                    link_assump[self.features.index(t)][(self.features.index(s[0]),0)] = '-?>'
                    link_assump[self.features.index(s[0])][(self.features.index(t), 0)] = '<?-'
                elif s[1] > 0:
                    link_assump[self.features.index(t)][(self.features.index(s[0]), -abs(s[1]))] = '-?>'

            elif t in self.sys_context.keys() and s[0] == self.sys_context[t]:
                link_assump[self.features.index(t)][(self.features.index(s[0]), -abs(s[1]))] = '-->'

    return link_assump

get_pval_matrix()

Return pval matrix.

Pval matrix contains information about the pval of the links componing the causal model.

Returns:

Type Description
np.array

np.array: pval matrix

Source code in causalflow/graph/DAG.py
775
776
777
778
779
780
781
782
783
784
785
786
787
788
def get_pval_matrix(self) -> np.array:
    """
    Return pval matrix.

    Pval matrix contains information about the pval of the links componing the causal model.

    Returns:
        np.array: pval matrix
    """
    r = np.zeros((len(self.features), len(self.features), self.max_lag + 1))
    for t in self.g.keys():
        for s, info in self.g[t].sources.items():
            r[self.features.index(t), self.features.index(s[0])][s[1]] = info[PVAL]
    return np.array(r)

get_skeleton()

Return skeleton matrix.

Skeleton matrix is composed by 0 and 1. 1 <- if there is a link from source to target 0 <- if there is not a link from source to target

Returns:

Type Description
np.array

np.array: skeleton matrix

Source code in causalflow/graph/DAG.py
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
def get_skeleton(self) -> np.array:
    """
    Return skeleton matrix.

    Skeleton matrix is composed by 0 and 1.
    1 <- if there is a link from source to target 
    0 <- if there is not a link from source to target 

    Returns:
        np.array: skeleton matrix
    """
    r = np.full((len(self.features), len(self.features), self.max_lag + 1), '', dtype=object)
    for t in self.g.keys():
        for s in self.g[t].sources:
            r[self.features.index(t), self.features.index(s[0])][s[1]] = 1
    return np.array(r)

get_val_matrix()

Return val matrix.

Val matrix contains information about the strength of the links componing the causal model.

Returns:

Type Description
np.array

np.array: val matrix.

Source code in causalflow/graph/DAG.py
759
760
761
762
763
764
765
766
767
768
769
770
771
772
def get_val_matrix(self) -> np.array:
    """
    Return val matrix.

    Val matrix contains information about the strength of the links componing the causal model.

    Returns:
        np.array: val matrix.
    """
    r = np.zeros((len(self.features), len(self.features), self.max_lag + 1))
    for t in self.g.keys():
        for s, info in self.g[t].sources.items():
                r[self.features.index(t), self.features.index(s[0])][s[1]] = info[SCORE]
    return np.array(r)

make_pretty()

Make variables' names pretty, i.e. $ varname $ with '{' after '_' and '}' at the end of the string.

Returns:

Name Type Description
dict dict

pretty DAG.

Source code in causalflow/graph/DAG.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def make_pretty(self) -> dict:
    """
    Make variables' names pretty, i.e. $ varname $ with '{' after '_' and '}' at the end of the string.

    Returns:
        dict: pretty DAG.
    """
    def prettify(name):
        return '$' + re.sub(r'_(\w+)', r'_{\1}', name) + '$'

    pretty = dict()
    for t in self.g:
        p_t = prettify(t)
        pretty[p_t] = copy.deepcopy(self.g[t])
        pretty[p_t].name = p_t
        pretty[p_t].children = [prettify(c) for c in self.g[t].children]
        for s in self.g[t].sources:
            del pretty[p_t].sources[s]
            p_s = prettify(s[0])
            pretty[p_t].sources[(p_s, s[1])] = {
                SCORE: self.g[t].sources[s][SCORE],
                PVAL: self.g[t].sources[s][PVAL],
                TYPE: self.g[t].sources[s][TYPE]
            }
    return pretty

remove_context()

Remove context variables.

Source code in causalflow/graph/DAG.py
181
182
183
184
185
186
187
188
189
190
191
def remove_context(self):
    """Remove context variables."""
    for sys_var, context_var in self.sys_context.items():
        if sys_var in self.g:
            # Removing context var from sys var
            # self.g[sys_var].intervention_node = False
            self.g[sys_var].associated_context = None
            self.del_source(sys_var, context_var, 0)

            # Removing context var from dag
            del self.g[context_var]

remove_unneeded_features()

Remove isolated nodes.

Source code in causalflow/graph/DAG.py
151
152
153
154
155
156
157
158
def remove_unneeded_features(self):
    """Remove isolated nodes."""
    tmp = copy.deepcopy(self.g)
    for t in self.g.keys():
        if self.g[t].is_isolated: 
            if self.g[t].intervention_node: del tmp[self.g[t].associated_context]
            del tmp[t]
    self.g = tmp

ts_dag(min_cross_width=1, max_cross_width=5, node_size=8, x_disp=1.5, y_disp=0.2, text_disp=0.1, node_color='orange', edge_color='grey', tail_color='black', font_size=8, save_name=None, img_extention=ImageExt.PNG)

Build a timeseries dag.

Parameters:

Name Type Description Default
min_cross_width int

minimum linewidth. Defaults to 1.

1
max_cross_width int

maximum linewidth. Defaults to 5.

5
node_size int

node size. Defaults to 8.

8
x_disp float

node displacement along x. Defaults to 1.5.

1.5
y_disp float

node displacement along y. Defaults to 0.2.

0.2
text_disp float

text displacement along y. Defaults to 0.1.

0.1
node_color str / list

node color. If a string, all the nodes will have the same colour. If a list (same dimension of features), each colour will have the specified colour. Defaults to 'orange'.

'orange'
edge_color str

edge color. Defaults to 'grey'.

'grey'
tail_color str

tail color. Defaults to 'black'.

'black'
font_size int

font size. Defaults to 8.

8
save_name str

Filename path. If None, plot is shown and not saved. Defaults to None.

None
img_extention ImageExt

Image Extension. Defaults to PNG.

ImageExt.PNG
Source code in causalflow/graph/DAG.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def ts_dag(self,
           min_cross_width = 1, 
           max_cross_width = 5,
           node_size = 8,
           x_disp = 1.5,
           y_disp = 0.2,
           text_disp = 0.1,
           node_color = 'orange',
           edge_color = 'grey',
           tail_color = 'black',
           font_size = 8,
           save_name = None,
           img_extention = ImageExt.PNG):
    """
    Build a timeseries dag.

    Args:
        min_cross_width (int, optional): minimum linewidth. Defaults to 1.
        max_cross_width (int, optional): maximum linewidth. Defaults to 5.
        node_size (int, optional): node size. Defaults to 8.
        x_disp (float, optional): node displacement along x. Defaults to 1.5.
        y_disp (float, optional): node displacement along y. Defaults to 0.2.
        text_disp (float, optional): text displacement along y. Defaults to 0.1.
        node_color (str/list, optional): node color. 
                                         If a string, all the nodes will have the same colour. 
                                         If a list (same dimension of features), each colour will have the specified colour.
                                         Defaults to 'orange'.
        edge_color (str, optional): edge color. Defaults to 'grey'.
        tail_color (str, optional): tail color. Defaults to 'black'.
        font_size (int, optional): font size. Defaults to 8.
        save_name (str, optional): Filename path. If None, plot is shown and not saved. Defaults to None.
        img_extention (ImageExt, optional): Image Extension. Defaults to PNG.
    """
    r = copy.deepcopy(self)
    r.g = r.make_pretty()

    Gcont = nx.DiGraph()
    Glagcross = nx.DiGraph()
    Glagauto = nx.DiGraph()

    # 1. Nodes definition
    if isinstance(node_color, list):
        node_c = dict()
    else:
        node_c = node_color
    for i in range(len(self.features)):
        for j in range(self.max_lag + 1):
            Glagauto.add_node((j, i))
            Glagcross.add_node((j, i))
            Gcont.add_node((j, i))
            if isinstance(node_color, list): node_c[(j, i)] = node_color[abs(i - (len(r.g.keys()) - 1))]

    pos = {n : (n[0]*x_disp, n[1]*y_disp) for n in Glagauto.nodes()}
    scale = max(pos.values())

    # 2. Edges definition
    cont_edges = list()
    cont_edge_width = dict()
    cont_arrows = dict()
    lagged_cross_edges = list()
    lagged_cross_edge_width = dict()
    lagged_cross_arrows = dict()
    lagged_auto_edges = list()
    lagged_auto_edge_width = dict()
    lagged_auto_arrows = dict()

    for t in r.g:
        for s in r.g[t].sources:
            s_index = len(r.g.keys())-1 - list(r.g.keys()).index(s[0])
            t_index = len(r.g.keys())-1 - list(r.g.keys()).index(t)

            # 2.1. Contemporaneous edges definition
            if s[1] == 0:
                for i in range(self.max_lag + 1):
                    s_node = (i, s_index)
                    t_node = (i, t_index)
                    self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score, 
                                    cont_edges, cont_edge_width, cont_arrows, r, t, s, 
                                    s_node, t_node)

            else:
                s_lag = self.max_lag - s[1]
                t_lag = self.max_lag
                while s_lag >= 0:
                    s_node = (s_lag, s_index)
                    t_node = (t_lag, t_index)
                    # 2.2. Lagged cross edges definition
                    if s[0] != t:
                        self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score, 
                                        lagged_cross_edges, lagged_cross_edge_width, lagged_cross_arrows, r, t, s, 
                                        s_node, t_node)
                    # 2.3. Lagged auto edges definition
                    else:
                        self.__add_edge(min_cross_width, max_cross_width, 0, self.max_cross_score, 
                                        lagged_auto_edges, lagged_auto_edge_width, lagged_auto_arrows, r, t, s, 
                                        s_node, t_node)
                    s_lag -= 1
                    t_lag -= 1

    Gcont.add_edges_from(cont_edges)
    Glagcross.add_edges_from(lagged_cross_edges)
    Glagauto.add_edges_from(lagged_auto_edges)

    fig, ax = plt.subplots(figsize=(8,6))
    edge_layout = self.__get_fixed_edges(ax, x_disp, Gcont, node_size, pos, node_c, font_size, 
                                         cont_arrows, edge_color, tail_color, cont_edge_width, scale)

    # 3. Label definition
    for n in Gcont.nodes():
        if n[0] == 0:
            ax.text(pos[n][0] - text_disp, pos[n][1], list(r.g.keys())[len(r.g.keys()) - 1 - n[1]], horizontalalignment='center', verticalalignment='center', fontsize=font_size)

    # 4. Time line text drawing
    pos_tau = set([pos[p][0] for p in pos])
    max_y = max([pos[p][1] for p in pos])
    for p in pos_tau:
        if abs(int(p/x_disp) - self.max_lag) == 0:
            ax.text(p, max_y + 0.1, r"$t$", horizontalalignment='center', fontsize=font_size)
        else:
            ax.text(p, max_y + 0.1, r"$t-" + str(abs(int(p/x_disp) - self.max_lag)) + "$", horizontalalignment='center', fontsize=font_size)

    # 5. Draw graph - contemporaneous
    if cont_edges:
        a = Graph(Gcont,
                node_layout={p : np.array(pos[p]) for p in pos},
                node_size=node_size,
                node_color=node_c,
                node_edge_width=0,
                node_label_fontdict=dict(size=font_size),
                node_label_offset=0,
                node_alpha=1,

                arrows=cont_arrows,
                edge_layout=edge_layout,
                edge_label=False,
                edge_color=edge_color,
                tail_color=tail_color,
                edge_width=cont_edge_width,
                edge_alpha=1,
                edge_zorder=1,
                scale = (scale[0] + 2, scale[1] + 2))

    # 6. Draw graph - lagged cross
    if lagged_cross_edges:
        a = Graph(Glagcross,
                node_layout={p : np.array(pos[p]) for p in pos},
                node_size=node_size,
                node_color=node_c,
                node_edge_width=0,
                node_label_fontdict=dict(size=font_size),
                node_label_offset=0,
                node_alpha=1,

                arrows=lagged_cross_arrows,
                edge_layout='curved',
                edge_label=False,
                edge_color=edge_color,
                tail_color=tail_color,
                edge_width=lagged_cross_edge_width,
                edge_alpha=1,
                edge_zorder=1,
                scale = (scale[0] + 2, scale[1] + 2))

    # 7. Draw graph - lagged auto
    if lagged_auto_edges:
        a = Graph(Glagauto,
                node_layout={p : np.array(pos[p]) for p in pos},
                node_size=node_size,
                node_color=node_c,
                node_edge_width=0,
                node_label_fontdict=dict(size=font_size),
                node_label_offset=0,
                node_alpha=1,

                arrows=lagged_auto_arrows,
                edge_layout='straight',
                edge_label=False,
                edge_color=edge_color,
                tail_color=tail_color,
                edge_width=lagged_auto_edge_width,
                edge_alpha=1,
                edge_zorder=1,
                scale = (scale[0] + 2, scale[1] + 2))

    # 7. Plot or save
    if save_name is not None:
        plt.savefig(save_name + img_extention.value, dpi = 300)
    else:
        plt.show()

This module provides the PAG class.

Classes

PAG: class for facilitating the handling and the creation of PAGs.

PAG

PAG class.

Source code in causalflow/graph/PAG.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
class PAG():
    """PAG class."""

    def __init__(self, dag, tau_max, latents) -> None:
        """
        Class constructor.

        Args:
            dag (DAG): DAG to convert.
            tau_max (int): max time lag.
            latents (list[str]): list of latent variables.

        Raises:
            ValueError: latent must be a string
        """
        if not isinstance(latents, list): raise ValueError('latents must be a list')
        self.link_assumptions = dag
        self.tau_max = tau_max
        self.latents = latents
        self.dSepSets = {}

        self.tsDAG = self.createDAG(self.link_assumptions, self.tau_max)

        self.pag = self.tsDAG2tsDPAG()


    def convert2Graph(self) -> dict:
        """
        Convert a PAG to a graph representation.

        Returns:
            dict: Graph representation of a PAG.
        """
        out = {t: {} for t in self.pag}
        for t in self.pag:
            for s in self.pag[t]:
                out[t][(s[0], s[1])] = s[2]
        return out


    @staticmethod
    def createDAG(link_assumptions, tau_max) -> BayesianNetwork:
        """
        Create a DAG represented by a Baysian Network.

        Args:
            link_assumptions (dict): DAG link assumptions.
            tau_max (int): max time lag.

        Raises:
            ValueError: source not well defined.

        Returns:
            BayesianNetwork: DAG represented by a Baysian Network.
        """
        BN = BayesianNetwork()
        BN.add_nodes_from([(t, -l) for t in link_assumptions.keys() for l in range(0, tau_max)])

        # Edges
        edges = []
        for t in link_assumptions.keys():
            for source in link_assumptions[t]:
                if len(source) == 0: continue
                elif len(source) == 2: s, l = source
                elif len(source) == 3: s, l, _ = source
                else: raise ValueError("Source not well defined")
                edges.append(((s, l), (t, 0)))
                # Add edges across time slices from -1 to -tau_max
                for lag in range(1, tau_max + 1):
                    if l - lag >= -tau_max:
                        edges.append(((s, l - lag), (t, -lag)))
        BN.add_edges_from(edges)
        return BN


    def alreadyChecked(self, source, target):
        """
        Check if a link has been already checked.

        Args:
            source (str): source node
            target (str): target node

        Returns:
            (bool, tuple): tuple containing if the link has been checked and, if so, their separation set. Otherwise None.
        """
        if (source, target) in self.dSepSets: return True, self.dSepSets[(source, target)]
        elif (target, source) in self.dSepSets: return True, self.dSepSets[(target, source)]
        elif ((source[0], source[1] - target[1]), (target[0], 0)) in self.dSepSets: return True, self.dSepSets[((source[0], source[1] - target[1]), (target[0], 0))]
        elif ((target[0], target[1] - source[1]), (source[0], 0)) in self.dSepSets: return True, self.dSepSets[((target[0], target[1] - source[1]), (source[0], 0))]
        return False, None


    def tsDAG2tsDPAG(self) -> dict:
        """
        Convert a DAG to a Time-series DPAG.

        Returns:
            dict: Time-series DPAG.
        """
        self.tsDPAG = {t: [(s[0], s[1], '-->') for s in self.link_assumptions[t] if s[0] not in self.latents] for t in self.link_assumptions.keys() if t not in self.latents}

        if len(self.latents) > 0:
            self.ambiguous_links = []

            # Separate nodes based on their time index
            time_zero_nodes = []
            other_nodes = []

            for node in self.tsDAG.nodes():
                if node[1] == 0:
                    time_zero_nodes.append(node)
                else:
                    other_nodes.append(node)

            for target in time_zero_nodes + other_nodes:
                print(f"Analysing target: {target}")
                if target[0] in self.latents: continue
                tmp = []
                for n in list(self.tsDAG.nodes()):
                    if n[0] != target[0] or n[1] != target[1]:
                        tmp.append(n)
                for p in list(self.tsDAG.predecessors(target)) + list(self.tsDAG.successors(target)):
                    if p in tmp: tmp.remove(p)
                for source in tmp:
                    alreadyChecked, d_sep = self.alreadyChecked(source, target)
                    if alreadyChecked:
                        print(f"\t- {source}{target} | {d_sep} ALREADY CHECKED")
                    else:
                        areDsep, d_sep = self.find_d_separators(source, target, self.latents)
                        self.dSepSets[(source, target)] = d_sep
                        print(f"\t- {source}{target} | {d_sep}")
                    if areDsep and any(node[0] in self.latents for node in d_sep):
                        if source[0] in self.latents or target[0] in self.latents: continue
                        if target[1] == 0:
                            print(f"\t- SPURIOUS LINK: ({source[0]}, {source[1]}) o-o ({target[0]}, {target[1]})")
                            if (source[0], source[1], 'o-o') not in self.tsDPAG[target[0]]: self.tsDPAG[target[0]].append((source[0], source[1], 'o-o'))
                            if (source, target, 'o-o') not in self.ambiguous_links: self.ambiguous_links.append((source, target, 'o-o'))
                        elif source[1] == 0:
                            print(f"\t- SPURIOUS LINK: ({target[0]}, {target[1]}) o-o ({source[0]}, {source[1]})")
                            if (target[0], target[1], 'o-o') not in self.tsDPAG[source[0]]: self.tsDPAG[source[0]].append((target[0], target[1], 'o-o'))
                            if (target, source, 'o-o') not in self.ambiguous_links: self.ambiguous_links.append((target, source, 'o-o'))


            print(f"--------------------------------------------------")
            print(f"    Bidirected link due to latent confounders     ")
            print(f"--------------------------------------------------")
            # *(1) Bidirected link between variables confounded by a latent variable  
            # *    if a link between them does not exist already
            confounders = self.find_latent_confounders()
            for confounded in copy.deepcopy(list(confounders.values())):
                for c1 in copy.deepcopy(confounded):
                    tmp = copy.deepcopy(confounded)
                    tmp.remove(c1)
                    for c2 in tmp:
                        if (c1, c2, 'o-o') in self.ambiguous_links:
                            self.update_link_type(c1, c2, '<->')
                            self.ambiguous_links.remove((c1, c2, 'o-o'))
                            print(f"\t- SPURIOUS LINK REMOVED: {c1} o-o {c2}")
                        elif (c2, c1, 'o-o') in self.ambiguous_links:
                            self.update_link_type(c1, c2, '<->')
                            self.ambiguous_links.remove((c2, c1, 'o-o'))
                            print(f"\t- SPURIOUS LINK REMOVED: {c2} o-o {c1}")
                    confounded.remove(c1)

            print(f"--------------------------------------------------")
            print(f"              Collider orientation                ")
            print(f"--------------------------------------------------")
            # *(2) Identify and orient the colliders:
            # *    for any path X – Z – Y where there is no edge between
            # *    X and Y and, Z was never included in the conditioning set ==> X → Z ← Y collider
            colliders = self.find_colliders()
            for ambiguous_link in copy.deepcopy(self.ambiguous_links):
                source, target, linktype = ambiguous_link
                for parent1, collider, parent2 in colliders:
                    if collider == target and (parent1 == source or parent2 == source):
                        if not self.tsDAG.has_edge(parent1, parent2) and not self.tsDAG.has_edge(parent2, parent1):
                            self.update_link_type(parent1, target, '-->')
                            self.update_link_type(parent2, target, '-->')
                            self.ambiguous_links.remove(ambiguous_link)
                            break

            print(f"--------------------------------------------------")
            print(f"Non-collider orientation (orientation propagation)")
            print(f"--------------------------------------------------")
            # *(3) Orient the non-colliders edges (orientation propagation)
            # *    any edge Z – Y part of a partially directed path X → Z – Y,
            # *    where there is no edge between X and Y can be oriented as Z → Y
            for ambiguous_link in copy.deepcopy(self.ambiguous_links):
                triples = self.find_triples_containing_link(ambiguous_link)
                for triple in triples: self.update_link_type(triple[1], triple[2], '-->')

        # TODO: (3) Check if cycles are present

        else:
            print("No latent variable")

        return self.tsDPAG


    def find_colliders(self) -> list:
        """
        Find colliders.

        Returns:
            list: colliders.
        """
        colliders = []
        for node in self.tsDPAG.keys():
            parents = [(p[0], p[1]) for p in self.tsDPAG[node]]
            if len(parents) >= 2:
                for i in range(len(parents)):
                    for j in range(i + 1, len(parents)):
                        parent1 = parents[i]
                        parent2 = parents[j]
                        colliders.append((parent1, (node, 0), parent2))
        return colliders


    def update_link_type(self, parent, target, linktype):
        """
        Update link type.

        Args:
            parent (str): parent node.
            target (str): target node
            linktype (str): link type. E.g. --> or -?>.
        """
        for idx, link in enumerate(self.tsDPAG[target[0]]):
            if link[0] == parent[0] and link[1] == parent[1]:
                self.tsDPAG[target[0]][idx] = (link[0], link[1], linktype)


    def find_latent_confounders(self) -> dict:
        """
        Find latent confounders.

        Returns:
            dict: latent confounders.
        """
        confounders = {(latent, -t): list(self.tsDAG.successors((latent, -t))) for latent in self.latents for t in range(self.tau_max + 1) if len(list(self.tsDAG.successors((latent, -t)))) > 1}

        # Initialize a new dictionary to store unique edges
        shrinked_confounders = defaultdict(list)

        # Set to keep track of added edges without considering the time slice
        seen_edges = set()

        for key, value in confounders.items():
            # Normalize key by removing the time slice
            key_normalized = key[0]

            for v in value:
                # Normalize value by removing the time slice
                v_normalized = v[0]

                # Create a tuple of normalized edge
                edge = (key_normalized, v_normalized)

                # Check if edge or its reverse has been seen
                if edge not in seen_edges and (v_normalized, key_normalized) not in seen_edges:
                    # If not seen, add to unique edges and mark as seen
                    shrinked_confounders[key].append(v)
                    seen_edges.add(edge)

        return shrinked_confounders


    def find_d_separators(self, source, target):
        """
        Find D-Separation set.

        Args:
            source (str): source node.
            target (str): target node.

        Returns:
            (bool, set): (True, separation set) if source and target are d-separated. Otherwise (False, empty set). 
        """
        paths = self.find_all_paths(source, target)

        if len(paths) == 0: 
            return True, set()
        else:
            nodes = set()
            obs_nodes = set()
            for path in paths:
                path.remove(source)
                path.remove(target)
                for node in path:
                    nodes.add(node)
                    if node not in self.latents: obs_nodes.add(node)

            for r in range(len(obs_nodes) + 1):
                for subset in combinations(obs_nodes, r):
                    subset_set = set(subset)
                    if not self.tsDAG.is_dconnected(source, target, subset_set):
                        return True, subset_set

            for r in range(len(nodes) + 1):
                for subset in combinations(nodes, r):
                    subset_set = set(subset)
                    if not self.tsDAG.is_dconnected(source, target, subset_set):
                        return True, subset_set

        return False, set()

    def find_triples_containing_link(self, ambiguous_link) -> set:
        """
        Find all triples containing a link.

        Args:
            ambiguous_link (tuple): ambiguous_link

        Returns:
            set: triples containing the specified link.
        """
        pag = self.createDAG(self.tsDPAG, self.tau_max)

        source, target, _ = ambiguous_link
        triples = set()

        for n in pag.predecessors(source): 
            if n != target and not pag.has_edge(n, target) and not pag.has_edge(target, n): triples.add((n, source, target))
        for n in pag.predecessors(target): 
            if n != source and not pag.has_edge(n, source) and not pag.has_edge(source, n): triples.add((n, target, source))

        return triples


    # DFS to find all paths
    def find_all_paths(self, start, goal, path=[]) -> list:
        """
        Find all path from start to goal.

        Args:
            start (str): starting node.
            goal (str): goal node.
            path (list, optional): Found paths. Defaults to [].

        Returns:
            list: paths
        """
        path = path + [start]
        if start == goal:
            return [path]
        paths = []
        for node in self.tsDAG.successors(start):
            if node not in path:
                new_paths = self.find_all_paths(node, goal, path)
                for new_path in new_paths:
                    paths.append(new_path)
        for node in self.tsDAG.predecessors(start):
            if node not in path:
                new_paths = self.find_all_paths(node, goal, path)
                for new_path in new_paths:
                    paths.append(new_path)
        return paths

__init__(dag, tau_max, latents)

Class constructor.

Parameters:

Name Type Description Default
dag DAG

DAG to convert.

required
tau_max int

max time lag.

required
latents list[str]

list of latent variables.

required

Raises:

Type Description
ValueError

latent must be a string

Source code in causalflow/graph/PAG.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(self, dag, tau_max, latents) -> None:
    """
    Class constructor.

    Args:
        dag (DAG): DAG to convert.
        tau_max (int): max time lag.
        latents (list[str]): list of latent variables.

    Raises:
        ValueError: latent must be a string
    """
    if not isinstance(latents, list): raise ValueError('latents must be a list')
    self.link_assumptions = dag
    self.tau_max = tau_max
    self.latents = latents
    self.dSepSets = {}

    self.tsDAG = self.createDAG(self.link_assumptions, self.tau_max)

    self.pag = self.tsDAG2tsDPAG()

alreadyChecked(source, target)

Check if a link has been already checked.

Parameters:

Name Type Description Default
source str

source node

required
target str

target node

required

Returns:

Type Description
(bool, tuple)

tuple containing if the link has been checked and, if so, their separation set. Otherwise None.

Source code in causalflow/graph/PAG.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def alreadyChecked(self, source, target):
    """
    Check if a link has been already checked.

    Args:
        source (str): source node
        target (str): target node

    Returns:
        (bool, tuple): tuple containing if the link has been checked and, if so, their separation set. Otherwise None.
    """
    if (source, target) in self.dSepSets: return True, self.dSepSets[(source, target)]
    elif (target, source) in self.dSepSets: return True, self.dSepSets[(target, source)]
    elif ((source[0], source[1] - target[1]), (target[0], 0)) in self.dSepSets: return True, self.dSepSets[((source[0], source[1] - target[1]), (target[0], 0))]
    elif ((target[0], target[1] - source[1]), (source[0], 0)) in self.dSepSets: return True, self.dSepSets[((target[0], target[1] - source[1]), (source[0], 0))]
    return False, None

convert2Graph()

Convert a PAG to a graph representation.

Returns:

Name Type Description
dict dict

Graph representation of a PAG.

Source code in causalflow/graph/PAG.py
39
40
41
42
43
44
45
46
47
48
49
50
def convert2Graph(self) -> dict:
    """
    Convert a PAG to a graph representation.

    Returns:
        dict: Graph representation of a PAG.
    """
    out = {t: {} for t in self.pag}
    for t in self.pag:
        for s in self.pag[t]:
            out[t][(s[0], s[1])] = s[2]
    return out

createDAG(link_assumptions, tau_max) staticmethod

Create a DAG represented by a Baysian Network.

Parameters:

Name Type Description Default
link_assumptions dict

DAG link assumptions.

required
tau_max int

max time lag.

required

Raises:

Type Description
ValueError

source not well defined.

Returns:

Name Type Description
BayesianNetwork BayesianNetwork

DAG represented by a Baysian Network.

Source code in causalflow/graph/PAG.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@staticmethod
def createDAG(link_assumptions, tau_max) -> BayesianNetwork:
    """
    Create a DAG represented by a Baysian Network.

    Args:
        link_assumptions (dict): DAG link assumptions.
        tau_max (int): max time lag.

    Raises:
        ValueError: source not well defined.

    Returns:
        BayesianNetwork: DAG represented by a Baysian Network.
    """
    BN = BayesianNetwork()
    BN.add_nodes_from([(t, -l) for t in link_assumptions.keys() for l in range(0, tau_max)])

    # Edges
    edges = []
    for t in link_assumptions.keys():
        for source in link_assumptions[t]:
            if len(source) == 0: continue
            elif len(source) == 2: s, l = source
            elif len(source) == 3: s, l, _ = source
            else: raise ValueError("Source not well defined")
            edges.append(((s, l), (t, 0)))
            # Add edges across time slices from -1 to -tau_max
            for lag in range(1, tau_max + 1):
                if l - lag >= -tau_max:
                    edges.append(((s, l - lag), (t, -lag)))
    BN.add_edges_from(edges)
    return BN

find_all_paths(start, goal, path=[])

Find all path from start to goal.

Parameters:

Name Type Description Default
start str

starting node.

required
goal str

goal node.

required
path list

Found paths. Defaults to [].

[]

Returns:

Name Type Description
list list

paths

Source code in causalflow/graph/PAG.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def find_all_paths(self, start, goal, path=[]) -> list:
    """
    Find all path from start to goal.

    Args:
        start (str): starting node.
        goal (str): goal node.
        path (list, optional): Found paths. Defaults to [].

    Returns:
        list: paths
    """
    path = path + [start]
    if start == goal:
        return [path]
    paths = []
    for node in self.tsDAG.successors(start):
        if node not in path:
            new_paths = self.find_all_paths(node, goal, path)
            for new_path in new_paths:
                paths.append(new_path)
    for node in self.tsDAG.predecessors(start):
        if node not in path:
            new_paths = self.find_all_paths(node, goal, path)
            for new_path in new_paths:
                paths.append(new_path)
    return paths

find_colliders()

Find colliders.

Returns:

Name Type Description
list list

colliders.

Source code in causalflow/graph/PAG.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def find_colliders(self) -> list:
    """
    Find colliders.

    Returns:
        list: colliders.
    """
    colliders = []
    for node in self.tsDPAG.keys():
        parents = [(p[0], p[1]) for p in self.tsDPAG[node]]
        if len(parents) >= 2:
            for i in range(len(parents)):
                for j in range(i + 1, len(parents)):
                    parent1 = parents[i]
                    parent2 = parents[j]
                    colliders.append((parent1, (node, 0), parent2))
    return colliders

find_d_separators(source, target)

Find D-Separation set.

Parameters:

Name Type Description Default
source str

source node.

required
target str

target node.

required

Returns:

Type Description
(bool, set)

(True, separation set) if source and target are d-separated. Otherwise (False, empty set).

Source code in causalflow/graph/PAG.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def find_d_separators(self, source, target):
    """
    Find D-Separation set.

    Args:
        source (str): source node.
        target (str): target node.

    Returns:
        (bool, set): (True, separation set) if source and target are d-separated. Otherwise (False, empty set). 
    """
    paths = self.find_all_paths(source, target)

    if len(paths) == 0: 
        return True, set()
    else:
        nodes = set()
        obs_nodes = set()
        for path in paths:
            path.remove(source)
            path.remove(target)
            for node in path:
                nodes.add(node)
                if node not in self.latents: obs_nodes.add(node)

        for r in range(len(obs_nodes) + 1):
            for subset in combinations(obs_nodes, r):
                subset_set = set(subset)
                if not self.tsDAG.is_dconnected(source, target, subset_set):
                    return True, subset_set

        for r in range(len(nodes) + 1):
            for subset in combinations(nodes, r):
                subset_set = set(subset)
                if not self.tsDAG.is_dconnected(source, target, subset_set):
                    return True, subset_set

    return False, set()

find_latent_confounders()

Find latent confounders.

Returns:

Name Type Description
dict dict

latent confounders.

Source code in causalflow/graph/PAG.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def find_latent_confounders(self) -> dict:
    """
    Find latent confounders.

    Returns:
        dict: latent confounders.
    """
    confounders = {(latent, -t): list(self.tsDAG.successors((latent, -t))) for latent in self.latents for t in range(self.tau_max + 1) if len(list(self.tsDAG.successors((latent, -t)))) > 1}

    # Initialize a new dictionary to store unique edges
    shrinked_confounders = defaultdict(list)

    # Set to keep track of added edges without considering the time slice
    seen_edges = set()

    for key, value in confounders.items():
        # Normalize key by removing the time slice
        key_normalized = key[0]

        for v in value:
            # Normalize value by removing the time slice
            v_normalized = v[0]

            # Create a tuple of normalized edge
            edge = (key_normalized, v_normalized)

            # Check if edge or its reverse has been seen
            if edge not in seen_edges and (v_normalized, key_normalized) not in seen_edges:
                # If not seen, add to unique edges and mark as seen
                shrinked_confounders[key].append(v)
                seen_edges.add(edge)

    return shrinked_confounders

Find all triples containing a link.

Parameters:

Name Type Description Default
ambiguous_link tuple

ambiguous_link

required

Returns:

Name Type Description
set set

triples containing the specified link.

Source code in causalflow/graph/PAG.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def find_triples_containing_link(self, ambiguous_link) -> set:
    """
    Find all triples containing a link.

    Args:
        ambiguous_link (tuple): ambiguous_link

    Returns:
        set: triples containing the specified link.
    """
    pag = self.createDAG(self.tsDPAG, self.tau_max)

    source, target, _ = ambiguous_link
    triples = set()

    for n in pag.predecessors(source): 
        if n != target and not pag.has_edge(n, target) and not pag.has_edge(target, n): triples.add((n, source, target))
    for n in pag.predecessors(target): 
        if n != source and not pag.has_edge(n, source) and not pag.has_edge(source, n): triples.add((n, target, source))

    return triples

tsDAG2tsDPAG()

Convert a DAG to a Time-series DPAG.

Returns:

Name Type Description
dict dict

Time-series DPAG.

Source code in causalflow/graph/PAG.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def tsDAG2tsDPAG(self) -> dict:
    """
    Convert a DAG to a Time-series DPAG.

    Returns:
        dict: Time-series DPAG.
    """
    self.tsDPAG = {t: [(s[0], s[1], '-->') for s in self.link_assumptions[t] if s[0] not in self.latents] for t in self.link_assumptions.keys() if t not in self.latents}

    if len(self.latents) > 0:
        self.ambiguous_links = []

        # Separate nodes based on their time index
        time_zero_nodes = []
        other_nodes = []

        for node in self.tsDAG.nodes():
            if node[1] == 0:
                time_zero_nodes.append(node)
            else:
                other_nodes.append(node)

        for target in time_zero_nodes + other_nodes:
            print(f"Analysing target: {target}")
            if target[0] in self.latents: continue
            tmp = []
            for n in list(self.tsDAG.nodes()):
                if n[0] != target[0] or n[1] != target[1]:
                    tmp.append(n)
            for p in list(self.tsDAG.predecessors(target)) + list(self.tsDAG.successors(target)):
                if p in tmp: tmp.remove(p)
            for source in tmp:
                alreadyChecked, d_sep = self.alreadyChecked(source, target)
                if alreadyChecked:
                    print(f"\t- {source}{target} | {d_sep} ALREADY CHECKED")
                else:
                    areDsep, d_sep = self.find_d_separators(source, target, self.latents)
                    self.dSepSets[(source, target)] = d_sep
                    print(f"\t- {source}{target} | {d_sep}")
                if areDsep and any(node[0] in self.latents for node in d_sep):
                    if source[0] in self.latents or target[0] in self.latents: continue
                    if target[1] == 0:
                        print(f"\t- SPURIOUS LINK: ({source[0]}, {source[1]}) o-o ({target[0]}, {target[1]})")
                        if (source[0], source[1], 'o-o') not in self.tsDPAG[target[0]]: self.tsDPAG[target[0]].append((source[0], source[1], 'o-o'))
                        if (source, target, 'o-o') not in self.ambiguous_links: self.ambiguous_links.append((source, target, 'o-o'))
                    elif source[1] == 0:
                        print(f"\t- SPURIOUS LINK: ({target[0]}, {target[1]}) o-o ({source[0]}, {source[1]})")
                        if (target[0], target[1], 'o-o') not in self.tsDPAG[source[0]]: self.tsDPAG[source[0]].append((target[0], target[1], 'o-o'))
                        if (target, source, 'o-o') not in self.ambiguous_links: self.ambiguous_links.append((target, source, 'o-o'))


        print(f"--------------------------------------------------")
        print(f"    Bidirected link due to latent confounders     ")
        print(f"--------------------------------------------------")
        # *(1) Bidirected link between variables confounded by a latent variable  
        # *    if a link between them does not exist already
        confounders = self.find_latent_confounders()
        for confounded in copy.deepcopy(list(confounders.values())):
            for c1 in copy.deepcopy(confounded):
                tmp = copy.deepcopy(confounded)
                tmp.remove(c1)
                for c2 in tmp:
                    if (c1, c2, 'o-o') in self.ambiguous_links:
                        self.update_link_type(c1, c2, '<->')
                        self.ambiguous_links.remove((c1, c2, 'o-o'))
                        print(f"\t- SPURIOUS LINK REMOVED: {c1} o-o {c2}")
                    elif (c2, c1, 'o-o') in self.ambiguous_links:
                        self.update_link_type(c1, c2, '<->')
                        self.ambiguous_links.remove((c2, c1, 'o-o'))
                        print(f"\t- SPURIOUS LINK REMOVED: {c2} o-o {c1}")
                confounded.remove(c1)

        print(f"--------------------------------------------------")
        print(f"              Collider orientation                ")
        print(f"--------------------------------------------------")
        # *(2) Identify and orient the colliders:
        # *    for any path X – Z – Y where there is no edge between
        # *    X and Y and, Z was never included in the conditioning set ==> X → Z ← Y collider
        colliders = self.find_colliders()
        for ambiguous_link in copy.deepcopy(self.ambiguous_links):
            source, target, linktype = ambiguous_link
            for parent1, collider, parent2 in colliders:
                if collider == target and (parent1 == source or parent2 == source):
                    if not self.tsDAG.has_edge(parent1, parent2) and not self.tsDAG.has_edge(parent2, parent1):
                        self.update_link_type(parent1, target, '-->')
                        self.update_link_type(parent2, target, '-->')
                        self.ambiguous_links.remove(ambiguous_link)
                        break

        print(f"--------------------------------------------------")
        print(f"Non-collider orientation (orientation propagation)")
        print(f"--------------------------------------------------")
        # *(3) Orient the non-colliders edges (orientation propagation)
        # *    any edge Z – Y part of a partially directed path X → Z – Y,
        # *    where there is no edge between X and Y can be oriented as Z → Y
        for ambiguous_link in copy.deepcopy(self.ambiguous_links):
            triples = self.find_triples_containing_link(ambiguous_link)
            for triple in triples: self.update_link_type(triple[1], triple[2], '-->')

    # TODO: (3) Check if cycles are present

    else:
        print("No latent variable")

    return self.tsDPAG

Update link type.

Parameters:

Name Type Description Default
parent str

parent node.

required
target str

target node

required
linktype str

link type. E.g. --> or -?>.

required
Source code in causalflow/graph/PAG.py
232
233
234
235
236
237
238
239
240
241
242
243
def update_link_type(self, parent, target, linktype):
    """
    Update link type.

    Args:
        parent (str): parent node.
        target (str): target node
        linktype (str): link type. E.g. --> or -?>.
    """
    for idx, link in enumerate(self.tsDPAG[target[0]]):
        if link[0] == parent[0] and link[1] == parent[1]:
            self.tsDPAG[target[0]][idx] = (link[0], link[1], linktype)