tinynn/graph/configs/gen_modules_yml.py (54 lines of code) (raw):
import torch
import torchvision
import inspect
import yaml
final_dict = {}
# Stage 1: Quantization stubs
for k in torch.quantization.stubs.__dir__():
c = getattr(torch.quantization.stubs, k)
if not isinstance(c, type) and not inspect.ismodule(c):
continue
if isinstance(c, type):
print(k, c, issubclass(c, torch.nn.Module))
print(c.__module__)
final_dict.setdefault('torch.quantization', [])
final_dict['torch.quantization'].append(k)
# Stage 2: torch.nn Modules
for k in torch.nn.__dict__:
c = getattr(torch.nn, k)
if not isinstance(c, type) and not inspect.ismodule(c):
continue
if isinstance(c, type):
print(k, c, issubclass(c, torch.nn.Module))
print(c.__module__)
# Skip container modules
if '.container' in c.__module__:
continue
if c.__name__ in (
'Parameter',
'Module',
'DataParallel',
'TransformerEncoder',
'TransformerDecoder',
'TransformerEncoderLayer',
'TransformerDecoderLayer',
):
continue
final_dict.setdefault('torch.nn', [])
final_dict['torch.nn'].append(k)
# Stage 3: torchvision.ops Modules
for k in torchvision.ops.__dict__:
c = getattr(torchvision.ops, k)
if not isinstance(c, type) and not inspect.ismodule(c):
continue
if isinstance(c, type):
print(k, c, issubclass(c, torch.nn.Module))
print(c.__module__)
# Skip container modules
if c.__name__ in (
'FeaturePyramidNetwork',
'SqueezeExcitation',
'Conv2dNormActivation',
'Conv3dNormActivation',
'MLP',
):
continue
final_dict.setdefault('torchvision.ops', [])
final_dict['torchvision.ops'].append(k)
# Stage 4: Update config
with open('torch_module_override.yml', 'w') as f:
yaml.dump(final_dict, f)