多卡环境指定gpu
首先运行nvidia-smi
以获得gpu的编号
-
在import module时指定
import tensorflow as tf import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "".format(gpu_index)
-
创建session时指定
config = tf.ConfigProto() config.gpu_options.visible_device_list = '{}'.format(gpu_index) sess = tf.Session(config=config)
按需占用显存
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)