Source code for ccobra.benchmark.modelimporter

""" Model importer implementation. Used for dynamically importing model classes
for automated evaluation.

Copyright 2018 Cognitive Computation Lab
University of Freiburg
Nicolas Riesterer <riestern@tf.uni-freiburg.de>
Daniel Brand <daniel.brand@cognition.uni-freiburg.de>

"""

import importlib
import inspect
import os
import sys
import copy

[docs]class ModelImporter(): """ Model importer class. Supports dynamical importing of modules, detection of model classes, and instantiation of said classes. """
[docs] def get_class(self, model_path): """ Determines the model class attribute. Parameters ---------- model_path : str Path to the file to scan for CCobraModel classes. Returns ------- str CCobraModel class attribute. Raises ------ ValueError Thrown if the class to load could not be determined. """ python_files = [] abs_path = os.path.abspath(model_path) if os.path.isfile(abs_path): python_files.append(abs_path) else: python_files = [ os.path.join(abs_path, f) for f in os.listdir( abs_path) if os.path.isfile( os.path.join(abs_path, f)) and f[-2:] == "py"] candidates = {} candidate_class_names = set() for python_file in python_files: module_name = os.path.splitext(os.path.basename(python_file))[0] sfl = importlib.machinery.SourceFileLoader(module_name, python_file) module = sfl.load_module() member_class_modules = inspect.getmembers(module, inspect.isclass) candidate_module = None candidate_class = None for member_class_module in member_class_modules: member_class = member_class_module[1] if member_class is self.superclass: continue elif issubclass(member_class, self.superclass): if self.load_specific_class is None and candidate_module: raise ValueError( 'Multiple model classes found in file ' \ '(e.g., {} and {}). ' \ 'Please only specify one per file.'.format( member_class.__name__, candidate_class.__name__)) if self.load_specific_class is None \ or self.load_specific_class == member_class.__name__: candidate_module = module candidate_class = member_class if candidate_module: full_name = '{}.{}'.format( candidate_module.__name__, candidate_class.__name__) candidates[full_name] = (candidate_module, candidate_class) candidate_class_names.add(full_name) if not candidates: raise ValueError( "No suitable classes found in model_path '{}'.".format( model_path)) if len(candidates) == 1: return list(candidates.values())[0][1] if self.load_specific_class is not None: if self.load_specific_class in candidates: # print("Selecting {} because the user demands it".format(self.load_specific_class)) return candidates[self.load_specific_class][1] for candidate in candidates.values(): candidate_class = candidate[1] if candidate_class.__name__ == self.load_specific_class: # print("Selecting {} because the user did not " \ # "provide a full path".format(self.load_specific_class)) return candidate_class remaining_classes = {x[1] for x in candidates.values()} for full_name, content in candidates.items(): candidate_module = content[0] candidate_class = content[1] imported_modules = inspect.getmembers( candidate_module, inspect.ismodule) for imported_module in imported_modules: imported_module = imported_module[1] for other in candidates: other_module = candidates[other][0] other_class = candidates[other][1] if str(other_module) == str(imported_module): remaining_classes.remove(other_class) if len(remaining_classes) > 1: raise ValueError( "Could not determine main class. Candidates were: '{}'.".format( remaining_classes)) elif len(remaining_classes) == 1: return list(remaining_classes)[0] else: raise ValueError("Could not determine main class.")
def __init__(self, model_path, superclass=object, load_specific_class=None): """ Imports a model based on a given python source script. Dynamically identifies the contained model class and prepares for instantiation. Parameters ---------- model_path : str Path to the python script to import. May be absolute or relative. superclass : object, optional Superclass determining which classes to consider for initialization. load_specific_class : str, optional Name of the class to load. Required if model file contains multiple CCobraModels. Raises ------ ValueError When multiple applicable model classes are found (determined via the superclass parameter). Only one single model is allowed per file. ValueError When no model with the given superclass is found. """ self.load_specific_class = load_specific_class self.superclass = superclass self.old_modules = set(sys.modules) self.class_attribute = self.get_class(model_path) self.old_path = copy.deepcopy(sys.path)
[docs] def unimport(self): """ Cuts off all dependencies loaded together with the module from the module graph. Attention: Might cause problems with garbage collection. """ # Make sure that modules with same names do not produce conflicts loaded_modules = set(sys.modules) - self.old_modules for module_name in loaded_modules: if module_name.startswith('torch'): continue del sys.modules[module_name] sys.path = self.old_path
[docs] def instantiate(self, model_kwargs=None): """ Creates an instance of the imported model by calling the empy default constructor. Returns ------- CCobraModel CCobraModel instance. """ if not model_kwargs: model_kwargs = {} return self.class_attribute(**model_kwargs)