DeepModel.py 775 B

12345678910111213141516171819202122232425262728293031
  1. class DeePredict:
  2. def __init__(self, model=None, bridge=None):
  3. if model:
  4. self._model = model
  5. if bridge:
  6. self._bridge = bridge
  7. def model(self, *args, **kwargs):
  8. if self.__getattribute__('_model'):
  9. return self._model(*args, **kwargs)
  10. else:
  11. raise ValueError('缺少模型')
  12. def bridge(self, *args, **kwargs):
  13. return self._bridge(*args, **kwargs)
  14. def predict(self, *args, **kwargs):
  15. x = self.model(*args, **kwargs)
  16. if self.__getattribute__('_bridge'):
  17. y = self.bridge(x)
  18. return y
  19. else:
  20. return x
  21. if __name__ == '__main__':
  22. def f1(x):
  23. return x
  24. dp = DeePredict(f1, f1)
  25. print(dp.predict([1]))