多卡环境指定gpu

首先运行nvidia-smi以获得gpu的编号

  • 在import module时指定

    import tensorflow as tf
    import os
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "".format(gpu_index)
    
  • 创建session时指定

    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = '{}'.format(gpu_index)
    sess = tf.Session(config=config)
    

按需占用显存

  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)