Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from chainer.functions.caffe import CaffeFunction
- import pickle, logging
- from os.path import isfile
- class CaffeNetWrapper(object):
- def __init__(self, caffe_model, chainer_model, force_caffe_load = False):
- super(CaffeNetWrapper, self).__init__()
- if not force_caffe_load and isfile(chainer_model):
- logging.info("loading chainer model from \"{}\"".format(chainer_model))
- self.net = pickle.load(open(chainer_model, "rb"))
- else:
- logging.info("loading caffe model from \"{}\"".format(caffe_model))
- self.net = CaffeFunction(caffe_model)
- try:
- logging.info("saving loaded model to \"{}\"".format(chainer_model))
- with open(chainer_model, "wb") as f:
- pickle.dump(self.net, f)
- except Exception as e:
- logging.warn("could not save loadded model to \"{}\" Error: {}".format(chainer_model, e))
- def __call__(self, *args, **kw):
- return self.net(*args, **kw)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement