easy_rec/__init__.py (53 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import platform
import sys
from easy_rec.version import __version__
curr_dir, _ = os.path.split(__file__)
parent_dir = os.path.dirname(curr_dir)
sys.path.insert(0, parent_dir)
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
# Avoid import tensorflow which conflicts with the version used in EasyRecProcessor
if 'PROCESSOR_TEST' not in os.environ:
from tensorflow.python.platform import tf_logging
# In DeepRec, logger.propagate of tf_logging is False, should be True
tf_logging._logger.propagate = True
def get_ops_dir():
import tensorflow as tf
if platform.system() == 'Linux':
ops_dir = os.path.join(curr_dir, 'python/ops')
if 'PAI' in tf.__version__:
ops_dir = os.path.join(ops_dir, '1.12_pai')
elif tf.__version__.startswith('1.12'):
ops_dir = os.path.join(ops_dir, '1.12')
elif tf.__version__.startswith('1.15'):
if 'IS_ON_PAI' in os.environ:
ops_dir = os.path.join(ops_dir, 'DeepRec')
else:
ops_dir = os.path.join(ops_dir, '1.15')
else:
tmp_version = tf.__version__.split('.')
tmp_version = '.'.join(tmp_version[:2])
ops_dir = os.path.join(ops_dir, tmp_version)
return ops_dir
else:
return None
ops_dir = get_ops_dir()
if ops_dir is not None and not os.path.exists(ops_dir):
logging.warning('ops_dir[%s] does not exist' % ops_dir)
ops_dir = None
from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
from easy_rec.python.main import export # isort:skip # noqa: E402
from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402
try:
import tensorflow_io.oss
except Exception:
pass
print('easy_rec version: %s' % __version__)
print('Usage: easy_rec.help()')
_global_config = {}
def help():
print("""
1 Train
1.1 Train 1gpu
CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.train_eval
--pipeline_config_path deepfm_combo_on_avazu_ctr.config
1.2 Train 2gpu
sh scripts/train_2gpu.sh deepfm_combo_on_avazu_ctr.config
2 Eval
CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.eval
--pipeline_config_path deepfm_combo_on_avazu_ctr.config
3 Export
CUDA_VISIBLE_DEVICES=""
python -m easy_rec.python.export
--pipeline_config_path deepfm_combo_on_avazu_ctr.config
--export_dir models/export
4 Create config from excel
python -m easy_rec.python.tools.create_config_from_excel
--excel_path dwd_avazu_ctr_multi_tower.xls
--output_path dwd_avazu_ctr_multi_tower.config
5. Inference:
# use list input
import csv
from easy_rec.python.inference.predictor import Predictor
predictor = Predictor(SAVED_MODEL_DIR)
with open(INPUT_CSV, 'r') as fin:
reader = csv.reader(fin)
inputs = []
for row in reader:
inputs.append(row[1:])
output_res = self._predictor.predict(inputs, batch_size=32)
# use dict input
import csv
from easy_rec.python.inference.predictor import Predictor
predictor = Predictor(SAVED_MODEL_DIR)
field_keys = [ "field1", "field2", "field3", "field4", "field5",
"field6", "field7", "field8", "field9", "field10",
"field11", "field12", "field13", "field14", "field15",
"field16", "field17", "field18", "field19", "field20" ]
with open(INPUT_CSV, 'r') as fin:
reader = csv.reader(fin)
inputs = []
for row in reader:
inputs.append({ f : row[fid+1] for fid, f in enumerate(field_keys) })
output_res = self._predictor.predict(inputs, batch_size=32)
""")