in dowhy/causal_graph.py [0:0]
def __init__(self,
treatment_name, outcome_name,
graph=None,
common_cause_names=None,
instrument_names=None,
effect_modifier_names=None,
mediator_names=None,
observed_node_names=None,
missing_nodes_as_confounders=False):
self.treatment_name = parse_state(treatment_name)
self.outcome_name = parse_state(outcome_name)
instrument_names = parse_state(instrument_names)
common_cause_names = parse_state(common_cause_names)
effect_modifier_names = parse_state(effect_modifier_names)
mediator_names = parse_state(mediator_names)
self.logger = logging.getLogger(__name__)
#re.sub only takes string parameter so the first if is to avoid error
#if the input is a text file, convert the contained data into string
if isinstance(graph, str) and re.match(r".*\.txt" , str(graph)):
text_file = open(graph , "r")
graph = text_file.read()
text_file.close()
if isinstance(graph, str) and re.match(r"^dag", graph): #Convert daggity output to dot format
graph = daggity_to_dot(graph)
if isinstance(graph, str):
graph=graph.replace("\n", " ")
if graph is None:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names,
instrument_names,
effect_modifier_names,
mediator_names)
elif re.match(r".*\.dot", graph):
# load dot file
try:
import pygraphviz as pgv
self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
except Exception as e:
self.logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...")
try:
import pydot
self._graph = nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph))
except Exception as e:
self.logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
elif re.match(r".*\.gml", graph):
self._graph = nx.DiGraph(nx.read_gml(graph))
elif re.match(r".*graph\s*\{.*\}\s*", graph):
try:
import pygraphviz as pgv
self._graph = pgv.AGraph(graph, strict=True, directed=True)
self._graph = nx.drawing.nx_agraph.from_agraph(self._graph)
except Exception as e:
self.logger.error("Error: Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot ...")
try:
import pydot
P_list = pydot.graph_from_dot_data(graph)
self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0])
except Exception as e:
self.logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
elif re.match(".*graph\s*\[.*\]\s*", graph):
self._graph = nx.DiGraph(nx.parse_gml(graph))
else:
self.logger.error("Error: Please provide graph (as string or text file) in dot or gml format.")
self.logger.error("Error: Incorrect graph format")
raise ValueError
if missing_nodes_as_confounders:
self._graph = self.add_missing_nodes_as_common_causes(observed_node_names)
# Adding node attributes
self._graph = self.add_node_attributes(observed_node_names)