in evaluations/evaluator.py [0:0]
def _update_shapes(pool3):
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
ops = pool3.graph.get_operations()
for op in ops:
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None: # pylint: disable=protected-access
# shape = [s.value for s in shape] TF 1.x
shape = [s for s in shape] # TF 2.x
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3