Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import inspect
- import networkx as nx
- import re
- __all__ = ['make_inheritance_tree']
- def get_classes(p, visited=None):
- """generator that yields all classes in a module or package"""
- if visited is None:
- visited = [id(p)]
- for k, v in p.__dict__.items():
- if id(v) in visited:
- continue
- else:
- visited.append(id(v))
- if inspect.ismodule(v):
- for item in get_classes(v, visited=visited):
- yield item
- elif isinstance(v, type):
- yield v
- def clsname(cls):
- """return the qualname of a class"""
- return cls.__module__+'.'+cls.__name__
- def graph_class(G, cls, root=None, visited=None):
- """
- make a graph for a single class by connecting the classes from the method resolution order
- classes not within the root name are not checked for mro
- """
- if visited is None:
- visited = [id(cls)]
- for supercls in type.mro(cls):
- if id(supercls) in visited:
- continue
- else:
- visited.append(id(cls))
- G.add_edge(clsname(supercls), clsname(cls))
- if supercls.__module__ == '__builtin__':
- continue
- if root and not supercls.__module__.startswith(root+'.'):
- continue
- graph_class(G, supercls, root=root, visited=visited)
- def make_package_class_tree(p, maxcount=0, root=None):
- """create a graph for a package"""
- G = nx.DiGraph()
- visited = []
- for i, cls in enumerate(get_classes(p)):
- if id(cls) in visited:
- continue
- else:
- visited.append(id(cls))
- graph_class(G, cls, root=root or p.__name__, visited=visited)
- if i+1 >= maxcount:
- break
- nx.draw_networkx(G)
- def make_class_tree(cls, root=None):
- """create a graph for a class"""
- G = nx.DiGraph()
- graph_class(G, cls, root=root or cls.__module__)
- nx.draw_networkx(G)
- def make_inheritance_tree(obj, maxcount=0, root=None):
- """create a graph for an arbitrary object"""
- if inspect.ismodule(obj):
- return make_package_class_tree(obj, maxcount=maxcount, root=root)
- elif not isinstance(obj, type):
- obj = type(obj)
- return make_class_tree(obj, root=root)
- if __name__ == '__main__':
- from matplotlib import pyplot as plt
- make_inheritance_tree(plt.Axes, root='matplotlib')
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement