def find_all_periods()

in cater_preprocessing/utils.py [0:0]


def find_all_periods(scene, threshold, cutoff_t=-1):
    scene_tag = '' if cutoff_t==-1 else '_till_{}'.format(cutoff_t) 
    movements = scene['movements']
    events = []
    
    has_contain = False 
    for o, ms in movements.items(): 
        order = {'sliding': 0, 'rotating': 0, 'flying':0, 'moving': 0}
        obj_id = scene['obj_id_to_idx'][o]
        prior_end = -1 
        for m in ms:
            if m[0]=='_no_op': continue 
            if 'contain' in m[0]: 
                has_contain = True 
            a = action_map[m[0]]
            s = m[2]
            e = m[3]
            if cutoff_t!=-1 and e>cutoff_t: continue 
                
            # specific movement 
            th = order[a]+1
            order[a] += 1 
            # all types of movement 
            all_th = order['moving']+1
            order['moving'] += 1 
            
            # check if each object cannot have more than one event at once 
            # check if all actions are already sorted in CATER 
            assert s >= prior_end 
            prior_end = e 
            events.append((obj_id, 'start_{}'.format(a), th, s))
            events.append((obj_id, 'end_{}'.format(a), th, e))
            events.append((obj_id, 'start_{}'.format('moving'), all_th, s))
            events.append((obj_id, 'end_{}'.format('moving'), all_th, e))
    scene['has_contain'] = has_contain
    
    # Adding containment as an temporal state like actions 
    contained_events = scene['contained_events']
    for obj_id, cs in contained_events.items():
        assert len(cs) == len(set(cs))
        sorted_cs = sorted(cs, key=lambda x: x[0])
        order = 0 
        for c in cs:
            s, e = c
            th = order+1
            order += 1 
            events.append((obj_id, 'start_{}'.format('contained'), th, s))
            events.append((obj_id, 'end_{}'.format('contained'), th, e))
    
      
    events = sorted(events, key=lambda x: x[-1])
    grouped_events = []
    event_to_group = {}
    curr_t = events[0][-1]
    curr_group = []
    for idx, event in enumerate(events): 
        if event[-1] != curr_t and idx>0: 
            grouped_events.append(curr_group) 
            curr_group = []
            curr_t = event[-1]
        curr_group.append(event)
        event_to_group[idx] = len(grouped_events)
    grouped_events.append(curr_group)
    scene['events{}'.format(scene_tag)] = events 
    scene['grouped_events{}'.format(scene_tag)] = grouped_events 
    scene['event_to_group'] = event_to_group
    
    event_periods = []
    if grouped_events[0][0][-1]>threshold:
        for e in grouped_events[0]: 
            event_periods.append((None, e))
    has_eov = False
    has_sov = False
    curr_start = grouped_events[0][0][-1]
    curr_events = grouped_events[0] 
    if curr_start == 0: 
        has_sov = True
    t_eov = len(scene['objects'][0]['locations'])-1 if cutoff_t==-1 else cutoff_t
    if grouped_events[-1][0][-1] == t_eov:
        has_eov = True 
    for es in grouped_events[1:]:
        t = es[0][-1]
        if t - curr_start > threshold: 
            for e1 in curr_events:
                for e2 in es:
                    event_periods.append((e1, e2))
        curr_start = t 
        curr_events = es 
    if has_eov:
        for e in grouped_events[-2]:
            event_periods.append((e, None)) # last frame of video
    else: 
        for e in grouped_events[-1]: 
            event_periods.append((e, None)) # last frame of video 
    if has_sov:
        for e in grouped_events[1]: 
            event_periods.append((None, e)) # first frame of video 
    else:
        for e in grouped_events[0]: 
            event_periods.append((None, e)) # first frame of video 
    # Trim event periods 
    out_periods = []
    for p in event_periods:
        if not is_valid_event(p, events): continue
        out_periods.append(p)
        
    if len(out_periods)==0:
        pdb.set_trace()

    scene['periods{}'.format(scene_tag)]=out_periods